Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove refined recompute deep copy #9617

Open
wants to merge 17 commits into
base: develop
Choose a base branch
from
41 changes: 41 additions & 0 deletions docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,47 @@ Trainer 是一个简单,但功能完整的 Paddle 训练和评估模块,并

Recompute the forward pass to calculate gradients. Used for saving memory (default: False)

--refined_recompute
精化重新计算参数,用于在GPU显存使用和计算速度之间寻求最佳平衡。
此参数允许用户对重新计算过程进行细致控制,以优化资源利用。具体配置示例如下:
`"attention_column_ln:-1, attention_row_ln:-1, flash_attn:-1, mlp_column_ln:5, mlp_row_ln:-1"`

在配置中,支持的参数包括:
`attention_column_ln`
`attention_row_ln`
`mlp_column_ln`
`mlp_row_ln`
`flash_attn`

每个参数后的数字,即`skip_num`,决定了对应操作跳过重计算的次数。具体解释如下:
`skip_num` 为 `-1`:表示在所有阶段均不进行重新计算,从而最大化显存使用。
`skip_num` 为 `0`:表示在每个阶段都强制进行重新计算,以最小化显存使用。

此外,您还可以将`skip_num`设置为`[1, ..., num_layers]`范围内的任意值。若`skip_num`超出`num_layers`,其行为将等同于设置为`-1`。
若配置中省略了某个参数,则系统默认将其设置为`xxx:0`。

(类型: `str`, 可选, 默认为: "")

Refined recompute parameter for optimizing the balance between GPU memory usage and computational speed.
This parameter allows fine-grained control over the recomputation process to optimize resource utilization. An example configuration is as follows:
`"attention_column_ln:-1, attention_row_ln:-1, flash_attn:-1, mlp_column_ln:5, mlp_row_ln:-1"`

The supported parameters in the configuration include:
`attention_column_ln`
`attention_row_ln`
`mlp_column_ln`
`mlp_row_ln`
`flash_attn`

The number following each parameter, `skip_num`, determines the number of times to bypass recomputation for the specified operation. Specifically:
`skip_num of -1`: Indicates no recomputation across all stages, maximizing memory usage.
`skip_num of 0`: Enforces recomputation at every stage, minimizing memory usage.

Additionally, you can set skip_num to any value within the range `[1, ..., num_layers]`. If `skip_num` exceeds `num_layers`, it will behave as if set to `-1`.
If a parameter is omitted from the configuration, it defaults to `xxx:0`.

(Type: `str`, optional, default: "")

--minimum_eval_times
最少评估次数,如果当前设置的eval_steps,评估次数少于minimum_eval_times,
此选项会覆盖eval_steps参数。
Expand Down
1 change: 1 addition & 0 deletions llm/docs/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ python ./predict/reft_predictor.py \
- `do_train`: 是否打开训练,默认为 False。
- `do_eval`: 是否打开评估,默认为 False。
- `recompute`: 重计算,暂支持 full 策略。开启后可降低显存以达到增大 batch size 的目的,默认为 False。
- `refined_recompute`: 精细化重计算,通过精细化控制所需重计算的部分从而达到显存和性能之间的均衡,当前仅支持`llama`系列模型以及`qwen`系列模型,详细使用请参考[TrainingArguments 文档](https://paddlenlp.readthedocs.io/zh/latest/trainer.html)。
- `tensor_parallel_degree`: 此参数 tensor_parallel_degree 表示将一层 transformer 结构的份数,该方法对通信开销较大, 建议 tensor_parallel_degree<=8, 尽量使用机器内部通信。默认为-1,表示不启用张量并行。
- `pipeline_parallel_degree`: 表示划分流水线的大小.(假设该参数为4, 模型12层, 则每一个 pp stage 包含3层模型) 默认值-1, 表示不启用流水线并行。
- `sharding_parallel_degree`: 表示分组参数切片的数据并行大小. 默认值1, 表示不启用分组参数切片的数据并行。
Expand Down
13 changes: 13 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,19 @@ class TrainingArguments:
"Only support for networks with transformer blocks."
},
)
refined_recompute: str = field(
JunnYu marked this conversation as resolved.
Show resolved Hide resolved
default="",
metadata={
"help": "The refined recompute parameter is designed to optimize the balance between GPU memory usage and computational speed.\n"
"An example configuration could be: `attention_column_ln:-1,attention_row_ln:-1,flash_attn:-1,mlp_column_ln:5,mlp_row_ln:-1`.\n"
"The supported parameters for refining recompute are `attention_column_ln`, `attention_row_ln`, `flash_attn`, `mlp_column_ln`, and `mlp_row_ln`.\n"
"The associated number, `skip_num`, determines how many times to bypass recomputation for the specified operation.\n"
"A `skip_num` of `-1` indicates no recomputation across all stages, maximizing memory usage;\n"
"A `skip_num` of `0` enforces recomputation at every stage, minimizing memory usage.\n"
"You can also set `skip_num` to a value within the range [1, ..., num_layers]. If `skip_num` exceeds `num_layers`, it will behave as if set to `-1`.\n"
"If a parameter is omitted, it defaults to `xxx:0`."
},
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


scale_loss: float = field(default=2**15, metadata={"help": "The value of initial scale_loss for fp16."})

Expand Down
13 changes: 5 additions & 8 deletions paddlenlp/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,6 @@ class LlmMetaConfig:
"Recompute granularity, Choose among ['full', 'core_attn', 'full_attn']",
),
("recompute_use_reentrant", bool, False, "recompute_use_reentrant"),
# refined_recompute attributes
(
"refined_recompute",
str,
"",
"refined_recompute, Choose from 'mlp_row_ln', 'mlp_column_ln', 'attention_row_ln', 'attention_column_ln', 'flash_attn']",
),
("skip_recompute_ops", Optional[Dict[str, int]], None, "skip_recompute_ops"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip_recompute_ops 这个没有了,现在加在哪里?

]

@classmethod
Expand Down Expand Up @@ -569,6 +561,11 @@ def __init__(self, **kwargs):
self.fuse_attention_qkv = kwargs.pop("fuse_attention_qkv", False)
self.fuse_attention_ffn = kwargs.pop("fuse_attention_ffn", False)

# for refined_recompute
self.refined_recompute = kwargs.pop("refined_recompute", {})
self.skip_recompute_ops = kwargs.pop("skip_recompute_ops", {})
self.register_unsavable_keys(["refined_recompute", "skip_recompute_ops"])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放到这里了,作为config基类里面的属性,默认都是空字典

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加入到不保存的配置里吧,否则可能影响下游推理等任务加载。


if "quantization_config" in kwargs and isinstance(kwargs["quantization_config"], Dict):
kwargs["quantization_config"] = QuantizationConfig.from_dict(kwargs["quantization_config"])
self.quantization_config = kwargs.pop("quantization_config", QuantizationConfig())
Expand Down
34 changes: 18 additions & 16 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@


class LlamaMLP(nn.Layer):
def __init__(self, config):
def __init__(self, config, layer_idx: int = 0):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
Expand All @@ -618,19 +618,19 @@

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False):

Check warning on line 621 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L621

Added line #L621 was not covered by tests
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False):

Check warning on line 623 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L623

Added line #L623 was not covered by tests
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False):

Check warning on line 631 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L631

Added line #L631 was not covered by tests
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False):

Check warning on line 633 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L633

Added line #L633 was not covered by tests
RowParallelLinear = RRRowParallelLinear

if config.tensor_parallel_degree > 1:
Expand Down Expand Up @@ -682,7 +682,7 @@
class LlamaAttention(nn.Layer):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, layer_idx: int = 0):
super().__init__()

self.config = config
Expand Down Expand Up @@ -746,18 +746,18 @@

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False):

Check warning on line 749 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L749

Added line #L749 was not covered by tests
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False):

Check warning on line 751 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L751

Added line #L751 was not covered by tests
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False):

Check warning on line 758 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L758

Added line #L758 was not covered by tests
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False):

Check warning on line 760 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L760

Added line #L760 was not covered by tests
RowParallelLinear = RRRowParallelLinear

if config.tensor_parallel_degree > 1:
Expand Down Expand Up @@ -862,7 +862,7 @@
if (
config.recompute
and not config.recompute_use_reentrant
and config.skip_recompute_ops.get("flash_attn", False)
and config.skip_recompute_ops[layer_idx].get("flash_attn", False)
):
self.attn_func = partial(scaled_dot_product_attention, skip_recompute=True)

Expand Down Expand Up @@ -1168,12 +1168,12 @@


class LlamaDecoderLayer(nn.Layer):
def __init__(self, config, layerwise_recompute: bool = False):
def __init__(self, config, layerwise_recompute: bool = False, layer_idx: int = 0):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config, layerwise_recompute)
self.mlp = LlamaMLP(config)
self.self_attn = LlamaAttention(config, layerwise_recompute, layer_idx=layer_idx)
self.mlp = LlamaMLP(config, layer_idx=layer_idx)
self.input_layernorm = LlamaRMSNorm(config)
self.post_attention_layernorm = LlamaRMSNorm(config)
self.sequence_parallel = config.sequence_parallel
Expand Down Expand Up @@ -1518,9 +1518,11 @@
self.layers = nn.LayerList(
[
LlamaDecoderLayer(
create_skip_config_for_refined_recompute(i, config), i not in self.no_recompute_layers
config=create_skip_config_for_refined_recompute(layer_idx, config),
layerwise_recompute=layer_idx not in self.no_recompute_layers,
layer_idx=layer_idx,
)
for i in range(config.num_hidden_layers)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config)
Expand Down
35 changes: 18 additions & 17 deletions paddlenlp/transformers/qwen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@


class QWenAttention(nn.Layer):
def __init__(self, config):
def __init__(self, config, layer_idx: int = 0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是必需加layer_idx的吗?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

必须加,不然不知道第几层是不是需要开启rr

super().__init__()

self.layer_idx = layer_idx
self.config = config
self.seq_length = config.seq_length
self.hidden_size = config.hidden_size
Expand All @@ -166,18 +166,18 @@

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False):

Check warning on line 169 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L169

Added line #L169 was not covered by tests
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False):

Check warning on line 171 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L171

Added line #L171 was not covered by tests
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("attention_column_ln", False):
if config.skip_recompute_ops[layer_idx].get("attention_column_ln", False):

Check warning on line 178 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L178

Added line #L178 was not covered by tests
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("attention_row_ln", False):
if config.skip_recompute_ops[layer_idx].get("attention_row_ln", False):

Check warning on line 180 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L180

Added line #L180 was not covered by tests
RowParallelLinear = RRRowParallelLinear

if config.tensor_parallel_degree > 1:
Expand Down Expand Up @@ -252,7 +252,7 @@
skip_recompute = (
self.config.recompute
and not self.config.recompute_use_reentrant
and self.config.skip_recompute_ops.get("flash_attn", False)
and self.config.skip_recompute_ops[self.layer_idx].get("flash_attn", False)
)
attn_output = no_recompute(
F.scaled_dot_product_attention,
Expand Down Expand Up @@ -409,7 +409,7 @@


class QWenMLP(nn.Layer):
def __init__(self, config):
def __init__(self, config, layer_idx: int = 0):
super().__init__()
ff_dim_in = config.intermediate_size // 2
self.fuse_attention_ffn = config.fuse_attention_ffn
Expand All @@ -420,18 +420,18 @@

# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False):

Check warning on line 423 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L423

Added line #L423 was not covered by tests
ColumnParallelLinear = RRColumnSequenceParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False):

Check warning on line 425 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L425

Added line #L425 was not covered by tests
RowParallelLinear = RRRowSequenceParallelLinear
else:
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear
# NOTE: refined_recompute is only supported when `recompute_use_reentrant=False`
if config.recompute and not config.recompute_use_reentrant:
if config.skip_recompute_ops.get("mlp_column_ln", False):
if config.skip_recompute_ops[layer_idx].get("mlp_column_ln", False):

Check warning on line 432 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L432

Added line #L432 was not covered by tests
ColumnParallelLinear = RRColumnParallelLinear
if config.skip_recompute_ops.get("mlp_row_ln", False):
if config.skip_recompute_ops[layer_idx].get("mlp_row_ln", False):

Check warning on line 434 in paddlenlp/transformers/qwen/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/qwen/modeling.py#L434

Added line #L434 was not covered by tests
RowParallelLinear = RRRowParallelLinear

if config.tensor_parallel_degree > 1:
Expand Down Expand Up @@ -484,13 +484,13 @@


class QWenBlock(nn.Layer):
def __init__(self, config):
def __init__(self, config, layer_idx: int = 0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QwenBlock的外部调用,缺少了layer_idx的输入,可以检查一下。

super().__init__()
self.sequence_parallel = config.sequence_parallel
self.ln_1 = QWenRMSNorm(config)
self.attn = QWenAttention(config)
self.attn = QWenAttention(config, layer_idx=layer_idx)
self.ln_2 = QWenRMSNorm(config)
self.mlp = QWenMLP(config)
self.mlp = QWenMLP(config, layer_idx=layer_idx)

def forward(
self,
Expand Down Expand Up @@ -726,9 +726,10 @@
self.h = nn.LayerList(
[
QWenBlock(
create_skip_config_for_refined_recompute(i, config),
config=create_skip_config_for_refined_recompute(layer_idx, config),
layer_idx=layer_idx,
)
for i in range(config.num_hidden_layers)
for layer_idx in range(config.num_hidden_layers)
]
)
self.ln_f = QWenRMSNorm(config)
Expand Down
5 changes: 4 additions & 1 deletion paddlenlp/transformers/qwen/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ def get_hcg():
self.add_sequential_layer(LayerDesc(QWenEmbeddingPipe, config=config), "qwen")
for i in range(config.num_hidden_layers):
self.add_sequential_layer(
LayerDesc(QWenBlockPipe, config=create_skip_config_for_refined_recompute(i, config)),
LayerDesc(
QWenBlockPipe,
config=create_skip_config_for_refined_recompute(i, config),
),
f"qwen.h.{i}",
)
self.add_sequential_layer(LayerDesc(QWenRMSNormPipe, config=config), "qwen.ln_f")
Expand Down
Loading