From 3a8618e0e135434fbe11dd27d99818696def24c5 Mon Sep 17 00:00:00 2001 From: chen2016013 <111894720+chen2016013@users.noreply.github.com> Date: Fri, 15 Nov 2024 14:54:25 +0800 Subject: [PATCH] fix recompute index bug (#69342) --- python/paddle/base/core.py | 9 ++++++ .../jit/dy2static/pir_partial_program.py | 30 +++++++++++-------- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/python/paddle/base/core.py b/python/paddle/base/core.py index 7d800d31294c8..33c31c549dd39 100644 --- a/python/paddle/base/core.py +++ b/python/paddle/base/core.py @@ -571,6 +571,15 @@ def _enable_dist_prim_all(): def _enable_auto_recompute(): flag = os.getenv("FLAGS_enable_auto_recompute") + + # NOTE(chenxi67): open recompute when cinn is enabled + from paddle.base.framework import in_cinn_mode + + if in_cinn_mode(): + if flag and flag.lower() in ("0", "false"): + return False + else: + return True if flag and flag.lower() in ("1", "true"): return True else: diff --git a/python/paddle/jit/dy2static/pir_partial_program.py b/python/paddle/jit/dy2static/pir_partial_program.py index ced9dbcb1cb71..54d15cf33b91e 100644 --- a/python/paddle/jit/dy2static/pir_partial_program.py +++ b/python/paddle/jit/dy2static/pir_partial_program.py @@ -926,9 +926,10 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram): inputs = train_runnable_program.x_values params = train_runnable_program.param_values combined_inputs = list(itertools.chain(inputs, params)) - forward_end_idx = len(program.global_block().ops) + forward_prog_len = len(program.global_block().ops) + forward_end_idx = forward_prog_len - 1 forward_end_op = None - if forward_end_idx > 0: + if forward_prog_len > 0: forward_end_op = program.global_block().ops[-1] grad_info_map = [None] * len(combined_inputs) with backend_guard(self._backend): @@ -958,7 +959,7 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram): "grad_input_", ) op_between_forward_and_backward = ( - len(program.global_block().ops) - forward_end_idx + len(program.global_block().ops) - forward_prog_len ) # call grad to get backward ops. @@ -985,7 +986,7 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram): if forward_end_op is not None: for idx, op in enumerate(program.global_block().ops): if op == forward_end_op: - forward_end_idx = idx + 1 + forward_end_idx = idx break for hooker in self._hookers: @@ -1019,11 +1020,12 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram): output_grads_to_append = list( filter(lambda x: not is_fake_value(x), x_grad_value + p_grad_value) ) - backward_end_op_index = len(program.global_block().ops) + backward_prog_len = len(program.global_block().ops) + backward_end_op_index = backward_prog_len - 1 paddle.base.libpaddle.pir.append_shadow_outputs( program, output_grads_to_append, - backward_end_op_index, + backward_prog_len, "grad_output_", ) @@ -1036,7 +1038,11 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram): [inputs, params, targets, x_grad_value, p_grad_value, o_grad_value] ) forward_index_pass = IndicesPreservePass( - [forward_end_idx, backward_start_op_index, backward_end_op_index], + [ + forward_end_idx + 1, + backward_start_op_index + 1, + backward_end_op_index + 1, + ], fused_bn_add_act_pass, ) program = forward_index_pass(program) @@ -1049,17 +1055,17 @@ def _append_backward_desc(self, train_runnable_program: RunnableProgram): o_grad_value, ) = fused_bn_add_act_pass.values ( - forward_end_idx, - backward_start_op_index, - backward_end_op_index, + forward_end_range, + backward_start_range, + backward_end_op_range, ) = forward_index_pass.new_indices return RunnableProgram( program, (inputs, params, targets), (x_grad_value, p_grad_value, o_grad_value), - (0, forward_end_idx), - (backward_start_op_index, backward_end_op_index), + (0, forward_end_range), + (backward_start_range, backward_end_op_range), ) def _prepare_attributes(self, in_sot_mode=False):