Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Sparsity Logs on FSDP Model Save #2203

Merged
merged 5 commits into from
Apr 10, 2024
Merged

Fix Sparsity Logs on FSDP Model Save #2203

merged 5 commits into from
Apr 10, 2024

Conversation

Satrat
Copy link

@Satrat Satrat commented Mar 28, 2024

Model sparsity was not being logged correctly on FSDP save. This was because when we save an FSDP model we gather it onto the CPU in the main process and call save_pretrained on the wrapped model.

    with FullyShardedDataParallel.state_dict_type(
        model, StateDictType.FULL_STATE_DICT, full_state_dict_config
    ):
        state_dict = accelerator.get_state_dict(model, unwrap=False)

    if accelerator.is_main_process:
        accelerator.unwrap_model(model).save_pretrained(
            output_dir,
            is_main_process=accelerator.is_main_process,
            save_function=accelerator.save,
            state_dict=state_dict

This results in the wrapped model being passed to the global sparsity function with incomplete information. We need the full state dict as well to properly calcualte the sparsity. I also add some other log cleanup so we aren't printing out as much sparsity information between train stages

Example

Launch with FSDP: accelerate launch --config_file integrations/huggingface-transformers/finetuning/example_fsdp_config.yaml test.py

from sparseml.transformers.finetune.text_generation import train

model_name = "neuralmagic/TinyLlama-1.1B-Chat-v1.0-pruned2.4"
output_dir = "llama-1.1b_compressed_test_out"
recipe_stub = """
test_stage:
  pruning_modifiers:
    ConstantPruningModifier:
      targets: ['re:.*q_proj.weight', 're:.*k_proj.weight', 're:.*v_proj.weight', 're:.*o_proj.weight',
        're:.*gate_proj.weight', 're:.*up_proj.weight', 're:.*down_proj.weight']
      start: 0
"""

train(
  save_compressed=True,
  model=model_name,
  recipe=recipe_stub,
  dataset="open_platypus",
  max_steps=5,
  output_dir=output_dir,
  overwrite_output_dir=True,
  splits = "train",
  logging_steps = 25,
  precision="bfloat16",
  gradient_checkpointing = True,
  bf16 = True,
)

Output:
All the parameters get picked up, previously only the 23 wrapped transformers layers were showing up as parameters

2024-03-28 21:18:09 sparseml.transformers.compression.utils.compress_save INFO     Inferring a sparsity configuration requires a global sparsity calculation. This can be costly for large models. To skip the calculation of compression statistics set skip_compression_stats=True
Calculating model sparsity: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:04<00:00, 48.54it/s]
Compressing model: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:16<00:00, 12.24it/s]

@bfineran bfineran merged commit 2de5c92 into main Apr 10, 2024
11 of 17 checks passed
@bfineran bfineran deleted the fsdp_sparsity_log_fixes branch April 10, 2024 14:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants