Skip to content

Commit

Permalink
fix recompute index bug (#69342)
Browse files Browse the repository at this point in the history
  • Loading branch information
chen2016013 authored Nov 15, 2024
1 parent 4043b8c commit 3a8618e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
9 changes: 9 additions & 0 deletions python/paddle/base/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,15 @@ def _enable_dist_prim_all():

def _enable_auto_recompute():
flag = os.getenv("FLAGS_enable_auto_recompute")

# NOTE(chenxi67): open recompute when cinn is enabled
from paddle.base.framework import in_cinn_mode

if in_cinn_mode():
if flag and flag.lower() in ("0", "false"):
return False
else:
return True
if flag and flag.lower() in ("1", "true"):
return True
else:
Expand Down
30 changes: 18 additions & 12 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,9 +926,10 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
inputs = train_runnable_program.x_values
params = train_runnable_program.param_values
combined_inputs = list(itertools.chain(inputs, params))
forward_end_idx = len(program.global_block().ops)
forward_prog_len = len(program.global_block().ops)
forward_end_idx = forward_prog_len - 1
forward_end_op = None
if forward_end_idx > 0:
if forward_prog_len > 0:
forward_end_op = program.global_block().ops[-1]
grad_info_map = [None] * len(combined_inputs)
with backend_guard(self._backend):
Expand Down Expand Up @@ -958,7 +959,7 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
"grad_input_",
)
op_between_forward_and_backward = (
len(program.global_block().ops) - forward_end_idx
len(program.global_block().ops) - forward_prog_len
)

# call grad to get backward ops.
Expand All @@ -985,7 +986,7 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
if forward_end_op is not None:
for idx, op in enumerate(program.global_block().ops):
if op == forward_end_op:
forward_end_idx = idx + 1
forward_end_idx = idx
break

for hooker in self._hookers:
Expand Down Expand Up @@ -1019,11 +1020,12 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
output_grads_to_append = list(
filter(lambda x: not is_fake_value(x), x_grad_value + p_grad_value)
)
backward_end_op_index = len(program.global_block().ops)
backward_prog_len = len(program.global_block().ops)
backward_end_op_index = backward_prog_len - 1
paddle.base.libpaddle.pir.append_shadow_outputs(
program,
output_grads_to_append,
backward_end_op_index,
backward_prog_len,
"grad_output_",
)

Expand All @@ -1036,7 +1038,11 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
[inputs, params, targets, x_grad_value, p_grad_value, o_grad_value]
)
forward_index_pass = IndicesPreservePass(
[forward_end_idx, backward_start_op_index, backward_end_op_index],
[
forward_end_idx + 1,
backward_start_op_index + 1,
backward_end_op_index + 1,
],
fused_bn_add_act_pass,
)
program = forward_index_pass(program)
Expand All @@ -1049,17 +1055,17 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram):
o_grad_value,
) = fused_bn_add_act_pass.values
(
forward_end_idx,
backward_start_op_index,
backward_end_op_index,
forward_end_range,
backward_start_range,
backward_end_op_range,
) = forward_index_pass.new_indices

return RunnableProgram(
program,
(inputs, params, targets),
(x_grad_value, p_grad_value, o_grad_value),
(0, forward_end_idx),
(backward_start_op_index, backward_end_op_index),
(0, forward_end_range),
(backward_start_range, backward_end_op_range),
)

def _prepare_attributes(self, in_sot_mode=False):
Expand Down

0 comments on commit 3a8618e

Please sign in to comment.