diff --git a/.github/workflows/regression_test.yaml b/.github/workflows/regression_test.yaml index 80ee645f47..4e27f79ba7 100644 --- a/.github/workflows/regression_test.yaml +++ b/.github/workflows/regression_test.yaml @@ -26,6 +26,8 @@ jobs: python-version: ['3.11'] torch-version: ["stable", "nightly"] fail-fast: false + env: + PYTORCH_CUDA_ALLOC_CONF: expandable_segments:True steps: - name: Check out repo uses: actions/checkout@v3 diff --git a/README.md b/README.md index 31fc280e04..2d885a3779 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ torchtune currently supports the following models. | [Code-Llama2](https://ai.meta.com/blog/code-llama-large-language-model-coding/) | 7B, 13B, 70B [[models](torchtune/models/code_llama2/_model_builders.py), [configs](recipes/configs/code_llama2/)] | | [Mistral](https://huggingface.co/mistralai) | 7B [[models](torchtune/models/mistral/_model_builders.py), [configs](recipes/configs/mistral/)] | | [Gemma](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) | 2B, 7B [[models](torchtune/models/gemma/_model_builders.py), [configs](recipes/configs/gemma/)] | +| [Gemma2](https://huggingface.co/docs/transformers/main/en/model_doc/gemma2) | 2B, 9B, 27B [[models](torchtune/models/gemma2/_model_builders.py), [configs](recipes/configs/gemma2/)] | | [Microsoft Phi3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3) | Mini [[models](torchtune/models/phi3/), [configs](recipes/configs/phi3/)] | [Qwen2](https://qwenlm.github.io/blog/qwen2/) | 0.5B, 1.5B, 7B [[models](torchtune/models/qwen2/), [configs](recipes/configs/qwen2/)] diff --git a/docs/source/api_ref_datasets.rst b/docs/source/api_ref_datasets.rst index 98d328ee54..74b047953e 100644 --- a/docs/source/api_ref_datasets.rst +++ b/docs/source/api_ref_datasets.rst @@ -37,6 +37,7 @@ Image + Text datasets multimodal.llava_instruct_dataset multimodal.the_cauldron_dataset + multimodal.vqa_dataset .. _dataset_builders: diff --git a/docs/source/api_ref_models.rst b/docs/source/api_ref_models.rst index fe94104484..b2d74022b1 100644 --- a/docs/source/api_ref_models.rst +++ b/docs/source/api_ref_models.rst @@ -208,6 +208,47 @@ To download the CodeLlama-7B model: code_llama2.lora_code_llama2_70b code_llama2.qlora_code_llama2_70b +qwen-2.5 +-------- + +Models of size 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B from the `Qwen2.5 family `_. + +To download the Qwen2.5 1.5B model, for example: + +.. code-block:: bash + + tune download Qwen/Qwen2.5-1.5B-Instruct --output-dir /tmp/Qwen2_5-1_5B-Instruct --ignore-patterns None + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + qwen2_5.qwen2_5_0_5b + qwen2_5.lora_qwen2_5_0_5b + qwen2_5.qwen2_5_1_5b_base + qwen2_5.qwen2_5_1_5b_instruct + qwen2_5.lora_qwen2_5_1_5b_base + qwen2_5.lora_qwen2_5_1_5b_instruct + qwen2_5.qwen2_5_3b + qwen2_5.lora_qwen2_5_3b + qwen2_5.qwen2_5_7b_base + qwen2_5.qwen2_5_7b_instruct + qwen2_5.lora_qwen2_5_7b_base + qwen2_5.lora_qwen2_5_7b_instruct + qwen2_5.qwen2_5_14b_base + qwen2_5.qwen2_5_14b_instruct + qwen2_5.lora_qwen2_5_14b_base + qwen2_5.lora_qwen2_5_14b_instruct + qwen2_5.qwen2_5_32b_base + qwen2_5.qwen2_5_32b_instruct + qwen2_5.lora_qwen2_5_32b_base + qwen2_5.lora_qwen2_5_32b_instruct + qwen2_5.qwen2_5_72b_base + qwen2_5.qwen2_5_72b_instruct + qwen2_5.lora_qwen2_5_72b_base + qwen2_5.lora_qwen2_5_72b_instruct + qwen2_5.qwen2_5_tokenizer + qwen-2 ------ @@ -225,12 +266,12 @@ To download the Qwen2 1.5B model, for example: qwen2.qwen2 qwen2.lora_qwen2 - qwen2.qwen2_7b qwen2.qwen2_0_5b - qwen2.qwen2_1_5b - qwen2.lora_qwen2_7b qwen2.lora_qwen2_0_5b + qwen2.qwen2_1_5b qwen2.lora_qwen2_1_5b + qwen2.qwen2_7b + qwen2.lora_qwen2_7b qwen2.qwen2_tokenizer phi-3 @@ -320,8 +361,39 @@ To download the Gemma 7B model: gemma.gemma_tokenizer +gemma2 : +-------- + +Models of size 2B, 9B, 27B from the `Gemma family `_. + +Important: You need to request access on `Hugging Face `__ to use this model. + +To download the Gemma2 2B, 9B, 27B models : + +.. code-block:: bash + + tune download google/gemma-2-b --ignore-patterns "gemma-2-b.gguf" --hf-token + + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + gemma2.gemma2 + gemma2.lora_gemma2 + gemma2.gemma2_2b + gemma2.lora_gemma2_2b + gemma2.qlora_gemma2_2b + gemma2.gemma2_9b + gemma2.lora_gemma2_9b + gemma2.qlora_gemma2_9b + gemma2.gemma2_27b + gemma2.lora_gemma2_27b + gemma2.qlora_gemma2_27b + gemma.gemma_tokenizer + clip ------ +---- Vision components to support multimodality using `CLIP encoder `_. diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index f360b4f02c..2153e4fd20 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -75,6 +75,7 @@ PEFT Components peft.AdapterModule peft.get_adapter_params peft.set_trainable_params + peft.get_adapter_state_dict peft.validate_missing_and_unexpected_for_lora peft.validate_state_dict_for_lora peft.disable_adapter diff --git a/docs/source/api_ref_training.rst b/docs/source/api_ref_training.rst index 4e7f186523..c904ddc32b 100644 --- a/docs/source/api_ref_training.rst +++ b/docs/source/api_ref_training.rst @@ -56,6 +56,7 @@ Utilities for enabling and working with distributed training. get_world_size_and_rank get_full_finetune_fsdp_wrap_policy lora_fsdp_wrap_policy + gather_cpu_state_dict .. _ac_label: diff --git a/docs/source/deep_dives/checkpointer.rst b/docs/source/deep_dives/checkpointer.rst index 024e555483..13aac698c6 100644 --- a/docs/source/deep_dives/checkpointer.rst +++ b/docs/source/deep_dives/checkpointer.rst @@ -443,7 +443,7 @@ For this section we'll use the Llama2 13B model in HF format. checkpoint_dir=checkpoint_dir, checkpoint_files=pytorch_files, output_dir=checkpoint_dir, - model_type=ModelType.LLAMA2 + model_type="LLAMA2" ) torchtune_sd = checkpointer.load_checkpoint() diff --git a/docs/source/tutorials/memory_optimizations.rst b/docs/source/tutorials/memory_optimizations.rst index 321ea30333..a0f6d16c91 100644 --- a/docs/source/tutorials/memory_optimizations.rst +++ b/docs/source/tutorials/memory_optimizations.rst @@ -14,16 +14,16 @@ To make things easy, we've summarized these components in the following table: :header: "Component", "When to use?" :widths: auto - ":ref:`glossary_precision`", "You'll usually want to leave this as its default ``bfloat16``. If you're struggling with training stability or accuracy due to precision, fp32 may help, but will significantly increase memory usage and decrease training speed." - ":ref:`glossary_act_ckpt`", "Use when you're memory constrained and need to handle larger batch sizes or longer context lengths. Be aware that it may slow down training speed." - ":ref:`glossary_act_off`", "Similar to activation checkpointing, this can be used when memory constrained, but comes at the cost of training speed due to the overhead of moving tensors between GPU VRAM and CPU. This can also be used alongside activation checkpointing." - ":ref:`glossary_grad_accm`", "Helpful when memory-constrained to simulate larger batch sizes. Often preferable to activation checkpointing for better training speed." - ":ref:`glossary_low_precision_opt`", "When you need to further reduce memory usage beyond using ``bf16`` by reducing the precision in the optimizer states. Note that lower precision optimizers may reduce training stability/accuracy." - ":ref:`glossary_opt_in_bwd`", "Helps reduce memory usage when using stateful optimizers, particularly when full-finetuning large models with high gradient memory usage. This is not compatible with ``gradient_accumulation_steps``, so training may slow down due to reduced model throughput." - ":ref:`glossary_cpu_offload`", "Offloads optimizer states and (optionally) gradients to CPU, and performs optimizer steps on CPU. This can be used to significantly reduce GPU memory usage at the cost of CPU RAM and training speed, as CPU optimizer steps can be slow and bottleneck training performance." - ":ref:`glossary_lora`", "When you want to significantly reduce the number of trainable parameters, saving gradient and optimizer memory during training, and significantly speeding up training." - ":ref:`glossary_qlora`", "When you need even more memory savings than LoRA, at the potential cost of some training speed. Useful for very large models or limited hardware." - ":ref:`glossary_dora`", "Like LoRA, DoRA can provide significant memory savings and training speed-ups. DoRA may improve performance over LoRA, particularly when using small rank updates." + ":ref:`glossary_precision`", "You'll usually want to leave this as its default ``bfloat16``. It uses 2 bytes per model parameter instead of 4 bytes when using ``float32``." + ":ref:`glossary_act_ckpt`", "Use when you're memory constrained and want to use a larger model, batch size or context length. Be aware that it will slow down training speed." + ":ref:`glossary_act_off`", "Similar to activation checkpointing, this can be used when memory constrained, but may decrease training speed. This **should** be used alongside activation checkpointing." + ":ref:`glossary_grad_accm`", "Helpful when memory-constrained to simulate larger batch sizes. Not compatible with optimizer in backward. Use it when you can already fit at least one sample without OOMing, but not enough of them." + ":ref:`glossary_low_precision_opt`", "Use when you want to reduce the size of the optimizer state. This is relevant when training large models and using optimizers with momentum, like Adam. Note that lower precision optimizers may reduce training stability/accuracy." + ":ref:`glossary_opt_in_bwd`", "Use it when you have large gradients and can fit a large enough batch size, since this is not compatible with ``gradient_accumulation_steps``." + ":ref:`glossary_cpu_offload`", "Offloads optimizer states and (optionally) gradients to CPU, and performs optimizer steps on CPU. This can be used to significantly reduce GPU memory usage at the cost of CPU RAM and training speed. Prioritize using it only if the other techniques are not enough." + ":ref:`glossary_lora`", "When you want to significantly reduce the number of trainable parameters, saving gradient and optimizer memory during training, and significantly speeding up training. This may reduce training accuracy" + ":ref:`glossary_qlora`", "When you are training a large model, since quantization will save 1.5 bytes * (# of model parameters), at the potential cost of some training speed and accuracy." + ":ref:`glossary_dora`", "a variant of LoRA that may improve model performance at the cost of slightly more memory." .. note:: @@ -83,8 +83,7 @@ and in most cases training can slow down quite a bit as a result of this activat *Sounds great! How do I use it?* -To enable activation checkpointing, use the ``enable_activation_checkpointing`` config entry or flag -in any of our recipes, e.g. ``enable_activation_checkpointing=True``. +To enable activation checkpointing, use ``enable_activation_checkpointing=True``. .. _glossary_act_off: @@ -104,14 +103,13 @@ This setting is especially helpful for larger batch sizes, or longer context len While of course it takes runtime and resources to move Tensors from GPU to CPU and back, the implementation in torchtune uses multiple CUDA streams (when available) in order to overlap the extra communication with the computation to hide the extra runtime. As the communication workload is variable depending on the number and size of tensors being -offloaded, it is common to not offload every single activation. In fact, one can use offloading in conjunction with activations -checkpointing, where all activations will either be recomputed later in the backward or brought back from the CPU. +offloaded, we do not recommend using it unless :ref:`glossary_act_ckpt` is also enabled, in which case only the checkpointed +tensors will be offloaded. *Sounds great! How do I use it?* -To enable activation offloading, use the ``enable_activation_offloading`` config entry or flag -in our lora finetuning single device recipe, e.g. ``enable_activation_offloading=True``. To allow -usage of streams, make sure you are on a torch version later than PyTorch 2.5.0. +To enable activation offloading, use ``enable_activation_offloading=True``. If you are on torch +version later than PyTorch 2.5.0, it will allow the usage of multiple CUDA streams automatically. .. _glossary_grad_accm: @@ -143,10 +141,8 @@ If you're using one of our distributed recipes, simply multiply by the number of ``total_batch_size = batch_size * gradient_accumulation_steps * num_devices`` -Gradient accumulation is especially useful when you are memory constrained. In this case, -accumulating gradients might give you better training speed than enabling :ref:`activation -checkpointing `, since activation checkpointing reduces memory consumption at the cost of repeated -computations. +Gradient accumulation is especially useful when you can fit at least one sample in your GPU. In this case, artificially increasing the batch by +accumulating gradients might give you faster training speeds than using other memory optimization techniques that trade-off memory for speed, like :ref:`activation checkpointing `. *Sounds great! How do I use it?* @@ -168,25 +164,35 @@ Lower Precision Optimizers *What's going on here?* In addition to :ref:`reducing model and optimizer precision ` during training, we can further reduce precision in our optimizer states. -All of our single-device fine-tuning recipes support lower-precision optimizers from the `bitsandbytes `_ library - -a good place to start might be the ``AdamW8bit`` and ``PagedAdamW8bit`` optimizers, which we've tested our recipes with. +All of our recipes support lower-precision optimizers from the `torchao `_ library. +For single device recipes, we also support `bitsandbytes `_. + +A good place to start might be the :class:`torchao.prototype.low_bit_optim.AdamW8bit` and :class:`bitsandbytes.optim.PagedAdamW8bit` optimizers. +Both reduce memory by quantizing the optimizer state dict. Paged optimizers will also offload to CPU if there isn't enough GPU memory available. In practice, +you can expect higher memory savings from bnb's PagedAdamW8bit but higher training speed from torchao's AdamW8bit. *Sounds great! How do I use it?* -To use this in your recipes, make sure you have installed bitsandbytes (``pip install bitsandbytes``). Then, enable +To use this in your recipes, make sure you have installed torchao (``pip install torchao``) or bitsandbytes (``pip install bitsandbytes``). Then, enable a low precision optimizer using the :ref:`cli_label`: + .. code-block:: bash tune run --config \ - optimizer=bitsandbytes.optim.PagedAdamW + optimizer=torchao.prototype.low_bit_optim.AdamW8bit + +.. code-block:: bash + + tune run --config \ + optimizer=bitsandbytes.optim.PagedAdamW8bit or by directly :ref:`modifying a config file`: .. code-block:: yaml optimizer: - _component_: bitsandbytes.optim.PagedAdamW + _component_: bitsandbytes.optim.PagedAdamW8bit lr: 2e-5 .. _glossary_opt_in_bwd: @@ -213,10 +219,9 @@ To understand how this works, we encourage you to read through the relevant PyTo .. todo ref full finetune recipe doc -In torchtune, you can enable this feature using the ``optimizer_in_bwd`` flag, which is currently only supported in our -single-device full finetune recipe. This feature works best when optimizer memory is particularly large; -e.g. when using a stateful optimizer with a model with a lot of parameters, and when you don't need to use -:ref:`gradient accumulation `. +In torchtune, you can enable this feature using the ``optimizer_in_bwd`` flag. This feature works best when using a stateful optimizer +with a model with a lot of parameters, and when you don't need to use :ref:`gradient accumulation `. +You won't see meaningful impact when finetuning LoRA recipes, since in this case the number of parameters being updated are small. .. _glossary_cpu_offload: @@ -232,6 +237,9 @@ through the `CPUOffloadOptimizer `, which will *only* offload to CPU +when there is not enough GPU available. + *Sounds great! How do I use it?* To use this optimizer in your recipes, set the ``optimizer`` key in your config to :class:`torchao.prototype.low_bit_optim.CPUOffloadOptimizer`, which @@ -272,10 +280,10 @@ or using it directly in your code, which allows you to change the base optimizer Some helpful hints from the ``torchao`` `CPUOffloadOptimizer page `_: -* The CPU optimizer step is often the bottleneck when optimizer CPU offload is used. To minimize the slowdown, it is recommended to (1) use full ``bf16`` training so that parameters, gradients, and optimizer states are in ``bf16``; and (2) give GPU more work per optimizer step (e.g. larger batch size with activation checkpointing, gradient accumulation). +* The CPU optimizer step is often the bottleneck when optimizer CPU offload is used. To minimize the slowdown, it is recommended to (1) use full ``bf16`` training so that parameters, gradients, and optimizer states are in ``bf16``; and (2) give GPU more work per optimizer step to amortize the offloading time (e.g. larger batch size with activation checkpointing, gradient accumulation). * Gradient accumulation should always be set to 1 when ``offload_gradients=True``, as gradients are cleared on GPU every backward pass. * This optimizer works by keeping a copy of parameters and pre-allocating gradient memory on CPU. Therefore, expect your RAM usage to increase by 4x model size. -* This optimizer is only supported for single-device recipes. To use CPU-offloading in distributed recipes, use ``fsdp_cpu_offload=True`` in any distributed recipe. See :class:`torch.distributed.fsdp.FullyShardedDataParallel` for more details +* This optimizer is only supported for single-device recipes. To use CPU-offloading in distributed recipes, use ``fsdp_cpu_offload=True`` instead. See :class:`torch.distributed.fsdp.FullyShardedDataParallel` for more details and `FSDP1 vs FSDP2 `_ to see how they differ. .. _glossary_peft: @@ -339,20 +347,20 @@ These are all specified under the ``model`` flag or config entry, i.e: tune run lora_finetune_single_device --config llama3/8B_lora_single_device \ model.apply_lora_to_mlp=True \ - model.lora_attn_modules=["q_proj","k_proj","v_proj"] + model.lora_attn_modules=["q_proj","k_proj","v_proj","output_proj"] .. code-block:: yaml model: _component_: torchtune.models.llama3.lora_llama3_8b apply_lora_to_mlp: True - model.lora_attn_modules: ["q_proj", "k_proj", "v_proj"] + model.lora_attn_modules: ["q_proj", "k_proj", "v_proj","output_proj"] Secondly, parameters which control the scale of the impact of LoRA on the model: * ``lora_rank: int`` affects the scale of the LoRA decomposition, where ``lora_rank << in_dim`` and ``lora_rank << out_dim`` \- the dimensions of an arbitrary linear layer in the model. Concretely, ``lora_rank`` reduces the number of gradients stored - in a linear fashion from ``in_dim * out_dim`` to ``lora_rank * (in_dim + out_dim)``. Typically, we have ``lora_rank in [8, 128]``. + in a linear fashion from ``in_dim * out_dim`` to ``lora_rank * (in_dim + out_dim)``. Typically, we have ``lora_rank in [8, 256]``. * ``lora_alpha: float`` affects the magnitude of the LoRA updates. A larger alpha results in larger updates to the base model weights , potentially at the cost of training stability, conversely, smaller alpha can stabilize training at the cost of slower learning. We provide default settings for these parameters which we've tested with all of our models, but we encourage you to adjust them @@ -365,7 +373,7 @@ As above, these parameters are also specified under the ``model`` flag or config tune run lora_finetune_single_device --config llama3/8B_lora_single_device \ model.apply_lora_to_mlp=True \ - model.lora_attn_modules=["q_proj","k_proj","v_proj"] \ + model.lora_attn_modules=["q_proj","k_proj","v_proj","output_proj"] \ model.lora_rank=32 \ model.lora_alpha=64 @@ -374,7 +382,7 @@ As above, these parameters are also specified under the ``model`` flag or config model: _component_: torchtune.models.llama3.lora_llama3_8b apply_lora_to_mlp: True - lora_attn_modules: ["q_proj", "k_proj", "v_proj"] + lora_attn_modules: ["q_proj", "k_proj", "v_proj","output_proj"] lora_rank: 32 lora_alpha: 64 @@ -390,16 +398,16 @@ Quantized Low Rank Adaptation (QLoRA) *What's going on here?* -`QLoRA `_ is an enhancement on top of `LoRA `_ +`QLoRA `_ is a memory enhancement on top of `LoRA `_ that maintains the frozen model parameters from LoRA in 4-bit quantized precision, thereby reducing memory usage. This is enabled through a novel 4-bit NormalFloat (NF4) data type proposed by the authors, which allows for 4-8x less parameter memory usage whilst retaining model accuracy. You can read our tutorial on :ref:`finetuning Llama2 with QLoRA` for a deeper understanding of how it works. -When considering using QLoRA to reduce memory usage, it's worth noting that QLoRA prevents accuracy degradation during quantization -by up-casting quantized parameters to the original higher precision datatype during model forward passes - this up-casting may -incur penalties to training speed. The :ref:`relevant section ` in our QLoRA tutorial demonstrates the usage of ``torch.compile`` -to address this by speeding up training. +When considering using QLoRA to reduce memory usage, it's worth noting that QLoRA is slower than LoRA and may not be worth it if +the model you are finetuning is small. In numbers, QLoRA saves roughly 1.5 bytes * (# of model parameters). Also, although QLoRA quantizes the model, +it minimizes accuracy degradation by up-casting quantized parameters to the original higher precision datatype during model forward passes - this up-casting may incur penalties to training speed. +The :ref:`relevant section ` in our QLoRA tutorial demonstrates the usage of ``torch.compile`` to address this by speeding up training. *Sounds great! How do I use it?* diff --git a/docs/source/tutorials/qat_finetune.rst b/docs/source/tutorials/qat_finetune.rst index 0c259bd731..6f83bb6c02 100644 --- a/docs/source/tutorials/qat_finetune.rst +++ b/docs/source/tutorials/qat_finetune.rst @@ -136,7 +136,7 @@ used for inference or generation. QAT finetuning recipe in torchtune ---------------------------------- -Putting it all together, we can now fine-tune a model using torchtune’s `QAT recipe `. +Putting it all together, we can now fine-tune a model using torchtune’s :ref:`QAT recipe`. Make sure that you have first downloaded the Llama3 weights and tokenizer by following :ref:`these instructions`. In this tutorial, we use the following settings to demonstrate QAT’s effectiveness in recovering diff --git a/recipes/configs/code_llama2/7B_full_low_memory.yaml b/recipes/configs/code_llama2/7B_full_low_memory.yaml index ffe48249a7..b6586e8b5a 100644 --- a/recipes/configs/code_llama2/7B_full_low_memory.yaml +++ b/recipes/configs/code_llama2/7B_full_low_memory.yaml @@ -45,8 +45,8 @@ resume_from_checkpoint: False # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -55,20 +55,20 @@ shuffle: True epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 1 +gradient_accumulation_steps: 1 # Use to increase virtual batch size optimizer: _component_: bitsandbytes.optim.PagedAdamW lr: 2e-5 -optimizer_in_bwd: True +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss -compile: False +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: True # True reduces memory dtype: bf16 @@ -79,3 +79,28 @@ metric_logger: log_dir: /tmp/CodeLlama-7b-hf/logs log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml index 6533420441..11f2ffc6c6 100644 --- a/recipes/configs/code_llama2/7B_lora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -18,11 +18,11 @@ # Model Arguments model: _component_: torchtune.models.code_llama2.lora_code_llama2_7b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 # Tokenizer @@ -49,8 +49,8 @@ save_adapter_weights_only: False # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True @@ -59,7 +59,7 @@ shuffle: True epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 16 +gradient_accumulation_steps: 8 # Use to increase virtual batch size optimizer: _component_: torch.optim.AdamW fused: True @@ -70,13 +70,13 @@ lr_scheduler: num_warmup_steps: 100 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss -compile: False +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory dtype: bf16 diff --git a/recipes/configs/code_llama2/7B_qlora_single_device.yaml b/recipes/configs/code_llama2/7B_qlora_single_device.yaml index afda975b9f..ad21d8074d 100644 --- a/recipes/configs/code_llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_qlora_single_device.yaml @@ -18,11 +18,11 @@ # Model Arguments model: _component_: torchtune.models.code_llama2.qlora_code_llama2_7b - lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 # Tokenizer @@ -49,8 +49,8 @@ save_adapter_weights_only: False # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True @@ -58,7 +58,7 @@ shuffle: True epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 16 +gradient_accumulation_steps: 8 # Use to increase virtual batch size optimizer: _component_: torch.optim.AdamW fused: True @@ -69,13 +69,13 @@ lr_scheduler: num_warmup_steps: 100 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss -compile: False +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory dtype: bf16 diff --git a/recipes/configs/dev/8B_full_experimental.yaml b/recipes/configs/dev/8B_full_experimental.yaml index f70ec01004..288c55e105 100644 --- a/recipes/configs/dev/8B_full_experimental.yaml +++ b/recipes/configs/dev/8B_full_experimental.yaml @@ -26,8 +26,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -57,14 +57,14 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory ac_mode: 'selective' # ['selective', 'full'] ac_option: 2 # [int] = ac every positive int layer @@ -81,3 +81,28 @@ metric_logger: output_dir: /tmp/alpaca-llama3-finetune log_every_n_steps: null log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/gemma/2B_full.yaml b/recipes/configs/gemma/2B_full.yaml index 2bfe5995be..0e34bb205c 100644 --- a/recipes/configs/gemma/2B_full.yaml +++ b/recipes/configs/gemma/2B_full.yaml @@ -23,8 +23,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -54,14 +54,15 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -74,3 +75,28 @@ metric_logger: output_dir: /tmp/alpaca-gemma-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/gemma/2B_lora.yaml b/recipes/configs/gemma/2B_lora.yaml index 7169236759..9895736c35 100644 --- a/recipes/configs/gemma/2B_lora.yaml +++ b/recipes/configs/gemma/2B_lora.yaml @@ -22,18 +22,18 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True # Model Arguments model: _component_: torchtune.models.gemma.lora_gemma_2b - lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True - lora_rank: 64 - lora_alpha: 128 + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank lora_dropout: 0.0 checkpointer: @@ -66,14 +66,14 @@ loss: batch_size: 4 epochs: 3 max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -86,3 +86,28 @@ metric_logger: output_dir: /tmp/alpaca-gemma-lora log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/gemma/2B_lora_single_device.yaml b/recipes/configs/gemma/2B_lora_single_device.yaml index 9bf463181e..ed7aa11360 100644 --- a/recipes/configs/gemma/2B_lora_single_device.yaml +++ b/recipes/configs/gemma/2B_lora_single_device.yaml @@ -22,18 +22,18 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True # Model Arguments model: _component_: torchtune.models.gemma.lora_gemma_2b - lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True - lora_rank: 64 - lora_alpha: 128 + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank lora_dropout: 0.0 checkpointer: @@ -65,14 +65,14 @@ loss: batch_size: 4 epochs: 3 max_steps_per_epoch: null -gradient_accumulation_steps: 4 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision diff --git a/recipes/configs/gemma/2B_qlora_single_device.yaml b/recipes/configs/gemma/2B_qlora_single_device.yaml index 250d6ef178..ea288595ba 100644 --- a/recipes/configs/gemma/2B_qlora_single_device.yaml +++ b/recipes/configs/gemma/2B_qlora_single_device.yaml @@ -22,18 +22,18 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True # Model Arguments model: _component_: torchtune.models.gemma.qlora_gemma_2b - lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True - lora_rank: 64 - lora_alpha: 128 + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank lora_dropout: 0.0 checkpointer: @@ -65,14 +65,14 @@ loss: batch_size: 4 epochs: 3 max_steps_per_epoch: null -gradient_accumulation_steps: 4 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision diff --git a/recipes/configs/gemma/7B_full.yaml b/recipes/configs/gemma/7B_full.yaml index 8c7ff001fd..4555235385 100644 --- a/recipes/configs/gemma/7B_full.yaml +++ b/recipes/configs/gemma/7B_full.yaml @@ -23,8 +23,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -56,14 +56,15 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -76,3 +77,28 @@ metric_logger: output_dir: /tmp/alpaca-gemma-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/gemma/7B_lora.yaml b/recipes/configs/gemma/7B_lora.yaml index 209277c9d5..97685e66e1 100644 --- a/recipes/configs/gemma/7B_lora.yaml +++ b/recipes/configs/gemma/7B_lora.yaml @@ -23,18 +23,18 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True # Model Arguments model: _component_: torchtune.models.gemma.lora_gemma_7b - lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True - lora_rank: 64 - lora_alpha: 128 + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank lora_dropout: 0.0 checkpointer: @@ -68,14 +68,14 @@ loss: batch_size: 4 epochs: 3 max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -88,3 +88,28 @@ metric_logger: output_dir: /tmp/alpaca-gemma-lora log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/gemma/7B_lora_single_device.yaml b/recipes/configs/gemma/7B_lora_single_device.yaml index 57be9a3be0..82d1399b20 100644 --- a/recipes/configs/gemma/7B_lora_single_device.yaml +++ b/recipes/configs/gemma/7B_lora_single_device.yaml @@ -22,18 +22,18 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True # Model Arguments model: _component_: torchtune.models.gemma.lora_gemma_7b - lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 checkpointer: @@ -67,14 +67,14 @@ loss: batch_size: 8 epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 2 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision diff --git a/recipes/configs/gemma/7B_qlora_single_device.yaml b/recipes/configs/gemma/7B_qlora_single_device.yaml index 0b52716d60..985ab6cae8 100644 --- a/recipes/configs/gemma/7B_qlora_single_device.yaml +++ b/recipes/configs/gemma/7B_qlora_single_device.yaml @@ -22,18 +22,18 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True # Model Arguments model: _component_: torchtune.models.gemma.qlora_gemma_7b - lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True - lora_rank: 64 - lora_alpha: 128 + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank lora_dropout: 0.0 checkpointer: @@ -67,14 +67,14 @@ loss: batch_size: 4 epochs: 3 max_steps_per_epoch: null -gradient_accumulation_steps: 4 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision diff --git a/recipes/configs/gemma2/27B_full.yaml b/recipes/configs/gemma2/27B_full.yaml new file mode 100644 index 0000000000..ddc89b38b2 --- /dev/null +++ b/recipes/configs/gemma2/27B_full.yaml @@ -0,0 +1,74 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a gemma2 27B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/27B_full checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2-27b/tokenizer.model + +# Dataset +dataset: + packed: False # Set to true for great speed ups + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.gemma2_27b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2-27b/ + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00024" + recipe_checkpoint: null + output_dir: /tmp/gemma-2-27b + model_type: GEMMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 1 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +compile: False # pytorch compile, set to true for perf/memory improvement + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-27b-finetune +log_every_n_steps: 1 +log_peak_memory_stats: True diff --git a/recipes/configs/gemma2/27B_lora.yaml b/recipes/configs/gemma2/27B_lora.yaml new file mode 100644 index 0000000000..a138441199 --- /dev/null +++ b/recipes/configs/gemma2/27B_lora.yaml @@ -0,0 +1,86 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a gemma2 27B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/27B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/27B_lora checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2-27b/tokenizer.model + +# Dataset +dataset: + packed: False # Set to true for great speed ups + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.lora_gemma2_27b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2-27b/ + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00024" + recipe_checkpoint: null + output_dir: /tmp/gemma-2-27b/ + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +compile: False # pytorch compile, set to true for perf/memory improvement + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-27b-lora +log_every_n_steps: 1 +log_peak_memory_stats: True diff --git a/recipes/configs/gemma2/27B_lora_single_device.yaml b/recipes/configs/gemma2/27B_lora_single_device.yaml new file mode 100644 index 0000000000..577b0715c5 --- /dev/null +++ b/recipes/configs/gemma2/27B_lora_single_device.yaml @@ -0,0 +1,112 @@ +# Config for multi-device LoRA finetuning in lora_finetune_single_device.py +# using a gemma2 27B model +# +# This config assumes that you've run the following command before launching +# this run (torchtune does not use gguf so you can ignore it to save time and space): +# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma2/27B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma2/27B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2-27b/tokenizer.model + +# Dataset +dataset: + packed: False # Set to true for great speed ups + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.lora_gemma2_27b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 8 + lora_alpha: 16 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2-27b/ + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00024" + recipe_checkpoint: null + output_dir: /tmp/gemma-2-27b/ + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 5e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 2 +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 8 +compile: False # pytorch compile, set to true for perf/memory improvement + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-27b-lora +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/gemma2/27B_qlora_single_device.yaml b/recipes/configs/gemma2/27B_qlora_single_device.yaml new file mode 100644 index 0000000000..14d9b75ba7 --- /dev/null +++ b/recipes/configs/gemma2/27B_qlora_single_device.yaml @@ -0,0 +1,115 @@ +# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py +# using a gemma2 27B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-27b --ignore-patterns "gemma-2-27b.gguf" --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma2/27B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma2/27B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2-27b/tokenizer.model + +# Dataset +dataset: + packed: False # Set to true for great speed ups + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.qlora_gemma2_27b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2-27b/ + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00024" + recipe_checkpoint: null + output_dir: /tmp/gemma-2-27b/ + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 4 +compile: False # pytorch compile, set to true for perf/memory improvement + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-27b-lora +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 + +# For colab use True +low_cpu_ram: False diff --git a/recipes/configs/gemma2/2B_full.yaml b/recipes/configs/gemma2/2B_full.yaml new file mode 100644 index 0000000000..e302dd759d --- /dev/null +++ b/recipes/configs/gemma2/2B_full.yaml @@ -0,0 +1,76 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a gemma2 2B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/2B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/2B_full checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2-2b/tokenizer.model + +# Dataset +dataset: + packed: False # Set to true for great speed ups + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.gemma2_2b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2-2b/ + checkpoint_files: [ + model-00001-of-00003.safetensors, + model-00002-of-00003.safetensors, + model-00003-of-00003.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma-2-2b + model_type: GEMMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +compile: False # pytorch compile, set to true for perf/memory improvement + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-finetune +log_every_n_steps: 1 +log_peak_memory_stats: True diff --git a/recipes/configs/gemma2/2B_lora.yaml b/recipes/configs/gemma2/2B_lora.yaml new file mode 100644 index 0000000000..9a439ee0a3 --- /dev/null +++ b/recipes/configs/gemma2/2B_lora.yaml @@ -0,0 +1,88 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a gemma2 2B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/2B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/2B_lora checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2-2b/tokenizer.model + +# Dataset +dataset: + packed: False # Set to true for great speed ups + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.lora_gemma2_2b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2-2b/ + checkpoint_files: [ + model-00001-of-00003.safetensors, + model-00002-of-00003.safetensors, + model-00003-of-00003.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma-2-2b + model_type: GEMMA2 +resume_from_checkpoint: False + +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +compile: False # pytorch compile, set to true for perf/memory improvement + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-lora +log_every_n_steps: 1 +log_peak_memory_stats: True diff --git a/recipes/configs/gemma2/2B_lora_single_device.yaml b/recipes/configs/gemma2/2B_lora_single_device.yaml new file mode 100644 index 0000000000..1a2703fb47 --- /dev/null +++ b/recipes/configs/gemma2/2B_lora_single_device.yaml @@ -0,0 +1,114 @@ +# Config for multi-device LoRA finetuning in lora_finetune_single_device.py +# using a gemma2 2B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma2/2B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma2/2B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2-2b/tokenizer.model + +# Dataset +dataset: + packed: False # Set to true for great speed ups + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.lora_gemma2_2b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2-2b/ + checkpoint_files: [ + model-00001-of-00003.safetensors, + model-00002-of-00003.safetensors, + model-00003-of-00003.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma-2-2b + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 8 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 2 +compile: False # pytorch compile, set to true for perf/memory improvement + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-lora +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/gemma2/2B_qlora_single_device.yaml b/recipes/configs/gemma2/2B_qlora_single_device.yaml new file mode 100644 index 0000000000..c2525460ff --- /dev/null +++ b/recipes/configs/gemma2/2B_qlora_single_device.yaml @@ -0,0 +1,114 @@ +# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py +# using a gemma2 2B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-2b --ignore-patterns "gemma-2-2b.gguf" --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma2/2B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma2/2B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2-2b/tokenizer.model + +# Dataset +dataset: + packed: False # Set to true for great speed ups + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.qlora_gemma2_2b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2-2b/ + checkpoint_files: [ + model-00001-of-00003.safetensors, + model-00002-of-00003.safetensors, + model-00003-of-00003.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/gemma-2-2b + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 4 +compile: False # pytorch compile, set to true for perf/memory improvement + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-lora +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/gemma2/9B_full.yaml b/recipes/configs/gemma2/9B_full.yaml new file mode 100644 index 0000000000..0fc7e6e4e4 --- /dev/null +++ b/recipes/configs/gemma2/9B_full.yaml @@ -0,0 +1,74 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a gemma2 9B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-9b --ignore-patterns "gemma-2-9b.gguf" --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/9B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 full_finetune_distributed --config gemma2/9B_full checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2-9b/tokenizer.model + +# Dataset +dataset: + packed: False # Set to true for great speed ups + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.gemma2_9b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2-9b/ + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00008" + recipe_checkpoint: null + output_dir: /tmp/gemma-2-9b + model_type: GEMMA2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 1 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +compile: False # pytorch compile, set to true for perf/memory improvement + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-9b-finetune +log_every_n_steps: 1 +log_peak_memory_stats: True diff --git a/recipes/configs/gemma2/9B_lora.yaml b/recipes/configs/gemma2/9B_lora.yaml new file mode 100644 index 0000000000..960e4fa881 --- /dev/null +++ b/recipes/configs/gemma2/9B_lora.yaml @@ -0,0 +1,86 @@ +# Config for multi-device LoRA finetuning in lora_finetune_distributed.py +# using a gemma2 9B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-9b --ignore-patterns "gemma-2-9b.gguf" --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/9B_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma2/9B_lora checkpointer.checkpoint_dir= +# +# This config works only when the model is being fine-tuned on 2+ GPUs. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2-9b/tokenizer.model + +# Dataset +dataset: + packed: False # Set to true for great speed ups + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.lora_gemma2_9b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2-9b/ + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00008" + recipe_checkpoint: null + output_dir: /tmp/gemma-2-9b/ + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +compile: False # pytorch compile, set to true for perf/memory improvement + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-9b-lora +log_every_n_steps: 1 +log_peak_memory_stats: True diff --git a/recipes/configs/gemma2/9B_lora_single_device.yaml b/recipes/configs/gemma2/9B_lora_single_device.yaml new file mode 100644 index 0000000000..e9d6c22a73 --- /dev/null +++ b/recipes/configs/gemma2/9B_lora_single_device.yaml @@ -0,0 +1,112 @@ +# Config for multi-device LoRA finetuning in lora_finetune_single_device.py +# using a gemma2 9B model +# +# This config assumes that you've run the following command before launching +# this run (torchtune does not use gguf so you can ignore it to save time and space): +# tune download google/gemma-2-9b --ignore-patterns "gemma-2-9b.gguf" --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma2/9B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma2/9B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2-9b/tokenizer.model + +# Dataset +dataset: + packed: False # Set to true for great speed ups + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.lora_gemma2_9b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 8 + lora_alpha: 16 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2-9b/ + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00008" + recipe_checkpoint: null + output_dir: /tmp/gemma-2-9b/ + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 5e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 8 +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 2 +compile: False # pytorch compile, set to true for perf/memory improvement + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-9b-lora +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/gemma2/9B_qlora_single_device.yaml b/recipes/configs/gemma2/9B_qlora_single_device.yaml new file mode 100644 index 0000000000..8991ba9ece --- /dev/null +++ b/recipes/configs/gemma2/9B_qlora_single_device.yaml @@ -0,0 +1,115 @@ +# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py +# using a gemma2 9B model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download google/gemma-2-9b --ignore-patterns "gemma-2-9b.gguf" --hf-token +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config gemma2/9B_qlora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config gemma2/9B_qlora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +# Tokenizer +tokenizer: + _component_: torchtune.models.gemma.gemma_tokenizer + path: /tmp/gemma-2-9b/tokenizer.model + +# Dataset +dataset: + packed: False # Set to true for great speed ups + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.gemma2.qlora_gemma2_9b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: True + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/gemma-2-9b/ + checkpoint_files: + filename_format: model-{}-of-{}.safetensors + max_filename: "00008" + recipe_checkpoint: null + output_dir: /tmp/gemma-2-9b/ + model_type: GEMMA2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +optimizer: + _component_: torch.optim.AdamW + fused: True + lr: 2e-5 + +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 10 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Fine-tuning arguments +batch_size: 4 +epochs: 3 +max_steps_per_epoch: null +gradient_accumulation_steps: 4 +compile: False # pytorch compile, set to true for perf/memory improvement + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-gemma2-9b-lora +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 + +# For colab use True +low_cpu_ram: False diff --git a/recipes/configs/llama2/13B_full.yaml b/recipes/configs/llama2/13B_full.yaml index fef60b7c21..d02ce13c0b 100644 --- a/recipes/configs/llama2/13B_full.yaml +++ b/recipes/configs/llama2/13B_full.yaml @@ -43,8 +43,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -58,14 +58,15 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -78,3 +79,28 @@ metric_logger: output_dir: /tmp/alpaca-llama2-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama2/13B_lora.yaml b/recipes/configs/llama2/13B_lora.yaml index 6dd3017c06..7a6fa600d2 100644 --- a/recipes/configs/llama2/13B_lora.yaml +++ b/recipes/configs/llama2/13B_lora.yaml @@ -22,11 +22,11 @@ # Model Arguments model: _component_: torchtune.models.llama2.lora_llama2_13b - lora_attn_modules: ['q_proj', 'v_proj', 'k_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True apply_lora_to_output: True - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 checkpointer: @@ -52,8 +52,8 @@ tokenizer: # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -74,8 +74,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 16 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora_finetune_output @@ -88,5 +88,30 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama2/13B_qlora_single_device.yaml b/recipes/configs/llama2/13B_qlora_single_device.yaml index 5e37ee820a..a10285544a 100644 --- a/recipes/configs/llama2/13B_qlora_single_device.yaml +++ b/recipes/configs/llama2/13B_qlora_single_device.yaml @@ -18,11 +18,11 @@ # Model Arguments model: _component_: torchtune.models.llama2.qlora_llama2_13b - lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -47,8 +47,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -69,8 +69,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 16 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/qlora_finetune_output/ @@ -84,7 +84,7 @@ log_peak_memory_stats: True device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler diff --git a/recipes/configs/llama2/70B_lora.yaml b/recipes/configs/llama2/70B_lora.yaml index 7b936696ad..a67bfc9da2 100644 --- a/recipes/configs/llama2/70B_lora.yaml +++ b/recipes/configs/llama2/70B_lora.yaml @@ -12,11 +12,11 @@ # Model Arguments model: _component_: torchtune.models.llama2.lora_llama2_70b - lora_attn_modules: ['q_proj', 'v_proj', 'k_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 16 - lora_alpha: 32 + lora_rank: 16 # higher increases accuracy and memory + lora_alpha: 32 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -52,8 +52,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -74,7 +74,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 1 +compile: False # pytorch compile, set to true for better perf/memory +gradient_accumulation_steps: 1 # Use to increase virtual batch size # Logging output_dir: /tmp/lora_finetune_output @@ -87,5 +88,30 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama2/70B_qlora.yaml b/recipes/configs/llama2/70B_qlora.yaml index 5d778e13e3..d04b7c6753 100644 --- a/recipes/configs/llama2/70B_qlora.yaml +++ b/recipes/configs/llama2/70B_qlora.yaml @@ -17,11 +17,11 @@ # Model Arguments model: _component_: torchtune.models.llama2.qlora_llama2_70b - lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 16 - lora_alpha: 32 + lora_rank: 16 # higher increases accuracy and memory + lora_alpha: 32 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -57,8 +57,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed train_on_input: True seed: null shuffle: True @@ -83,8 +83,8 @@ fsdp: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/qlora_finetune_output @@ -97,5 +97,30 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama2/7B_full.yaml b/recipes/configs/llama2/7B_full.yaml index eea691ea86..99e7fcc30b 100644 --- a/recipes/configs/llama2/7B_full.yaml +++ b/recipes/configs/llama2/7B_full.yaml @@ -26,8 +26,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -57,14 +57,15 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -77,3 +78,28 @@ metric_logger: output_dir: /tmp/alpaca-llama2-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama2/7B_full_low_memory.yaml b/recipes/configs/llama2/7B_full_low_memory.yaml index 7380bd0756..c5300c0a90 100644 --- a/recipes/configs/llama2/7B_full_low_memory.yaml +++ b/recipes/configs/llama2/7B_full_low_memory.yaml @@ -28,8 +28,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -58,18 +58,18 @@ optimizer: lr_scheduler: _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup num_warmup_steps: 100 -optimizer_in_bwd: True +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training environment device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: True # True reduces memory # Reduced precision @@ -82,3 +82,28 @@ metric_logger: output_dir: /tmp/alpaca-llama2-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama2/7B_lora.yaml b/recipes/configs/llama2/7B_lora.yaml index 7841eea584..8e64a3fc11 100644 --- a/recipes/configs/llama2/7B_lora.yaml +++ b/recipes/configs/llama2/7B_lora.yaml @@ -21,11 +21,11 @@ # Model Arguments model: _component_: torchtune.models.llama2.lora_llama2_7b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -49,8 +49,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -71,7 +71,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 32 +compile: False # pytorch compile, set to true for better perf/memory +gradient_accumulation_steps: 8 # Use to increase virtual batch size # Logging output_dir: /tmp/lora_finetune_output @@ -84,7 +85,7 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler diff --git a/recipes/configs/llama2/7B_lora_dpo.yaml b/recipes/configs/llama2/7B_lora_dpo.yaml index 1a0b4bc390..f3b827ae3b 100644 --- a/recipes/configs/llama2/7B_lora_dpo.yaml +++ b/recipes/configs/llama2/7B_lora_dpo.yaml @@ -19,11 +19,11 @@ # Model Arguments model: _component_: torchtune.models.llama2.lora_llama2_7b - lora_attn_modules: ["q_proj", "v_proj"] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 # Tokenizer @@ -69,8 +69,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: 1000 -gradient_accumulation_steps: 8 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora_dpo_output/ @@ -83,4 +83,4 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory diff --git a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml index bfe8185f06..6483219e9b 100644 --- a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml @@ -18,11 +18,11 @@ # Model Arguments model: _component_: torchtune.models.llama2.lora_llama2_7b - lora_attn_modules: ["q_proj", "v_proj"] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 # Tokenizer @@ -66,8 +66,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: 1000 -gradient_accumulation_steps: 16 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora_dpo_output/ @@ -80,4 +80,4 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory diff --git a/recipes/configs/llama2/7B_lora_single_device.yaml b/recipes/configs/llama2/7B_lora_single_device.yaml index b96d139174..481fed1a7e 100644 --- a/recipes/configs/llama2/7B_lora_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_single_device.yaml @@ -19,11 +19,11 @@ # Model Arguments model: _component_: torchtune.models.llama2.lora_llama2_7b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -47,8 +47,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -69,8 +69,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 64 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora_finetune_output @@ -85,7 +85,7 @@ device: cuda dtype: bf16 # Activations Memory -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler diff --git a/recipes/configs/llama2/7B_qat_full.yaml b/recipes/configs/llama2/7B_qat_full.yaml index d1a408aca5..e404b0c4dc 100644 --- a/recipes/configs/llama2/7B_qat_full.yaml +++ b/recipes/configs/llama2/7B_qat_full.yaml @@ -22,8 +22,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -53,8 +53,9 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # QAT arguments quantizer: @@ -65,8 +66,8 @@ quantizer: device: cuda # Memory management -enable_activation_checkpointing: True -memory_efficient_fsdp_wrap: False +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 @@ -78,3 +79,28 @@ metric_logger: output_dir: /tmp/alpaca-llama2-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama2/7B_qlora.yaml b/recipes/configs/llama2/7B_qlora.yaml index 97cdae7dac..80cee9853c 100644 --- a/recipes/configs/llama2/7B_qlora.yaml +++ b/recipes/configs/llama2/7B_qlora.yaml @@ -20,11 +20,11 @@ # Model Arguments model: _component_: torchtune.models.llama2.qlora_llama2_7b - lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -48,8 +48,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed train_on_input: True seed: null shuffle: True @@ -74,8 +74,8 @@ fsdp: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 2 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/qlora_finetune_output/ @@ -88,5 +88,30 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama2/7B_qlora_single_device.yaml b/recipes/configs/llama2/7B_qlora_single_device.yaml index ad6667b2fb..b1f119d7db 100644 --- a/recipes/configs/llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/llama2/7B_qlora_single_device.yaml @@ -18,11 +18,11 @@ # Model Arguments model: _component_: torchtune.models.llama2.qlora_llama2_7b - lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -46,8 +46,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -68,8 +68,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 16 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/qlora_finetune_output/ @@ -84,7 +84,7 @@ device: cuda dtype: bf16 # Activations Memory -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler diff --git a/recipes/configs/llama3/70B_full.yaml b/recipes/configs/llama3/70B_full.yaml index e950b91dab..fde65da8c6 100644 --- a/recipes/configs/llama3/70B_full.yaml +++ b/recipes/configs/llama3/70B_full.yaml @@ -25,8 +25,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -86,17 +86,18 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 +gradient_accumulation_steps: 1 # Use to increase virtual batch size # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory -custom_sharded_layers: ['tok_embeddings', 'output'] +custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed fsdp_cpu_offload: True -compile: False # pytorch compile, set to true for perf/memory improvement +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Reduced precision dtype: bf16 @@ -108,3 +109,28 @@ metric_logger: output_dir: /tmp/full-llama3-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3/70B_lora.yaml b/recipes/configs/llama3/70B_lora.yaml index 4ab6c13793..2e4e718f62 100644 --- a/recipes/configs/llama3/70B_lora.yaml +++ b/recipes/configs/llama3/70B_lora.yaml @@ -12,11 +12,11 @@ # Model Arguments model: _component_: torchtune.models.llama3.lora_llama3_70b - lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 16 - lora_alpha: 32 + lora_rank: 16 # higher increases accuracy and memory + lora_alpha: 32 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -67,8 +67,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -89,8 +89,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False # pytorch compile, set to true for perf/memory improvement +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora_finetune_output @@ -103,5 +103,30 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3/8B_dora.yaml b/recipes/configs/llama3/8B_dora.yaml index 43b0fa6066..ee7a8d07f6 100644 --- a/recipes/configs/llama3/8B_dora.yaml +++ b/recipes/configs/llama3/8B_dora.yaml @@ -17,11 +17,11 @@ # Model Arguments model: _component_: torchtune.models.llama3.lora_llama3_8b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank use_dora: True # Tokenizer @@ -42,8 +42,8 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -64,8 +64,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/dora_finetune_output @@ -78,5 +78,30 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3/8B_dora_single_device.yaml b/recipes/configs/llama3/8B_dora_single_device.yaml index 20f5804082..82c7c765b5 100644 --- a/recipes/configs/llama3/8B_dora_single_device.yaml +++ b/recipes/configs/llama3/8B_dora_single_device.yaml @@ -19,11 +19,11 @@ # Model Arguments model: _component_: torchtune.models.llama3.lora_llama3_8b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank use_dora: True # Tokenizer @@ -44,8 +44,8 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -66,8 +66,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 64 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/dora_finetune_output @@ -80,7 +80,7 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler diff --git a/recipes/configs/llama3/8B_full.yaml b/recipes/configs/llama3/8B_full.yaml index 27f569aa16..4d7f7e7b8e 100644 --- a/recipes/configs/llama3/8B_full.yaml +++ b/recipes/configs/llama3/8B_full.yaml @@ -26,8 +26,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -57,16 +57,17 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory -custom_sharded_layers: ['tok_embeddings', 'output'] +custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed. # Reduced precision dtype: bf16 @@ -78,3 +79,28 @@ metric_logger: output_dir: /tmp/full-llama3-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3/8B_full_single_device.yaml b/recipes/configs/llama3/8B_full_single_device.yaml index b86272842e..26f635fac0 100644 --- a/recipes/configs/llama3/8B_full_single_device.yaml +++ b/recipes/configs/llama3/8B_full_single_device.yaml @@ -28,8 +28,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -60,15 +60,15 @@ lr_scheduler: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -optimizer_in_bwd: True -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +compile: False # pytorch compile, set to true for better perf/memory # Training environment device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -81,3 +81,28 @@ metric_logger: output_dir: /tmp/full-llama3-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3/8B_lora.yaml b/recipes/configs/llama3/8B_lora.yaml index 41537ccdbb..3ced0899e4 100644 --- a/recipes/configs/llama3/8B_lora.yaml +++ b/recipes/configs/llama3/8B_lora.yaml @@ -26,11 +26,11 @@ tokenizer: # Model Arguments model: _component_: torchtune.models.llama3.lora_llama3_8b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 checkpointer: @@ -47,8 +47,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -69,8 +69,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 32 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora_finetune_output @@ -83,5 +83,30 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3/8B_lora_single_device.yaml b/recipes/configs/llama3/8B_lora_single_device.yaml index 6c6aefa525..4535758ac9 100644 --- a/recipes/configs/llama3/8B_lora_single_device.yaml +++ b/recipes/configs/llama3/8B_lora_single_device.yaml @@ -19,11 +19,11 @@ # Model Arguments model: _component_: torchtune.models.llama3.lora_llama3_8b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 # Tokenizer @@ -46,8 +46,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -68,8 +68,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 64 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora_finetune_output @@ -84,7 +84,7 @@ device: cuda dtype: bf16 # Activations Memory -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Profiler (disabled) @@ -92,14 +92,14 @@ profiler: _component_: torchtune.training.setup_torch_profiler enabled: False - # Output directory of trace artifacts + #Output directory of trace artifacts output_dir: ${output_dir}/profiling_outputs #`torch.profiler.ProfilerActivity` types to trace cpu: True cuda: True - # trace options passed to `torch.profiler.profile` + #trace options passed to `torch.profiler.profile` profile_memory: False with_stack: False record_shapes: True @@ -108,6 +108,6 @@ profiler: # `torch.profiler.schedule` options: # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat wait_steps: 5 - warmup_steps: 5 + warmup_steps: 3 active_steps: 2 num_cycles: 1 diff --git a/recipes/configs/llama3/8B_qat_full.yaml b/recipes/configs/llama3/8B_qat_full.yaml index 07461e8243..2b08cbb10f 100644 --- a/recipes/configs/llama3/8B_qat_full.yaml +++ b/recipes/configs/llama3/8B_qat_full.yaml @@ -21,8 +21,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -44,7 +44,6 @@ resume_from_checkpoint: False # Fine-tuning arguments batch_size: 2 epochs: 3 -compile: False # QAT arguments quantizer: @@ -58,14 +57,17 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: True -memory_efficient_fsdp_wrap: True +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory +custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed. # Reduced precision dtype: bf16 @@ -74,6 +76,31 @@ dtype: bf16 metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: ${output_dir} -output_dir: /tmp/alpaca-llama3-finetune +output_dir: /tmp/full-llama3-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3/8B_qdora_single_device.yaml b/recipes/configs/llama3/8B_qdora_single_device.yaml index 18c625a956..8eb1b5151c 100644 --- a/recipes/configs/llama3/8B_qdora_single_device.yaml +++ b/recipes/configs/llama3/8B_qdora_single_device.yaml @@ -19,11 +19,11 @@ # Model Arguments model: _component_: torchtune.models.llama3.lora_llama3_8b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank use_dora: True quantize_base: True @@ -45,8 +45,8 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -67,8 +67,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 64 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/qdora_finetune_output @@ -81,7 +81,7 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler diff --git a/recipes/configs/llama3/8B_qlora_single_device.yaml b/recipes/configs/llama3/8B_qlora_single_device.yaml index 5486ae9f1a..0c4ab423b8 100644 --- a/recipes/configs/llama3/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3/8B_qlora_single_device.yaml @@ -18,11 +18,11 @@ # Model Arguments model: _component_: torchtune.models.llama3.qlora_llama3_8b - lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 # Tokenizer @@ -45,8 +45,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -67,8 +67,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 16 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/qlora_finetune_output/ @@ -83,7 +83,7 @@ device: cuda dtype: bf16 # Activations Memory -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Profiler (disabled) @@ -107,7 +107,7 @@ profiler: # `torch.profiler.schedule` options: # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat wait_steps: 5 - warmup_steps: 5 + warmup_steps: 3 active_steps: 2 num_cycles: 1 diff --git a/recipes/configs/llama3_1/405B_qlora.yaml b/recipes/configs/llama3_1/405B_qlora.yaml index ed978c1a51..cc4eead534 100644 --- a/recipes/configs/llama3_1/405B_qlora.yaml +++ b/recipes/configs/llama3_1/405B_qlora.yaml @@ -17,11 +17,11 @@ # Model Arguments model: _component_: torchtune.models.llama3_1.qlora_llama3_1_405b - lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 16 - lora_alpha: 32 + lora_rank: 16 # higher increases accuracy and memory + lora_alpha: 32 # usually alpha=2*rank tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer @@ -41,8 +41,8 @@ save_adapter_weights_only: True # Set to false to save the whole model + adapter # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed train_on_input: True seed: null shuffle: True @@ -67,8 +67,8 @@ fsdp: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 16 -compile: False # pytorch compile, set to true for perf/memory improvement +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/qlora_finetune_output @@ -81,5 +81,30 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_1/70B_full.yaml b/recipes/configs/llama3_1/70B_full.yaml index 34fabe663f..8e70706414 100644 --- a/recipes/configs/llama3_1/70B_full.yaml +++ b/recipes/configs/llama3_1/70B_full.yaml @@ -24,8 +24,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -87,18 +87,19 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 +gradient_accumulation_steps: 1 # Use to increase virtual batch size # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory -custom_sharded_layers: ['tok_embeddings', 'output'] +custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed. fsdp_cpu_offload: True -compile: False # pytorch compile, set to true for perf/memory improvement +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Reduced precision dtype: bf16 @@ -110,3 +111,28 @@ metric_logger: output_dir: /tmp/full-llama3_1-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_1/70B_lora.yaml b/recipes/configs/llama3_1/70B_lora.yaml index ee19446238..a89c01b4c1 100644 --- a/recipes/configs/llama3_1/70B_lora.yaml +++ b/recipes/configs/llama3_1/70B_lora.yaml @@ -11,11 +11,11 @@ # Model Arguments model: _component_: torchtune.models.llama3_1.lora_llama3_1_70b - lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 16 - lora_alpha: 32 + lora_rank: 16 # higher increases accuracy and memory + lora_alpha: 32 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -66,8 +66,8 @@ save_adapter_weights_only: True # Set to false to save the whole model + adapter # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -88,8 +88,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False # pytorch compile, set to true for perf/memory improvement +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora-llama3_1-finetune-output @@ -102,5 +102,30 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_1/8B_full.yaml b/recipes/configs/llama3_1/8B_full.yaml index 71ab8eedeb..b85c70ed1c 100644 --- a/recipes/configs/llama3_1/8B_full.yaml +++ b/recipes/configs/llama3_1/8B_full.yaml @@ -26,8 +26,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -60,17 +60,17 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 +gradient_accumulation_steps: 1 # Use to increase virtual batch size # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory -custom_sharded_layers: ['tok_embeddings', 'output'] -compile: False # pytorch compile, set to true for perf/memory improvement +custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed. # Reduced precision dtype: bf16 @@ -82,3 +82,28 @@ metric_logger: output_dir: /tmp/full-llama3.1-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_1/8B_full_single_device.yaml b/recipes/configs/llama3_1/8B_full_single_device.yaml index b26df8cb67..7e06ca4a6d 100644 --- a/recipes/configs/llama3_1/8B_full_single_device.yaml +++ b/recipes/configs/llama3_1/8B_full_single_device.yaml @@ -28,8 +28,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -60,15 +60,15 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -optimizer_in_bwd: True -compile: False # pytorch compile, set to true for perf/memory improvement +gradient_accumulation_steps: 1 # Use to increase virtual batch size +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +compile: False # pytorch compile, set to true for better perf/memory # Training environment device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -95,14 +95,14 @@ profiler: cuda: True #trace options passed to `torch.profiler.profile` - profile_memory: True + profile_memory: False with_stack: False record_shapes: True with_flops: False # `torch.profiler.schedule` options: # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat - wait_steps: 1 - warmup_steps: 2 - active_steps: 1 + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 num_cycles: 1 diff --git a/recipes/configs/llama3_1/8B_lora.yaml b/recipes/configs/llama3_1/8B_lora.yaml index 0793b8a57c..b889f20fe2 100644 --- a/recipes/configs/llama3_1/8B_lora.yaml +++ b/recipes/configs/llama3_1/8B_lora.yaml @@ -26,11 +26,11 @@ tokenizer: # Model Arguments model: _component_: torchtune.models.llama3_1.lora_llama3_1_8b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 checkpointer: @@ -50,8 +50,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -72,8 +72,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 32 -compile: False # pytorch compile, set to true for perf/memory improvement +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora_finetune_output @@ -86,5 +86,30 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_1/8B_lora_single_device.yaml b/recipes/configs/llama3_1/8B_lora_single_device.yaml index 12ef984db9..f631dcfd7e 100644 --- a/recipes/configs/llama3_1/8B_lora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_lora_single_device.yaml @@ -19,11 +19,11 @@ # Model Arguments model: _component_: torchtune.models.llama3_1.lora_llama3_1_8b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 # Tokenizer @@ -49,8 +49,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -71,8 +71,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 64 -compile: False # pytorch compile, set to true for perf/memory improvement +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora_finetune_output @@ -87,7 +87,7 @@ device: cuda dtype: bf16 # Activations Memory -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Profiler (disabled) @@ -111,6 +111,6 @@ profiler: # `torch.profiler.schedule` options: # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat wait_steps: 5 - warmup_steps: 5 + warmup_steps: 3 active_steps: 2 num_cycles: 1 diff --git a/recipes/configs/llama3_1/8B_qlora_single_device.yaml b/recipes/configs/llama3_1/8B_qlora_single_device.yaml index 0b44eaf383..57c8cdb513 100644 --- a/recipes/configs/llama3_1/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_qlora_single_device.yaml @@ -18,11 +18,11 @@ # Model Arguments model: _component_: torchtune.models.llama3_1.qlora_llama3_1_8b - lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 # Tokenizer @@ -48,8 +48,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -70,8 +70,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 16 -compile: False # pytorch compile, set to true for perf/memory improvement +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/qlora_finetune_output/ @@ -86,7 +86,7 @@ device: cuda dtype: bf16 # Activations Offloading -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Profiler (disabled) @@ -110,7 +110,7 @@ profiler: # `torch.profiler.schedule` options: # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat wait_steps: 5 - warmup_steps: 5 + warmup_steps: 3 active_steps: 2 num_cycles: 1 diff --git a/recipes/configs/llama3_2/1B_full.yaml b/recipes/configs/llama3_2/1B_full.yaml index 694a14b573..437c222d28 100644 --- a/recipes/configs/llama3_2/1B_full.yaml +++ b/recipes/configs/llama3_2/1B_full.yaml @@ -26,8 +26,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -57,16 +57,17 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 4 +gradient_accumulation_steps: 8 # Use to increase virtual batch size # Training env device: cuda # Memory management -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory -compile: False # pytorch compile, set to true for perf/memory improvement +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Reduced precision dtype: bf16 @@ -78,3 +79,28 @@ metric_logger: output_dir: /tmp/full-llama3.2-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_2/1B_full_single_device.yaml b/recipes/configs/llama3_2/1B_full_single_device.yaml index fe641f3479..4f367f03a5 100644 --- a/recipes/configs/llama3_2/1B_full_single_device.yaml +++ b/recipes/configs/llama3_2/1B_full_single_device.yaml @@ -28,8 +28,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -57,15 +57,15 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -optimizer_in_bwd: True -compile: False # pytorch compile, set to true for perf/memory improvement +gradient_accumulation_steps: 1 # Use to increase virtual batch size +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +compile: False # pytorch compile, set to true for better perf/memory # Training environment device: cuda # Memory management -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -92,14 +92,14 @@ profiler: cuda: True #trace options passed to `torch.profiler.profile` - profile_memory: True + profile_memory: False with_stack: False record_shapes: True with_flops: False # `torch.profiler.schedule` options: # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat - wait_steps: 1 - warmup_steps: 2 - active_steps: 1 + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 num_cycles: 1 diff --git a/recipes/configs/llama3_2/1B_lora.yaml b/recipes/configs/llama3_2/1B_lora.yaml index 17ee6a8625..4903e482ba 100644 --- a/recipes/configs/llama3_2/1B_lora.yaml +++ b/recipes/configs/llama3_2/1B_lora.yaml @@ -28,9 +28,8 @@ model: _component_: torchtune.models.llama3_2.lora_llama3_2_1b lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True - apply_lora_to_output: False - lora_rank: 64 - lora_alpha: 128 + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank lora_dropout: 0.0 checkpointer: @@ -47,8 +46,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 4 @@ -69,8 +68,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 4 -compile: False # pytorch compile, set to true for perf/memory improvement +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora_finetune_output @@ -83,5 +82,30 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_2/1B_lora_single_device.yaml b/recipes/configs/llama3_2/1B_lora_single_device.yaml index 3e23a6e56a..911129987c 100644 --- a/recipes/configs/llama3_2/1B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/1B_lora_single_device.yaml @@ -21,9 +21,8 @@ model: _component_: torchtune.models.llama3_2.lora_llama3_2_1b lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True - apply_lora_to_output: False - lora_rank: 64 - lora_alpha: 128 + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank lora_dropout: 0.0 # Tokenizer @@ -46,8 +45,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 4 @@ -68,8 +67,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 4 -compile: False # pytorch compile, set to true for perf/memory improvement +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora_finetune_output @@ -84,7 +83,7 @@ device: cuda dtype: bf16 # Activations Memory -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory # Profiler (disabled) @@ -108,6 +107,6 @@ profiler: # `torch.profiler.schedule` options: # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat wait_steps: 5 - warmup_steps: 5 + warmup_steps: 3 active_steps: 2 num_cycles: 1 diff --git a/recipes/configs/llama3_2/1B_qlora_single_device.yaml b/recipes/configs/llama3_2/1B_qlora_single_device.yaml index d4530df081..3573ae38fc 100644 --- a/recipes/configs/llama3_2/1B_qlora_single_device.yaml +++ b/recipes/configs/llama3_2/1B_qlora_single_device.yaml @@ -20,9 +20,8 @@ model: _component_: torchtune.models.llama3_2.qlora_llama3_2_1b lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True - apply_lora_to_output: False - lora_rank: 64 - lora_alpha: 128 + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank lora_dropout: 0.0 # Tokenizer @@ -45,8 +44,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 4 @@ -67,8 +66,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 4 -compile: False # pytorch compile, set to true for perf/memory improvement +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora_finetune_output @@ -83,7 +82,7 @@ device: cuda dtype: bf16 # Activations Memory -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory # Profiler (disabled) @@ -107,6 +106,6 @@ profiler: # `torch.profiler.schedule` options: # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat wait_steps: 5 - warmup_steps: 5 + warmup_steps: 3 active_steps: 2 num_cycles: 1 diff --git a/recipes/configs/llama3_2/3B_full.yaml b/recipes/configs/llama3_2/3B_full.yaml index 2d9e9d2f3a..54f810c33a 100644 --- a/recipes/configs/llama3_2/3B_full.yaml +++ b/recipes/configs/llama3_2/3B_full.yaml @@ -26,8 +26,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -58,15 +58,16 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 4 +gradient_accumulation_steps: 8 # Use to increase virtual batch size # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory -compile: False # pytorch compile, set to true for perf/memory improvement +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Reduced precision dtype: bf16 @@ -78,3 +79,28 @@ metric_logger: output_dir: /tmp/full-llama3.2-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_2/3B_full_single_device.yaml b/recipes/configs/llama3_2/3B_full_single_device.yaml index 16f5840edf..cffa1fb83e 100644 --- a/recipes/configs/llama3_2/3B_full_single_device.yaml +++ b/recipes/configs/llama3_2/3B_full_single_device.yaml @@ -28,8 +28,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -58,15 +58,15 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -optimizer_in_bwd: True -compile: False # pytorch compile, set to true for perf/memory improvement +gradient_accumulation_steps: 1 # Use to increase virtual batch size +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +compile: False # pytorch compile, set to true for better perf/memory # Training environment device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -93,14 +93,14 @@ profiler: cuda: True #trace options passed to `torch.profiler.profile` - profile_memory: True + profile_memory: False with_stack: False record_shapes: True with_flops: False # `torch.profiler.schedule` options: # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat - wait_steps: 1 - warmup_steps: 2 - active_steps: 1 + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 num_cycles: 1 diff --git a/recipes/configs/llama3_2/3B_lora.yaml b/recipes/configs/llama3_2/3B_lora.yaml index a2f00ad19e..0e790b20cb 100644 --- a/recipes/configs/llama3_2/3B_lora.yaml +++ b/recipes/configs/llama3_2/3B_lora.yaml @@ -28,9 +28,8 @@ model: _component_: torchtune.models.llama3_2.lora_llama3_2_3b lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True - apply_lora_to_output: False - lora_rank: 64 - lora_alpha: 128 + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank lora_dropout: 0.0 checkpointer: @@ -48,8 +47,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 4 @@ -70,8 +69,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 4 -compile: False # pytorch compile, set to true for perf/memory improvement +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora_finetune_output @@ -84,5 +83,30 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_2/3B_lora_single_device.yaml b/recipes/configs/llama3_2/3B_lora_single_device.yaml index 4add5d63aa..29e021d150 100644 --- a/recipes/configs/llama3_2/3B_lora_single_device.yaml +++ b/recipes/configs/llama3_2/3B_lora_single_device.yaml @@ -21,9 +21,8 @@ model: _component_: torchtune.models.llama3_2.lora_llama3_2_3b lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True - apply_lora_to_output: False - lora_rank: 64 - lora_alpha: 128 + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank lora_dropout: 0.0 # Tokenizer @@ -47,8 +46,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 4 @@ -69,8 +68,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 4 -compile: False # pytorch compile, set to true for perf/memory improvement +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora_finetune_output @@ -85,7 +84,7 @@ device: cuda dtype: bf16 # Activations Memory -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Profiler (disabled) @@ -109,6 +108,6 @@ profiler: # `torch.profiler.schedule` options: # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat wait_steps: 5 - warmup_steps: 5 + warmup_steps: 3 active_steps: 2 num_cycles: 1 diff --git a/recipes/configs/llama3_2/3B_qlora_single_device.yaml b/recipes/configs/llama3_2/3B_qlora_single_device.yaml index 520f616a79..7ffa146e51 100644 --- a/recipes/configs/llama3_2/3B_qlora_single_device.yaml +++ b/recipes/configs/llama3_2/3B_qlora_single_device.yaml @@ -20,9 +20,8 @@ model: _component_: torchtune.models.llama3_2.qlora_llama3_2_3b lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True - apply_lora_to_output: False - lora_rank: 64 - lora_alpha: 128 + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank lora_dropout: 0.0 # Tokenizer @@ -46,8 +45,8 @@ save_adapter_weights_only: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 4 @@ -68,8 +67,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 4 -compile: False # pytorch compile, set to true for perf/memory improvement +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/lora_finetune_output @@ -84,7 +83,7 @@ device: cuda dtype: bf16 # Activations Memory -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Profiler (disabled) @@ -108,6 +107,6 @@ profiler: # `torch.profiler.schedule` options: # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat wait_steps: 5 - warmup_steps: 5 + warmup_steps: 3 active_steps: 2 num_cycles: 1 diff --git a/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml b/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml index 1cc864b900..8ef1bcbea3 100644 --- a/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml +++ b/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml @@ -63,6 +63,7 @@ teacher_checkpointer: # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 4 @@ -87,7 +88,8 @@ kd_ratio: 0.5 # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 32 +compile: False # pytorch compile, set to true for better perf/memory +gradient_accumulation_steps: 8 # Use to increase virtual batch size # Logging output_dir: /tmp/kd_output @@ -100,7 +102,7 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama3_2/knowledge_distillation_single_device.yaml b/recipes/configs/llama3_2/knowledge_distillation_single_device.yaml index 6a3f85f257..e08fb8ad7a 100644 --- a/recipes/configs/llama3_2/knowledge_distillation_single_device.yaml +++ b/recipes/configs/llama3_2/knowledge_distillation_single_device.yaml @@ -7,7 +7,6 @@ # tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" # # You get better results using KD if the teacher model has already been fine-tuned on the target dataset: - packed: False # Set to true for great speed ups # tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device # # To launch on a single device, run the following command from root: @@ -63,8 +62,8 @@ teacher_checkpointer: # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 4 @@ -89,8 +88,8 @@ kd_ratio: 0.5 # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 32 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/kd_output @@ -105,7 +104,7 @@ device: cuda dtype: bf16 # Activations Memory -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory # Profiler (disabled) profiler: @@ -128,6 +127,6 @@ profiler: # `torch.profiler.schedule` options: # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat wait_steps: 5 - warmup_steps: 5 + warmup_steps: 3 active_steps: 2 num_cycles: 1 diff --git a/recipes/configs/llama3_2_vision/11B_evaluation.yaml b/recipes/configs/llama3_2_vision/11B_evaluation.yaml index 44f7c5d925..13bbabf549 100644 --- a/recipes/configs/llama3_2_vision/11B_evaluation.yaml +++ b/recipes/configs/llama3_2_vision/11B_evaluation.yaml @@ -7,7 +7,7 @@ # pip install lm_eval==0.4.5 # # To launch, run the following command from root torchtune directory: -# tune run eleuther_eval --config llama3_2_vision/evaluation +# tune run eleuther_eval --config llama3_2_vision/11B_evaluation # Model arguments model: diff --git a/recipes/configs/llama3_2_vision/11B_full.yaml b/recipes/configs/llama3_2_vision/11B_full.yaml index f40cde9f90..51173f162a 100644 --- a/recipes/configs/llama3_2_vision/11B_full.yaml +++ b/recipes/configs/llama3_2_vision/11B_full.yaml @@ -44,8 +44,8 @@ resume_from_checkpoint: False # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.multimodal.the_cauldron_dataset + packed: False # True increases speed subset: ocrvqa seed: null shuffle: True @@ -55,24 +55,24 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 4 +gradient_accumulation_steps: 8 # Use to increase virtual batch size optimizer: _component_: torch.optim.AdamW lr: 2e-5 fused: True -optimizer_in_bwd: False # Set to True to use less memory. Requires gradient_accumulation_steps=1. +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss clip_grad_norm: 1.0 -compile: False # pytorch compile, set to true for perf/memory improvement +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True -custom_sharded_layers: ['decoder.tok_embeddings'] +enable_activation_checkpointing: True # True reduces memory +custom_sharded_layers: ['decoder.tok_embeddings'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed. dtype: bf16 # Logging @@ -82,3 +82,28 @@ metric_logger: log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_2_vision/11B_full_single_device.yaml b/recipes/configs/llama3_2_vision/11B_full_single_device.yaml index b5c4141675..d10afdcbfe 100644 --- a/recipes/configs/llama3_2_vision/11B_full_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_full_single_device.yaml @@ -46,8 +46,8 @@ resume_from_checkpoint: False # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.multimodal.the_cauldron_dataset + packed: False # True increases speed subset: ocrvqa seed: null shuffle: True @@ -57,22 +57,22 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 16 +gradient_accumulation_steps: 8 # Use to increase virtual batch size optimizer: _component_: bitsandbytes.optim.PagedAdamW8bit lr: 2e-5 -optimizer_in_bwd: False # Set to True to use less memory. Requires gradient_accumulation_steps=1. +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss clip_grad_norm: 1.0 -compile: False # pytorch compile, set to true for perf/memory improvement +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory dtype: bf16 # Logging diff --git a/recipes/configs/llama3_2_vision/11B_lora.yaml b/recipes/configs/llama3_2_vision/11B_lora.yaml index 94a394965c..b394b9ffbf 100644 --- a/recipes/configs/llama3_2_vision/11B_lora.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora.yaml @@ -21,11 +21,11 @@ model: decoder_trainable: "frozen" encoder_trainable: "lora" fusion_trainable: "lora" - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 image_size: 560 # Make sure this matches the image_size in tokenizer @@ -51,8 +51,8 @@ save_adapter_weights_only: False # PeFT formatting not available yet. This will # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.multimodal.the_cauldron_dataset + packed: False # True increases speed subset: ocrvqa seed: null shuffle: True @@ -62,7 +62,7 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 4 +gradient_accumulation_steps: 8 # Use to increase virtual batch size optimizer: _component_: torch.optim.AdamW fused: True @@ -75,13 +75,13 @@ lr_scheduler: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss clip_grad_norm: 1.0 -compile: False # pytorch compile, set to true for perf/memory improvement +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory dtype: bf16 # Logging @@ -91,3 +91,28 @@ metric_logger: log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml index 121a50416e..050c6b0383 100644 --- a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml @@ -19,11 +19,11 @@ model: decoder_trainable: "frozen" encoder_trainable: "lora" fusion_trainable: "lora" - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 image_size: 560 # Make sure this matches the image_size in tokenizer @@ -49,8 +49,8 @@ save_adapter_weights_only: False # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.multimodal.the_cauldron_dataset + packed: False # True increases speed subset: ocrvqa seed: null shuffle: True @@ -60,7 +60,7 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 16 +gradient_accumulation_steps: 8 # Use to increase virtual batch size optimizer: _component_: torch.optim.AdamW fused: True @@ -73,13 +73,13 @@ lr_scheduler: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss clip_grad_norm: 1.0 -compile: False # pytorch compile, set to true for perf/memory improvement +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory dtype: bf16 # Logging @@ -103,14 +103,14 @@ profiler: cuda: True #trace options passed to `torch.profiler.profile` - profile_memory: True + profile_memory: False with_stack: False record_shapes: True with_flops: False # `torch.profiler.schedule` options: # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat - wait_steps: 1 - warmup_steps: 2 - active_steps: 1 + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 num_cycles: 1 diff --git a/recipes/configs/llama3_2_vision/11B_qlora.yaml b/recipes/configs/llama3_2_vision/11B_qlora.yaml index 5bd77570cd..d18209adfe 100644 --- a/recipes/configs/llama3_2_vision/11B_qlora.yaml +++ b/recipes/configs/llama3_2_vision/11B_qlora.yaml @@ -20,11 +20,11 @@ model: decoder_trainable: "frozen" encoder_trainable: "lora" fusion_trainable: "lora" - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 image_size: 560 # Make sure this matches the image_size in tokenizer @@ -60,7 +60,7 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 4 +gradient_accumulation_steps: 8 # Use to increase virtual batch size optimizer: _component_: torch.optim.AdamW fused: True @@ -72,13 +72,13 @@ lr_scheduler: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss clip_grad_norm: 1.0 -compile: False # set it to True for better memory and performance +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory dtype: bf16 # Logging @@ -88,3 +88,28 @@ metric_logger: log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs log_every_n_steps: 1 log_peak_memory_stats: False + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml index 5a02ff2406..2829cb4d43 100644 --- a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml @@ -19,11 +19,11 @@ model: decoder_trainable: "frozen" encoder_trainable: "lora" fusion_trainable: "lora" - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 image_size: 560 # Make sure this matches the image_size in tokenizer @@ -59,7 +59,7 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 16 +gradient_accumulation_steps: 8 # Use to increase virtual batch size optimizer: _component_: torch.optim.AdamW fused: True @@ -72,13 +72,13 @@ lr_scheduler: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss clip_grad_norm: 1.0 -compile: False # set it to True for better memory and performance +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory dtype: bf16 # Logging @@ -102,14 +102,14 @@ profiler: cuda: True #trace options passed to `torch.profiler.profile` - profile_memory: True + profile_memory: False with_stack: False record_shapes: True with_flops: False # `torch.profiler.schedule` options: # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat - wait_steps: 1 - warmup_steps: 2 - active_steps: 1 + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 num_cycles: 1 diff --git a/recipes/configs/llama3_2_vision/90B_full.yaml b/recipes/configs/llama3_2_vision/90B_full.yaml index a76ff89405..09a7a22769 100644 --- a/recipes/configs/llama3_2_vision/90B_full.yaml +++ b/recipes/configs/llama3_2_vision/90B_full.yaml @@ -1,17 +1,17 @@ -# Config for single device full finetuning in full_finetune_single_device.py +# Config for multi-device full finetuning in full_finetune_distributed.py # using a Llama3.2 90B Vision Instruct model # # This config assumes that you've run the following command before launching: # tune download meta-llama/Llama-3.2-90B-Vision-Instruct --output-dir /tmp/Llama-3.2-90B-Vision-Instruct --ignore-patterns "original/consolidated*" # -# To launch on a single device, run the following command from root: +# To launch on 8 devices, run the following command from root: # tune run --nproc_per_node 8 full_finetune_distributed --config llama3_2_vision/90B_full # # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training: -# tune run --nproc_per_node 8 full_finetune_distributed --config llama3_2_vision/90B_full checkpointer.checkpoint_dir= +# tune run --nproc_per_node 8 full_finetune_distributed --config llama3_2_vision/90B_full checkpointer.checkpoint_dir= # -# This config works best when the model is being fine-tuned on 2+ GPUs. +# This config needs 8 GPUs to run. # Model arguments model: @@ -52,24 +52,24 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 4 +gradient_accumulation_steps: 8 # Use to increase virtual batch size optimizer: _component_: torch.optim.AdamW lr: 2e-5 fused: True -optimizer_in_bwd: False # Set to True to use less memory. Requires gradient_accumulation_steps=1. +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss clip_grad_norm: 1.0 -compile: False # set it to True for better memory and performance +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True -custom_sharded_layers: ['decoder.tok_embeddings'] +enable_activation_checkpointing: True # True reduces memory +custom_sharded_layers: ['decoder.tok_embeddings'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed. dtype: bf16 # Logging @@ -79,3 +79,28 @@ metric_logger: log_dir: /tmp/Llama-3.2-90B-Vision-Instruct/logs log_every_n_steps: 1 log_peak_memory_stats: False + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_2_vision/90B_lora.yaml b/recipes/configs/llama3_2_vision/90B_lora.yaml index 49a498915d..14388cc4ea 100644 --- a/recipes/configs/llama3_2_vision/90B_lora.yaml +++ b/recipes/configs/llama3_2_vision/90B_lora.yaml @@ -4,14 +4,14 @@ # This config assumes that you've run the following command before launching: # tune download meta-llama/Llama-3.2-90B-Vision-Instruct --output-dir /tmp/Llama-3.2-90B-Vision-Instruct --ignore-patterns "original/consolidated*" # -# To launch on 2 devices, run the following command from root: -# tune run --nproc_per_node 8 lora_finetune_distributed --config llama3_2_vision/90B_lora +# To launch on 4 devices, run the following command from root: +# tune run --nproc_per_node 4 lora_finetune_distributed --config llama3_2_vision/90B_lora # # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training: -# tune run --nproc_per_node 8 lora_finetune_distributed --config llama3_2_vision/90B_lora checkpointer.checkpoint_dir= +# tune run --nproc_per_node 4 lora_finetune_distributed --config llama3_2_vision/90B_lora checkpointer.checkpoint_dir= # -# This config works best when the model is being fine-tuned on 2+ GPUs. +# This config works best when the model is being fine-tuned on 4+ GPUs. # Model arguments model: @@ -19,11 +19,11 @@ model: decoder_trainable: "frozen" encoder_trainable: "lora" fusion_trainable: "lora" - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 image_size: 560 # Make sure this matches the image_size in tokenizer @@ -59,7 +59,7 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 4 +gradient_accumulation_steps: 8 # Use to increase virtual batch size optimizer: _component_: torch.optim.AdamW fused: True @@ -72,13 +72,13 @@ lr_scheduler: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss clip_grad_norm: 1.0 -compile: False # set it to True for better memory and performance +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory dtype: bf16 # Logging @@ -88,3 +88,28 @@ metric_logger: log_dir: /tmp/Llama-3.2-90B-Vision-Instruct/logs log_every_n_steps: 1 log_peak_memory_stats: False + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/llama3_2_vision/90B_qlora.yaml b/recipes/configs/llama3_2_vision/90B_qlora.yaml index 8cd6a73324..30810e90b1 100644 --- a/recipes/configs/llama3_2_vision/90B_qlora.yaml +++ b/recipes/configs/llama3_2_vision/90B_qlora.yaml @@ -4,15 +4,14 @@ # This config assumes that you've run the following command before launching: # tune download meta-llama/Llama-3.2-90B-Vision-Instruct --output-dir /tmp/Llama-3.2-90B-Vision-Instruct --ignore-patterns "original/consolidated*" # -# To launch on 2 devices, run the following command from root: -# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/90B_qlora +# To launch on 4 devices, run the following command from root: +# tune run --nproc_per_node 4 lora_finetune_distributed --config llama3_2_vision/90B_qlora # # You can add specific overrides through the command line. For example # to override the checkpointer directory while launching training: -# tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_2_vision/90B_qlora checkpointer.checkpoint_dir= +# tune run --nproc_per_node 4 lora_finetune_distributed --config llama3_2_vision/90B_qlora checkpointer.checkpoint_dir= # -# This config works best when the model is being fine-tuned on 2+ GPUs. -# For single device QLoRA finetuning please use 90B_qlora_single_device.yaml +# This config works best when the model is being fine-tuned on 4+ GPUs. # Model arguments model: @@ -20,11 +19,11 @@ model: decoder_trainable: "frozen" encoder_trainable: "lora" fusion_trainable: "lora" - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 image_size: 560 # Make sure this matches the image_size in tokenizer @@ -60,7 +59,7 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 4 +gradient_accumulation_steps: 8 # Use to increase virtual batch size optimizer: _component_: torch.optim.AdamW fused: True @@ -72,13 +71,13 @@ lr_scheduler: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss clip_grad_norm: 1.0 -compile: False # set it to True for better memory and performance +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory dtype: bf16 # Logging @@ -88,3 +87,28 @@ metric_logger: log_dir: /tmp/Llama-3.2-90B-Vision-Instruct/logs log_every_n_steps: 1 log_peak_memory_stats: False + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/mistral/7B_full.yaml b/recipes/configs/mistral/7B_full.yaml index db242d2b6f..2452ef275b 100644 --- a/recipes/configs/mistral/7B_full.yaml +++ b/recipes/configs/mistral/7B_full.yaml @@ -29,8 +29,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -60,14 +60,15 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -80,3 +81,28 @@ metric_logger: output_dir: /tmp/Mistral-7B-v0.1/ log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/mistral/7B_full_low_memory.yaml b/recipes/configs/mistral/7B_full_low_memory.yaml index f25c150325..7ae9f916ab 100644 --- a/recipes/configs/mistral/7B_full_low_memory.yaml +++ b/recipes/configs/mistral/7B_full_low_memory.yaml @@ -31,8 +31,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True @@ -61,21 +61,21 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -optimizer_in_bwd: True +gradient_accumulation_steps: 1 # Use to increase virtual batch size +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: True # True reduces memory # Reduced precision dtype: bf16 # Model compilation -compile: False +compile: False # pytorch compile, set to true for better perf/memory # Logging metric_logger: @@ -84,3 +84,28 @@ metric_logger: output_dir: /tmp/Mistral-7B-v0.1/ log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/mistral/7B_full_ppo_low_memory.yaml b/recipes/configs/mistral/7B_full_ppo_low_memory.yaml index db3b3f5e86..e05be85ff6 100644 --- a/recipes/configs/mistral/7B_full_ppo_low_memory.yaml +++ b/recipes/configs/mistral/7B_full_ppo_low_memory.yaml @@ -127,16 +127,16 @@ batch_size: 64 num_steps: 10000 ppo_epochs: 2 ppo_batch_size: 32 -gradient_accumulation_steps: 1 +gradient_accumulation_steps: 1 # Use to increase virtual batch size # Memory management and performance -compile: True +compile: True # pytorch compile, set to true for better perf/memory optimizer: _component_: bitsandbytes.optim.PagedAdamW lr: 3e-6 -optimizer_in_bwd: True +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 log_peak_memory_stats: True -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/mistral/7B_lora.yaml b/recipes/configs/mistral/7B_lora.yaml index 9ba9976f2a..2724a0754d 100644 --- a/recipes/configs/mistral/7B_lora.yaml +++ b/recipes/configs/mistral/7B_lora.yaml @@ -30,19 +30,19 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True # Model Arguments model: _component_: torchtune.models.mistral.lora_mistral_7b - lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True apply_lora_to_output: True - lora_rank: 64 - lora_alpha: 16 + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 checkpointer: @@ -74,14 +74,14 @@ loss: batch_size: 4 epochs: 3 max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -94,3 +94,28 @@ metric_logger: output_dir: /tmp/Mistral-7B-v0.1 log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/mistral/7B_lora_single_device.yaml b/recipes/configs/mistral/7B_lora_single_device.yaml index 6380448331..be143ce480 100644 --- a/recipes/configs/mistral/7B_lora_single_device.yaml +++ b/recipes/configs/mistral/7B_lora_single_device.yaml @@ -27,19 +27,19 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True # Model Arguments model: _component_: torchtune.models.mistral.lora_mistral_7b - lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True apply_lora_to_output: True - lora_rank: 64 - lora_alpha: 16 + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 checkpointer: @@ -71,14 +71,14 @@ loss: batch_size: 4 epochs: 3 max_steps_per_epoch: null -gradient_accumulation_steps: 4 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision diff --git a/recipes/configs/mistral/7B_qlora_single_device.yaml b/recipes/configs/mistral/7B_qlora_single_device.yaml index 42c88af742..b3c1337901 100644 --- a/recipes/configs/mistral/7B_qlora_single_device.yaml +++ b/recipes/configs/mistral/7B_qlora_single_device.yaml @@ -28,19 +28,19 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed seed: null shuffle: True # Model Arguments model: _component_: torchtune.models.mistral.qlora_mistral_7b - lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 64 - lora_alpha: 16 + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 checkpointer: @@ -72,14 +72,14 @@ loss: batch_size: 4 epochs: 3 max_steps_per_epoch: null -gradient_accumulation_steps: 4 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision diff --git a/recipes/configs/phi3/mini_full.yaml b/recipes/configs/phi3/mini_full.yaml index bd5b00702c..1319ab816d 100644 --- a/recipes/configs/phi3/mini_full.yaml +++ b/recipes/configs/phi3/mini_full.yaml @@ -42,8 +42,8 @@ resume_from_checkpoint: False # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True @@ -51,20 +51,21 @@ shuffle: True epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 16 +gradient_accumulation_steps: 8 # Use to increase virtual batch size optimizer: _component_: torch.optim.AdamW fused: True lr: 5e-6 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss -compile: False +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory dtype: bf16 @@ -75,3 +76,28 @@ metric_logger: log_dir: /tmp/Phi-3-mini-4k-instruct/logs log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/phi3/mini_full_low_memory.yaml b/recipes/configs/phi3/mini_full_low_memory.yaml index 1fbb10d10f..ad7e0f4046 100644 --- a/recipes/configs/phi3/mini_full_low_memory.yaml +++ b/recipes/configs/phi3/mini_full_low_memory.yaml @@ -44,8 +44,8 @@ resume_from_checkpoint: False # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True @@ -53,20 +53,20 @@ shuffle: True epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 1 +gradient_accumulation_steps: 1 # Use to increase virtual batch size optimizer: _component_: bitsandbytes.optim.PagedAdamW lr: 5e-6 -optimizer_in_bwd: True +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss -compile: False +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: True # True reduces memory dtype: bf16 @@ -77,3 +77,28 @@ metric_logger: log_dir: /tmp/Phi-3-mini-4k-instruct/logs log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/phi3/mini_lora.yaml b/recipes/configs/phi3/mini_lora.yaml index 2391f9f383..5547be21e0 100644 --- a/recipes/configs/phi3/mini_lora.yaml +++ b/recipes/configs/phi3/mini_lora.yaml @@ -20,11 +20,11 @@ # Model arguments model: _component_: torchtune.models.phi3.lora_phi3_mini - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 # Tokenizer @@ -49,8 +49,8 @@ save_adapter_weights_only: False # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True @@ -58,7 +58,7 @@ shuffle: True epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 16 +gradient_accumulation_steps: 8 # Use to increase virtual batch size optimizer: _component_: torch.optim.AdamW fused: True @@ -69,13 +69,13 @@ lr_scheduler: num_warmup_steps: 100 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss -compile: False +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory dtype: bf16 @@ -86,3 +86,28 @@ metric_logger: log_dir: /tmp/Phi-3-mini-4k-instruct/logs log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/phi3/mini_lora_single_device.yaml b/recipes/configs/phi3/mini_lora_single_device.yaml index cec51773dc..533972b0e1 100644 --- a/recipes/configs/phi3/mini_lora_single_device.yaml +++ b/recipes/configs/phi3/mini_lora_single_device.yaml @@ -18,11 +18,11 @@ # Model arguments model: _component_: torchtune.models.phi3.lora_phi3_mini - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 # Tokenizer @@ -47,8 +47,8 @@ save_adapter_weights_only: False # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True @@ -56,7 +56,7 @@ shuffle: True epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 16 +gradient_accumulation_steps: 8 # Use to increase virtual batch size optimizer: _component_: torch.optim.AdamW fused: True @@ -67,13 +67,13 @@ lr_scheduler: num_warmup_steps: 100 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss -compile: False +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision diff --git a/recipes/configs/phi3/mini_qlora_single_device.yaml b/recipes/configs/phi3/mini_qlora_single_device.yaml index ceaa5b3530..e89bd1a542 100644 --- a/recipes/configs/phi3/mini_qlora_single_device.yaml +++ b/recipes/configs/phi3/mini_qlora_single_device.yaml @@ -18,11 +18,11 @@ # Model arguments model: _component_: torchtune.models.phi3.qlora_phi3_mini - lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 # Tokenizer @@ -47,8 +47,8 @@ save_adapter_weights_only: False # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True @@ -56,7 +56,7 @@ shuffle: True epochs: 1 max_steps_per_epoch: null batch_size: 2 -gradient_accumulation_steps: 16 +gradient_accumulation_steps: 8 # Use to increase virtual batch size optimizer: _component_: torch.optim.AdamW fused: True @@ -67,13 +67,13 @@ lr_scheduler: num_warmup_steps: 100 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss -compile: False +compile: False # pytorch compile, set to true for better perf/memory # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision diff --git a/recipes/configs/qwen2/0.5B_full.yaml b/recipes/configs/qwen2/0.5B_full.yaml index 133e24b1cc..ca5863c37c 100644 --- a/recipes/configs/qwen2/0.5B_full.yaml +++ b/recipes/configs/qwen2/0.5B_full.yaml @@ -26,8 +26,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True @@ -56,14 +56,15 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 16 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -76,3 +77,28 @@ metric_logger: output_dir: /tmp/Qwen2-0.5B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2/0.5B_full_single_device.yaml b/recipes/configs/qwen2/0.5B_full_single_device.yaml index 14ed13e213..7e491216c1 100644 --- a/recipes/configs/qwen2/0.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_full_single_device.yaml @@ -24,8 +24,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True @@ -54,17 +54,17 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss -optimizer_in_bwd: False +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 max_steps_per_epoch: null -gradient_accumulation_steps: 8 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training environment device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -77,3 +77,28 @@ metric_logger: output_dir: /tmp/Qwen2-0.5B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2/0.5B_lora.yaml b/recipes/configs/qwen2/0.5B_lora.yaml index a605229d2b..9f54c5fdbe 100644 --- a/recipes/configs/qwen2/0.5B_lora.yaml +++ b/recipes/configs/qwen2/0.5B_lora.yaml @@ -21,10 +21,10 @@ # Model Arguments model: _component_: torchtune.models.qwen2.lora_qwen2_0_5b - lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] - apply_lora_to_mlp: False - lora_rank: 32 - lora_alpha: 64 + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + lora_rank: 32 # higher increases accuracy and memory + lora_alpha: 64 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -46,8 +46,8 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True @@ -70,8 +70,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 4 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune @@ -85,7 +85,7 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler diff --git a/recipes/configs/qwen2/0.5B_lora_single_device.yaml b/recipes/configs/qwen2/0.5B_lora_single_device.yaml index 0052086a03..e9907ec939 100644 --- a/recipes/configs/qwen2/0.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_lora_single_device.yaml @@ -19,10 +19,10 @@ # Model Arguments model: _component_: torchtune.models.qwen2.lora_qwen2_0_5b - lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] - apply_lora_to_mlp: False - lora_rank: 32 - lora_alpha: 64 + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + lora_rank: 32 # higher increases accuracy and memory + lora_alpha: 64 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -45,8 +45,8 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 4 @@ -68,8 +68,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 4 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune @@ -84,7 +84,7 @@ device: cuda dtype: bf16 # Activations Memory -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler diff --git a/recipes/configs/qwen2/1.5B_full.yaml b/recipes/configs/qwen2/1.5B_full.yaml index 725d7fa65f..bae27e0a70 100644 --- a/recipes/configs/qwen2/1.5B_full.yaml +++ b/recipes/configs/qwen2/1.5B_full.yaml @@ -26,8 +26,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True @@ -56,14 +56,15 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -76,3 +77,28 @@ metric_logger: output_dir: /tmp/Qwen2-1.5B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2/1.5B_full_single_device.yaml b/recipes/configs/qwen2/1.5B_full_single_device.yaml index 6e140085c4..3b7642cf24 100644 --- a/recipes/configs/qwen2/1.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_full_single_device.yaml @@ -28,8 +28,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True @@ -56,20 +56,20 @@ optimizer: _component_: bitsandbytes.optim.PagedAdamW lr: 2e-5 -optimizer_in_bwd: True +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training environment device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -82,3 +82,28 @@ metric_logger: output_dir: /tmp/Qwen2-1.5B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2/1.5B_lora.yaml b/recipes/configs/qwen2/1.5B_lora.yaml index d5a23b571e..d006b29cce 100644 --- a/recipes/configs/qwen2/1.5B_lora.yaml +++ b/recipes/configs/qwen2/1.5B_lora.yaml @@ -19,10 +19,10 @@ # Model Arguments model: _component_: torchtune.models.qwen2.lora_qwen2_1_5b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False - lora_rank: 32 - lora_alpha: 64 + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + lora_rank: 32 # higher increases accuracy and memory + lora_alpha: 64 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -44,8 +44,8 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -66,8 +66,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 8 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune @@ -80,7 +80,7 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler diff --git a/recipes/configs/qwen2/1.5B_lora_single_device.yaml b/recipes/configs/qwen2/1.5B_lora_single_device.yaml index 88e18352b8..1943be6cb9 100644 --- a/recipes/configs/qwen2/1.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_lora_single_device.yaml @@ -19,10 +19,10 @@ # Model Arguments model: _component_: torchtune.models.qwen2.lora_qwen2_1_5b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False - lora_rank: 32 - lora_alpha: 64 + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + lora_rank: 32 # higher increases accuracy and memory + lora_alpha: 64 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -44,8 +44,8 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -66,8 +66,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 8 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune @@ -82,7 +82,7 @@ device: cuda dtype: bf16 # Activations Memory -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler diff --git a/recipes/configs/qwen2/7B_full.yaml b/recipes/configs/qwen2/7B_full.yaml index 3c159f90fc..d0a6726826 100644 --- a/recipes/configs/qwen2/7B_full.yaml +++ b/recipes/configs/qwen2/7B_full.yaml @@ -26,8 +26,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True @@ -59,14 +59,15 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 16 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -79,3 +80,28 @@ metric_logger: output_dir: /tmp/Qwen2-7B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2/7B_full_single_device.yaml b/recipes/configs/qwen2/7B_full_single_device.yaml index 5cc2c8b4b5..25e4a1b72b 100644 --- a/recipes/configs/qwen2/7B_full_single_device.yaml +++ b/recipes/configs/qwen2/7B_full_single_device.yaml @@ -28,8 +28,8 @@ tokenizer: # Dataset dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True @@ -57,18 +57,18 @@ epochs: 1 optimizer: _component_: bitsandbytes.optim.PagedAdamW lr: 5e-6 -optimizer_in_bwd: True +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training environment device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Reduced precision @@ -81,3 +81,28 @@ metric_logger: output_dir: /tmp/Qwen2-7B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2/7B_lora.yaml b/recipes/configs/qwen2/7B_lora.yaml index 612b48d156..c853a7e39f 100644 --- a/recipes/configs/qwen2/7B_lora.yaml +++ b/recipes/configs/qwen2/7B_lora.yaml @@ -21,11 +21,11 @@ # Model Arguments model: _component_: torchtune.models.qwen2.lora_qwen2_7b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -50,8 +50,8 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -72,8 +72,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 32 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2-7B-Instruct-lora-finetune @@ -86,7 +86,7 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler diff --git a/recipes/configs/qwen2/7B_lora_single_device.yaml b/recipes/configs/qwen2/7B_lora_single_device.yaml index 1297d1bbe1..97204f8a1d 100644 --- a/recipes/configs/qwen2/7B_lora_single_device.yaml +++ b/recipes/configs/qwen2/7B_lora_single_device.yaml @@ -19,11 +19,11 @@ # Model Arguments model: _component_: torchtune.models.qwen2.lora_qwen2_7b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -48,8 +48,8 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -70,8 +70,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 64 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2-7B-Instruct-lora-finetune @@ -86,7 +86,7 @@ device: cuda dtype: bf16 # Activations Offloading -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler diff --git a/recipes/configs/qwen2/knowledge_distillation_distributed.yaml b/recipes/configs/qwen2/knowledge_distillation_distributed.yaml index 9727860ca7..94d528e004 100644 --- a/recipes/configs/qwen2/knowledge_distillation_distributed.yaml +++ b/recipes/configs/qwen2/knowledge_distillation_distributed.yaml @@ -57,6 +57,7 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 8 @@ -80,7 +81,8 @@ kd_ratio: 0.5 # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 2 +compile: False # pytorch compile, set to true for better perf/memory +gradient_accumulation_steps: 8 # Use to increase virtual batch size # Logging output_dir: /tmp/qwen_kd @@ -93,7 +95,7 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/knowledge_distillation_single_device.yaml b/recipes/configs/qwen2/knowledge_distillation_single_device.yaml index f7d1b191cd..8246cf2e01 100644 --- a/recipes/configs/qwen2/knowledge_distillation_single_device.yaml +++ b/recipes/configs/qwen2/knowledge_distillation_single_device.yaml @@ -56,8 +56,8 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: - packed: False # Set to true for great speed ups _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed seed: null shuffle: True batch_size: 8 @@ -81,8 +81,8 @@ kd_ratio: 0.5 # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 2 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/qwen_kd @@ -95,4 +95,29 @@ log_peak_memory_stats: True # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2_5/0_5B_full.yaml b/recipes/configs/qwen2_5/0_5B_full.yaml index 341c054991..93c94c666a 100644 --- a/recipes/configs/qwen2_5/0_5B_full.yaml +++ b/recipes/configs/qwen2_5/0_5B_full.yaml @@ -27,7 +27,7 @@ tokenizer: # Dataset dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True @@ -56,14 +56,16 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 16 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 @@ -75,3 +77,28 @@ metric_logger: output_dir: /tmp/Qwen2_5-0_5B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: False + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2_5/0_5B_full_single_device.yaml b/recipes/configs/qwen2_5/0_5B_full_single_device.yaml index 58059e06a9..707cbaa0f2 100644 --- a/recipes/configs/qwen2_5/0_5B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/0_5B_full_single_device.yaml @@ -29,7 +29,7 @@ tokenizer: # Dataset dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True @@ -58,17 +58,18 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss -optimizer_in_bwd: False +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 max_steps_per_epoch: null -gradient_accumulation_steps: 8 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training environment device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 @@ -80,3 +81,28 @@ metric_logger: output_dir: /tmp/Qwen2_5-0_5B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: False + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2_5/0_5B_lora.yaml b/recipes/configs/qwen2_5/0_5B_lora.yaml index c6a4af1ee4..63ec87897c 100644 --- a/recipes/configs/qwen2_5/0_5B_lora.yaml +++ b/recipes/configs/qwen2_5/0_5B_lora.yaml @@ -20,11 +20,10 @@ # Model Arguments model: _component_: torchtune.models.qwen2_5.lora_qwen2_5_0_5b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False - apply_lora_to_output: False - lora_rank: 32 - lora_alpha: 64 + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + lora_rank: 32 # higher increases accuracy and memory + lora_alpha: 64 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -47,7 +46,7 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True @@ -70,8 +69,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 4 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2_5-0_5B-Instruct-lora-finetune @@ -84,7 +83,8 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2_5/0_5B_lora_single_device.yaml b/recipes/configs/qwen2_5/0_5B_lora_single_device.yaml index 2d9c089774..e11e34bcb7 100644 --- a/recipes/configs/qwen2_5/0_5B_lora_single_device.yaml +++ b/recipes/configs/qwen2_5/0_5B_lora_single_device.yaml @@ -19,11 +19,10 @@ # Model Arguments model: _component_: torchtune.models.qwen2_5.lora_qwen2_5_0_5b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False - apply_lora_to_output: False - lora_rank: 32 - lora_alpha: 64 + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + lora_rank: 32 # higher increases accuracy and memory + lora_alpha: 64 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -46,7 +45,7 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True batch_size: 4 @@ -68,8 +67,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 4 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2_5-0_5B-Instruct-lora-finetune @@ -84,8 +83,8 @@ device: cuda dtype: bf16 # Activations Offloading -enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2_5/14B_lora_single_device.yaml b/recipes/configs/qwen2_5/14B_lora_single_device.yaml index d89710d1a6..002129641a 100644 --- a/recipes/configs/qwen2_5/14B_lora_single_device.yaml +++ b/recipes/configs/qwen2_5/14B_lora_single_device.yaml @@ -19,11 +19,11 @@ # Model Arguments model: _component_: torchtune.models.qwen2_5.lora_qwen2_5_14b_instruct - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -53,7 +53,7 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -74,8 +74,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 64 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2_5-14B-Instruct-lora-finetune @@ -90,8 +90,8 @@ device: cuda dtype: bf16 # Activations Offloading -enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2_5/1_5B_full.yaml b/recipes/configs/qwen2_5/1_5B_full.yaml index 9456200422..be01ab8670 100644 --- a/recipes/configs/qwen2_5/1_5B_full.yaml +++ b/recipes/configs/qwen2_5/1_5B_full.yaml @@ -27,7 +27,7 @@ tokenizer: # Dataset dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True @@ -56,14 +56,16 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 @@ -75,3 +77,28 @@ metric_logger: output_dir: /tmp/Qwen2_5-1_5B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: False + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2_5/1_5B_full_single_device.yaml b/recipes/configs/qwen2_5/1_5B_full_single_device.yaml index 6a78521c80..9d23055ab5 100644 --- a/recipes/configs/qwen2_5/1_5B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/1_5B_full_single_device.yaml @@ -29,7 +29,7 @@ tokenizer: # Dataset dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True @@ -55,20 +55,21 @@ optimizer: _component_: bitsandbytes.optim.PagedAdamW lr: 2e-5 -optimizer_in_bwd: True +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training environment device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 @@ -80,3 +81,28 @@ metric_logger: output_dir: /tmp/Qwen2_5-1_5B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: False + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2_5/1_5B_lora.yaml b/recipes/configs/qwen2_5/1_5B_lora.yaml index 9e3cfad1b6..d47835d0b5 100644 --- a/recipes/configs/qwen2_5/1_5B_lora.yaml +++ b/recipes/configs/qwen2_5/1_5B_lora.yaml @@ -20,11 +20,10 @@ # Model Arguments model: _component_: torchtune.models.qwen2_5.lora_qwen2_5_1_5b_instruct - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False - apply_lora_to_output: False - lora_rank: 32 - lora_alpha: 64 + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + lora_rank: 32 # higher increases accuracy and memory + lora_alpha: 64 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -47,7 +46,7 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -68,8 +67,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 8 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2_5-1_5B-Instruct-lora-finetune @@ -82,7 +81,8 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2_5/1_5B_lora_single_device.yaml b/recipes/configs/qwen2_5/1_5B_lora_single_device.yaml index f35989fa4f..e9583ea62a 100644 --- a/recipes/configs/qwen2_5/1_5B_lora_single_device.yaml +++ b/recipes/configs/qwen2_5/1_5B_lora_single_device.yaml @@ -19,11 +19,10 @@ # Model Arguments model: _component_: torchtune.models.qwen2_5.lora_qwen2_5_1_5b_instruct - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False - apply_lora_to_output: False - lora_rank: 32 - lora_alpha: 64 + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + lora_rank: 32 # higher increases accuracy and memory + lora_alpha: 64 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -46,7 +45,7 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -67,8 +66,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 8 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2_5-1_5B-Instruct-lora-finetune @@ -83,8 +82,8 @@ device: cuda dtype: bf16 # Activations Offloading -enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2_5/32B_lora.yaml b/recipes/configs/qwen2_5/32B_lora.yaml index 19a9356c27..28cda4f662 100644 --- a/recipes/configs/qwen2_5/32B_lora.yaml +++ b/recipes/configs/qwen2_5/32B_lora.yaml @@ -17,11 +17,11 @@ # Model Arguments model: _component_: torchtune.models.qwen2_5.lora_qwen2_5_32b_instruct - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -60,7 +60,7 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -81,8 +81,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 32 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2_5-32B-Instruct-lora-finetune @@ -95,7 +95,8 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2_5/3B_full.yaml b/recipes/configs/qwen2_5/3B_full.yaml index 79343ca457..3fb2d23df0 100644 --- a/recipes/configs/qwen2_5/3B_full.yaml +++ b/recipes/configs/qwen2_5/3B_full.yaml @@ -27,7 +27,7 @@ tokenizer: # Dataset dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True @@ -57,14 +57,16 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 16 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 @@ -76,3 +78,28 @@ metric_logger: output_dir: /tmp/Qwen2_5-3B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: False + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2_5/3B_full_single_device.yaml b/recipes/configs/qwen2_5/3B_full_single_device.yaml index 09494d6c28..a5b028c659 100644 --- a/recipes/configs/qwen2_5/3B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/3B_full_single_device.yaml @@ -29,7 +29,7 @@ tokenizer: # Dataset dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True @@ -55,18 +55,19 @@ epochs: 1 optimizer: _component_: bitsandbytes.optim.PagedAdamW lr: 5e-6 -optimizer_in_bwd: True +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training environment device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 @@ -78,3 +79,28 @@ metric_logger: output_dir: /tmp/Qwen2_5-3B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: False + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2_5/3B_lora.yaml b/recipes/configs/qwen2_5/3B_lora.yaml index b987330a6d..ffd3b6c494 100644 --- a/recipes/configs/qwen2_5/3B_lora.yaml +++ b/recipes/configs/qwen2_5/3B_lora.yaml @@ -20,11 +20,10 @@ # Model Arguments model: _component_: torchtune.models.qwen2_5.lora_qwen2_5_3b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False - apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -48,7 +47,7 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -69,8 +68,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 32 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2_5-3B-Instruct-lora-finetune @@ -83,7 +82,8 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2_5/3B_lora_single_device.yaml b/recipes/configs/qwen2_5/3B_lora_single_device.yaml index 8caf08d063..b6c5be1a0a 100644 --- a/recipes/configs/qwen2_5/3B_lora_single_device.yaml +++ b/recipes/configs/qwen2_5/3B_lora_single_device.yaml @@ -19,11 +19,10 @@ # Model Arguments model: _component_: torchtune.models.qwen2_5.lora_qwen2_5_3b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False - apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -47,7 +46,7 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -68,8 +67,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 64 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2_5-3B-Instruct-lora-finetune @@ -84,8 +83,8 @@ device: cuda dtype: bf16 # Activations Offloading -enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2_5/72B_lora.yaml b/recipes/configs/qwen2_5/72B_lora.yaml index 906e52dfde..99019e6c43 100644 --- a/recipes/configs/qwen2_5/72B_lora.yaml +++ b/recipes/configs/qwen2_5/72B_lora.yaml @@ -17,11 +17,11 @@ # Model Arguments model: _component_: torchtune.models.qwen2_5.lora_qwen2_5_72b_instruct - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -80,7 +80,7 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -101,8 +101,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2_5-72B-Instruct-lora-finetune @@ -115,7 +115,8 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2_5/7B_full.yaml b/recipes/configs/qwen2_5/7B_full.yaml index 78313ca921..f6bab9f108 100644 --- a/recipes/configs/qwen2_5/7B_full.yaml +++ b/recipes/configs/qwen2_5/7B_full.yaml @@ -27,7 +27,7 @@ tokenizer: # Dataset dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True @@ -59,14 +59,16 @@ optimizer: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 16 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 # Training env device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 @@ -78,3 +80,28 @@ metric_logger: output_dir: /tmp/Qwen2_5-7B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: False + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2_5/7B_full_single_device.yaml b/recipes/configs/qwen2_5/7B_full_single_device.yaml index c4f464e97e..0986591e53 100644 --- a/recipes/configs/qwen2_5/7B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/7B_full_single_device.yaml @@ -29,7 +29,7 @@ tokenizer: # Dataset dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True @@ -57,18 +57,19 @@ epochs: 1 optimizer: _component_: bitsandbytes.optim.PagedAdamW lr: 5e-6 -optimizer_in_bwd: True +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null -gradient_accumulation_steps: 1 -compile: False +gradient_accumulation_steps: 1 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Training environment device: cuda # Memory management -enable_activation_checkpointing: True +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 @@ -80,3 +81,28 @@ metric_logger: output_dir: /tmp/Qwen2_5-7B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: False + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2_5/7B_lora.yaml b/recipes/configs/qwen2_5/7B_lora.yaml index 61365316be..b59ac69bcd 100644 --- a/recipes/configs/qwen2_5/7B_lora.yaml +++ b/recipes/configs/qwen2_5/7B_lora.yaml @@ -20,11 +20,11 @@ # Model Arguments model: _component_: torchtune.models.qwen2_5.lora_qwen2_5_7b_instruct - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -50,7 +50,7 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -71,8 +71,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 32 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2_5-7B-Instruct-lora-finetune @@ -85,7 +85,8 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: False +enable_activation_checkpointing: False # True reduces memory +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2_5/7B_lora_single_device.yaml b/recipes/configs/qwen2_5/7B_lora_single_device.yaml index 53949bc307..a030c2fba2 100644 --- a/recipes/configs/qwen2_5/7B_lora_single_device.yaml +++ b/recipes/configs/qwen2_5/7B_lora_single_device.yaml @@ -19,11 +19,11 @@ # Model Arguments model: _component_: torchtune.models.qwen2_5.lora_qwen2_5_7b_instruct - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank lora_dropout: 0.0 tokenizer: @@ -49,7 +49,7 @@ resume_from_checkpoint: False # Dataset and Sampler dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset - packed: False + packed: False # True increases speed seed: null shuffle: True batch_size: 2 @@ -70,8 +70,8 @@ loss: # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 64 -compile: False +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory # Logging output_dir: /tmp/Qwen2_5-7B-Instruct-lora-finetune @@ -86,8 +86,8 @@ device: cuda dtype: bf16 # Activations Offloading -enable_activation_checkpointing: True -enable_activation_offloading: False +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 1ce0db98d8..98d34b5f94 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -182,10 +182,14 @@ def __init__(self, cfg: DictConfig) -> None: raise RuntimeError( "enable_activation_offloading should only be True when enable_activation_checkpointing is True" ) - elif self._enable_activation_checkpointing: - log.info( + elif ( + self._enable_activation_checkpointing + and cfg.checkpointer.model_type != "LLAMA3_VISION" + ): + utils.log_rank_zero( + log, "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " - "Enabling activation offloading should reduce memory further." + "Enabling activation offloading should reduce memory further.", ) # These are public properties which are updated by the checkpoint loader @@ -641,8 +645,8 @@ def save_checkpoint( # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.get_full_model_state_dict( - self._model, + cpu_state_dict = training.gather_cpu_state_dict( + self._model.state_dict(), self._is_rank_zero, device=self._device, ) diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 4a44b233e8..0ab6ff3e63 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -172,10 +172,14 @@ def __init__(self, cfg: DictConfig) -> None: raise RuntimeError( "enable_activation_offloading should only be True when enable_activation_checkpointing is True" ) - elif self._enable_activation_checkpointing: - log.info( + elif ( + self._enable_activation_checkpointing + and cfg.checkpointer.model_type != "LLAMA3_VISION" + ): + utils.log_rank_zero( + log, "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " - "Enabling activation offloading should reduce memory further." + "Enabling activation offloading should reduce memory further.", ) # These are public properties which are updated by the checkpoint loader @@ -257,6 +261,10 @@ def setup(self, cfg: DictConfig) -> None: # should be called before ``_setup_optimizer`` since transforming the optimizer # state dict requires the model self._compile = cfg.compile + if cfg.device == "npu" and cfg.compile: + raise ValueError( + "NPU does not support model compilation. Please set `compile: False` in the config." + ) self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=self._enable_activation_checkpointing, @@ -435,7 +443,7 @@ def _setup_model( log.info(f"Model is initialized with precision {self._dtype}.") - if self._device.type == "cuda": + if self._device.type != "cpu": memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats(memory_stats) @@ -734,7 +742,7 @@ def train(self) -> None: ), "tokens_per_second_per_gpu": num_tokens / time_per_step, } - if self._device.type == "cuda" and self._log_peak_memory_stats: + if self._device.type != "cpu" and self._log_peak_memory_stats: log_dict.update( training.get_memory_stats(device=self._device) ) diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py index b40be9e89e..e5f1047923 100644 --- a/recipes/knowledge_distillation_distributed.py +++ b/recipes/knowledge_distillation_distributed.py @@ -25,6 +25,7 @@ from torchtune.modules.peft import ( DoRALinear, get_adapter_params, + get_adapter_state_dict, get_lora_module_names, get_merged_lora_ckpt, load_dora_magnitudes, @@ -707,8 +708,8 @@ def save_checkpoint(self, epoch: int) -> None: intermediate_checkpoint = epoch + 1 < self.total_epochs # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.get_full_model_state_dict( - self._model, + cpu_state_dict = training.gather_cpu_state_dict( + self._model.state_dict(), self._is_rank_zero, device=self._device, ) @@ -728,10 +729,7 @@ def save_checkpoint(self, epoch: int) -> None: # Filter out the adapter keys and weights from the model state dict. These will # be saved separately - adapter_key_filter = lambda x: x in self.adapter_params - adapter_state_dict = { - k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k) - } + adapter_state_dict = get_adapter_state_dict(cpu_state_dict) checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) # merge the adapter weights and base weights to create the model checkpoint diff --git a/recipes/knowledge_distillation_single_device.py b/recipes/knowledge_distillation_single_device.py index 1a2c3f0e4b..4521f42da3 100644 --- a/recipes/knowledge_distillation_single_device.py +++ b/recipes/knowledge_distillation_single_device.py @@ -23,6 +23,7 @@ from torchtune.datasets import ConcatDataset from torchtune.modules.peft import ( get_adapter_params, + get_adapter_state_dict, get_lora_module_names, get_merged_lora_ckpt, load_dora_magnitudes, @@ -586,10 +587,7 @@ def save_checkpoint(self, epoch: int) -> None: ckpt_dict.update({training.MODEL_KEY: merged_state_dict}) # Construct the adapter weights - adapter_key_filter = lambda x: x in self.adapter_params - adapter_state_dict = { - k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k) - } + adapter_state_dict = get_adapter_state_dict(self._model.state_dict()) ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) adapter_config = { "r": self._lora_rank, diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index 1ab88deaf8..ee7ca5e729 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -25,6 +25,7 @@ disable_adapter, DoRALinear, get_adapter_params, + get_adapter_state_dict, get_merged_lora_ckpt, load_dora_magnitudes, LoRALinear, @@ -504,8 +505,12 @@ def save_checkpoint( intermediate_checkpoint = epoch + 1 < self.total_epochs # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.get_full_model_state_dict( - self._model, + state_dict = self._model.state_dict() + if self._save_adapter_weights_only: + state_dict = get_adapter_state_dict(state_dict, device=None) + + cpu_state_dict = training.gather_cpu_state_dict( + state_dict, self._is_rank_zero, device=self._device, ) @@ -521,23 +526,21 @@ def save_checkpoint( # Now that we have the model and opt state dict, create the actual checkpoint dict # to be sent to the checkpointer and ultimately written to file if self._is_rank_zero: - - # Filter out the adapter keys and weights from the model state dict. These will - # be saved separately - adapter_key_filter = lambda x: x in self.adapter_params - adapter_state_dict = { - k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k) - } - checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) - - # merge the adapter weights and base weights to create the model checkpoint - if not self._save_adapter_weights_only: + if self._save_adapter_weights_only: + adapter_state_dict = cpu_state_dict + else: + # Filter out the adapter keys and weights from the model state dict. These will + # be saved separately + adapter_state_dict = get_adapter_state_dict(cpu_state_dict) + + # merge the adapter weights and base weights to create the model checkpoint merged_state_dict = get_merged_lora_ckpt( cpu_state_dict, rank=self._lora_rank, alpha=self._lora_alpha, ) checkpoint_dict.update({training.MODEL_KEY: merged_state_dict}) + checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) # if training is in-progress, checkpoint the optimizer state and recipe state # as well. diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index f34694ccc8..26e6a236bc 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -23,6 +23,7 @@ from torchtune.modules.peft import ( disable_adapter, get_adapter_params, + get_adapter_state_dict, get_merged_lora_ckpt, set_trainable_params, validate_missing_and_unexpected_for_lora, @@ -407,7 +408,7 @@ def save_checkpoint(self, epoch: int) -> None: } ) - adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()} + adapter_state_dict = get_adapter_state_dict(self._model.state_dict()) ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) if not self._save_adapter_weights_only: # Construct the full state dict with LoRA weights merged into base LLM weights diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 418c823344..a900cea103 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -26,6 +26,7 @@ from torchtune.modules.peft import ( DoRALinear, get_adapter_params, + get_adapter_state_dict, get_lora_module_names, get_merged_lora_ckpt, load_dora_magnitudes, @@ -181,10 +182,14 @@ def __init__(self, cfg: DictConfig) -> None: raise RuntimeError( "enable_activation_offloading should only be True when enable_activation_checkpointing is True" ) - elif self._enable_activation_checkpointing: - log.info( + elif ( + self._enable_activation_checkpointing + and cfg.checkpointer.model_type != "LLAMA3_VISION" + ): + utils.log_rank_zero( + log, "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " - "Enabling activation offloading should reduce memory further." + "Enabling activation offloading should reduce memory further.", ) def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: @@ -448,8 +453,7 @@ def _setup_model( with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(cfg_model) - self.adapter_params = get_adapter_params(model) - set_trainable_params(model, self.adapter_params) + set_trainable_params(model, get_adapter_params(model)) if self._compile: training.compile_model(model, verbose=self._is_rank_zero) @@ -660,11 +664,14 @@ def save_checkpoint( # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.get_full_model_state_dict( - self._model, + state_dict = self._model.state_dict() + if self._save_adapter_weights_only: + state_dict = get_adapter_state_dict(state_dict, device=None) + + cpu_state_dict = training.gather_cpu_state_dict( + state_dict, self._is_rank_zero, device=self._device, - trainable_only=self._save_adapter_weights_only, ) if self._is_rank_zero: log.info( @@ -690,22 +697,22 @@ def save_checkpoint( # to be sent to the checkpointer and ultimately written to file if self._is_rank_zero: start = time.perf_counter() - # Filter out the adapter keys and weights from the model state dict. These will - # be saved separately - adapter_key_filter = lambda x: x in self.adapter_params - adapter_state_dict = { - k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k) - } - checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) - # merge the adapter weights and base weights to create the model checkpoint - if not self._save_adapter_weights_only: + if self._save_adapter_weights_only: + adapter_state_dict = cpu_state_dict + else: + # Filter out the adapter keys and weights from the model state dict. These will + # be saved separately + adapter_state_dict = get_adapter_state_dict(cpu_state_dict) + + # merge the adapter weights and base weights to create the model checkpoint merged_state_dict = get_merged_lora_ckpt( cpu_state_dict, rank=self._lora_rank, alpha=self._lora_alpha, ) checkpoint_dict.update({training.MODEL_KEY: merged_state_dict}) + checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) # if training is in-progress, checkpoint the optimizer state and recipe state # as well. diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index c28b1ebb37..fcdb3e4ea5 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -24,6 +24,7 @@ from torchtune.datasets import ConcatDataset from torchtune.modules.peft import ( get_adapter_params, + get_adapter_state_dict, get_lora_module_names, get_merged_lora_ckpt, load_dora_magnitudes, @@ -32,6 +33,7 @@ ) from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY + from tqdm import tqdm log = utils.get_logger("DEBUG") @@ -169,10 +171,14 @@ def __init__(self, cfg: DictConfig) -> None: raise RuntimeError( "enable_activation_offloading should only be True when enable_activation_checkpointing is True" ) - elif self._enable_activation_checkpointing: - log.info( + elif ( + self._enable_activation_checkpointing + and cfg.checkpointer.model_type != "LLAMA3_VISION" + ): + utils.log_rank_zero( + log, "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " - "Enabling activation offloading should reduce memory further." + "Enabling activation offloading should reduce memory further.", ) def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: @@ -248,6 +254,10 @@ def setup(self, cfg: DictConfig) -> None: self._metric_logger.log_config(cfg) self._compile = cfg.compile + if cfg.device == "npu" and cfg.compile: + raise ValueError( + "NPU does not support model compilation. Please set `compile: False` in the config." + ) checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) # hack to toggle to the low cpu ram version of the reparametrize_as_dtype @@ -470,7 +480,7 @@ def _setup_model( log.info(f"Model is initialized with precision {self._dtype}.") - if self._device.type == "cuda": + if self._device.type != "cpu": memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats(memory_stats) return model @@ -583,21 +593,7 @@ def save_checkpoint(self, epoch: int) -> None: } ) - # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Testing remove this !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - adapter_state_dict = {} - for k, v in self._model.named_modules(): - if hasattr(v, "adapter_params") and callable(v.adapter_params): - import pdb - - pdb.set_trace() - adapter_params = v.adapter_params() - for n, p in v.state_dict().items(): - if any(n.endswith(param) for param in adapter_params): - full_key = f"{k}.{n}" - adapter_state_dict[n] = p.cpu() - - # adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()} - # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! End Testing !!!!!!!! !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + adapter_state_dict = get_adapter_state_dict(self._model.state_dict()) ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) if not self._save_adapter_weights_only: @@ -636,7 +632,6 @@ def save_checkpoint(self, epoch: int) -> None: def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: # Shape [b, s], needed for the loss not the model labels = batch.pop("labels") - # run model with self.activations_handling_ctx: logits = self._model(**batch) diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index 4126f95bd5..1aa622ba63 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os import sys import time @@ -21,11 +20,13 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_packed, padded_collate_sft +from torchtune.config._utils import _get_component_from_path +from torchtune.data import padded_collate_packed from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY from torchtune.training.activations import apply_selective_activation_checkpointing +from torchtune.training.lr_schedulers import get_lr from tqdm import tqdm @@ -50,18 +51,30 @@ class QATRecipeDistributed(FTRecipeInterface): to improved quantized accuracy. This can be specified through ``fake_quant_after_n_steps``. - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states - is supported via the ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). DDP is currently not supported. Training on CPU is not supported. - - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep activations in memory and instead recompute them during the backward pass. This is especially helpful for larger batch sizes when you're memory constrained. But these savings in memory come at the cost of training performance. In most cases training can slow-down quite a bit as a result of this activation recomputation. + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In most cases this should halve the memory footprint of full precision (fp32) training, without @@ -93,6 +106,10 @@ class QATRecipeDistributed(FTRecipeInterface): - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config has example commands for how to kick-off training. @@ -102,6 +119,9 @@ class QATRecipeDistributed(FTRecipeInterface): Raises: ValueError: If ``dtype`` is set to fp16. RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``left_pad_sequence`` is set as the data collator. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ def __init__(self, cfg: DictConfig) -> None: @@ -141,12 +161,50 @@ def __init__(self, cfg: DictConfig) -> None: # Training cfg self._resume_from_checkpoint = cfg.resume_from_checkpoint self._gradient_accumulation_steps = cfg.gradient_accumulation_steps - self._fsdp_sharding_strategy = torch.distributed.fsdp.ShardingStrategy[ - cfg.get("fsdp_sharding_strategy", "FULL_SHARD") - ] + self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) + self._clip_grad_norm = cfg.get("clip_grad_norm", None) self._fake_quant_after_n_steps = cfg.get("fake_quant_after_n_steps", None) self._quantizer_mode = None + # Optimizer in backward is not compatible with gradient accumulation or gradient clipping + if self._optimizer_in_bwd: + if self._clip_grad_norm is not None: + raise RuntimeError( + "Gradient clipping is not supported with optimizer in bwd." + "Please set clip_grad_norm=None, or optimizer_in_bwd=False." + ) + if self._gradient_accumulation_steps > 1: + raise RuntimeError( + "Gradient accumulation is not supported with optimizer in bwd." + "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." + ) + + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif ( + self._enable_activation_checkpointing + and cfg.checkpointer.model_type != "LLAMA3_VISION" + ): + utils.log_rank_zero( + log, + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further.", + ) + # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests self.seed = training.set_seed(seed=cfg.seed) @@ -223,10 +281,11 @@ def setup(self, cfg: DictConfig) -> None: checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) - self._model_compile = cfg.get("compile", False) + self._compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), @@ -239,6 +298,7 @@ def setup(self, cfg: DictConfig) -> None: self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, + optimizer_in_bwd=self._optimizer_in_bwd, opt_state_dict=( checkpoint_dict[training.OPT_KEY] if self._resume_from_checkpoint @@ -248,30 +308,25 @@ def setup(self, cfg: DictConfig) -> None: # initialize loss self._loss_fn = config.instantiate(cfg.loss) - backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + + if self._compile: + training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": # set num_output_chunks for model self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) - if self._model_compile: - log.info("Compiling loss with torch.compile...") - # For CEWithChunkedOutputLoss, if we compile the entire class - # we lose the benefits from the chunked loss. - # Therefore, we only compile the cross entropy function + upcasting - self._loss_fn.compute_cross_entropy = torch.compile( - self._loss_fn.compute_cross_entropy, backend=backend - ) - else: - if self._model_compile: - log.info("Compiling loss with torch.compile...") - self._loss_fn = torch.compile(self._loss_fn, backend=backend) - log.info("Loss is initialized.") + + if self._is_rank_zero: + log.info("Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized + collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") self._sampler, self._dataloader = self._setup_data( cfg_dataset=cfg.dataset, shuffle=cfg.shuffle, batch_size=cfg.batch_size, + collate_fn=collate_name, ) # Finally update the recipe state which can only be correctly set after all of the @@ -371,6 +426,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, fsdp_cpu_offload: bool, reshard_after_forward: bool, model_state_dict: Dict[str, Any], @@ -396,6 +452,9 @@ def _setup_model( with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(cfg_model) + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + # We currently have two versions of activation checkpointing in this recipe # for testing and BC purposes. ``enable_activation_checkpointing`` controls # the older version of AC and this behavior is unchanged @@ -451,7 +510,17 @@ def _setup_model( # This method will convert the full model state dict into a sharded state # dict and load into the model training.load_from_full_model_state_dict( - model, model_state_dict, self._device, self._is_rank_zero, strict=True + model, + model_state_dict, + self._device, + self._is_rank_zero, + strict=True, + cpu_offload=fsdp_cpu_offload, + ) + + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading ) # Ensure no params and buffers are on meta device @@ -470,25 +539,64 @@ def _setup_model( return model def _setup_optimizer( - self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None - ) -> Optimizer: - optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) - if opt_state_dict: - training.load_from_full_optimizer_state_dict( - optimizer, - opt_state_dict, - self._device, + self, + cfg_optimizer: DictConfig, + optimizer_in_bwd: bool = False, + opt_state_dict: Optional[Dict[str, Any]] = None, + ) -> Optional[Optimizer]: + if optimizer_in_bwd: + # Maintain a dict of optims for every parameter. + optim_dict = { + param: config.instantiate(cfg_optimizer, [param]) + for param in self._model.parameters() + } + + # Register optimizer step hooks on the model to run optimizer in backward. + training.register_optim_in_bwd_hooks( + model=self._model, optim_dict=optim_dict ) + # Create a wrapper for checkpoint save/load of optimizer states when running in backward. + self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( + model=self._model, optim_dict=optim_dict + ) + # Load optimizer states for each param. If optimizer states are being restored in an optimizer in + # backward run, these need to have been saved with the same setting. Cannot restore from runs that + # did not use optimizer in backward. + if opt_state_dict is not None: + for param in opt_state_dict.keys(): + try: + training.load_from_full_optimizer_state_dict( + self._optim_ckpt_wrapper.state_dict()[param], + opt_state_dict[param], + self._device, + ) + except BaseException as e: + raise RuntimeError( + "Failed loading in-backward optimizer checkpoints." + "Please make sure run being restored from was using in-backward optimizer." + ) from e + if self._is_rank_zero: + log.info("In-backward optimizers are set up.") + return None + else: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + training.load_from_full_optimizer_state_dict( + optimizer, + opt_state_dict, + self._device, + ) - if self._is_rank_zero: - log.info("Optimizer is initialized.") - return optimizer + if self._is_rank_zero: + log.info("Optimizer is initialized.") + return optimizer def _setup_data( self, cfg_dataset: DictConfig, shuffle: bool, batch_size: int, + collate_fn: str, ) -> Tuple[DistributedSampler, DataLoader]: """ All data related setup happens here. Currently this recipe only supports the @@ -499,15 +607,20 @@ def _setup_data( if isinstance(cfg_dataset, ListConfig): datasets = [ - config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + config.instantiate(single_cfg_dataset, self._tokenizer) for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) packed = False else: - ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) + # Instantiate collate_fn + if "left_pad_sequence" in collate_fn: + raise RuntimeError("left_pad_sequence collator is only for inference.") + collate_fn = _get_component_from_path(collate_fn) + sampler = DistributedSampler( ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 ) @@ -519,14 +632,12 @@ def _setup_data( drop_last=True, collate_fn=( partial( - padded_collate_sft, + collate_fn, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) if not packed - else partial( - padded_collate_packed, - ) + else padded_collate_packed ), ) @@ -553,25 +664,54 @@ def save_checkpoint( checkpoint_dict = {} intermediate_checkpoint = epoch + 1 < self.total_epochs + + if self._is_rank_zero: + log.info( + "Saving checkpoint. This may take some time. Retrieving full model state dict..." + ) + start = time.perf_counter() + # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.get_full_model_state_dict( - self._model, + cpu_state_dict = training.gather_cpu_state_dict( + self._model.state_dict(), self._is_rank_zero, + device=self._device, ) - if intermediate_checkpoint: - opt_state_dict = training.get_full_optimizer_state_dict( - self._optimizer, - self._is_rank_zero, + if self._is_rank_zero: + log.info( + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" ) + + if intermediate_checkpoint: + start = time.perf_counter() + if self._is_rank_zero: + log.info("Getting optimizer state dict...") + if not self._optimizer_in_bwd: + opt_state_dict = training.get_full_optimizer_state_dict( + self._optimizer, + self._is_rank_zero, + device=self._device, + ) + else: + opt_state_dict = {} + for param, opt in self._optim_ckpt_wrapper.optim_map.items(): + opt_state_dict[param] = training.get_full_optimizer_state_dict( + opt, self._is_rank_zero, device=self._device + ) + if self._is_rank_zero: + log.info( + f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs" + ) else: opt_state_dict = None # Now that we have the model and opt state dict, create the actual checkpoint dict # to be sent to the checkpointer and ultimately written to file - if self._is_rank_zero: + if self._is_rank_zero: + start = time.perf_counter() checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) # if training is in-progress, checkpoint the optimizer state and recipe state @@ -592,6 +732,9 @@ def save_checkpoint( epoch=epoch, intermediate_checkpoint=intermediate_checkpoint, ) + log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") + + torch.distributed.barrier() def train(self) -> None: """ @@ -599,10 +742,15 @@ def train(self) -> None: """ # clean up before training begins training.cleanup_before_training() + world_size, rank = training.get_world_size_and_rank() # zero out the gradients before starting training - self._optimizer.zero_grad() + if not self._optimizer_in_bwd: + self._optimizer.zero_grad() + else: + for opt in self._optim_ckpt_wrapper.optim_map.values(): + opt.zero_grad() # Initialize tokens count and running loss (for grad accumulation) t0 = time.perf_counter() @@ -612,7 +760,6 @@ def train(self) -> None: self._profiler.start() # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): - # Update the sampler to ensure data is correctly shuffled across epochs # in case shuffle is True self._sampler.set_epoch(curr_epoch) @@ -635,13 +782,6 @@ def train(self) -> None: ): torch.cuda.memory._record_memory_history() - # Both are shape [b, s] - tokens, labels = batch["tokens"], batch["labels"] - # Get the attention mask and position ids from the dataset if they - # exist. Currently, only sample packing in PackedDataset returns these - mask = batch.get("mask", None) # shape [b, s, s] - input_pos = batch.get("input_pos", None) # shape [b, s] - # Optionally wait N steps before enabling fake quant if self._fake_quant_after_n_steps is not None: if self.global_step == 0: @@ -663,20 +803,20 @@ def train(self) -> None: ) self._model.apply(enable_fq) - tokens = tokens.to(self._device) + utils.batch_to_device(batch, self._device) # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step - - utils.batch_to_device(batch, self._device) - current_num_tokens = ( batch["labels"] != self._loss_fn.ignore_index ).sum() num_tokens += current_num_tokens + + # Shape [b, s], needed for the loss not the model labels = batch.pop("labels") - logits = self._model(**batch) + with self.activations_handling_ctx: + logits = self._model(**batch) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] @@ -689,25 +829,40 @@ def train(self) -> None: logits = logits.reshape(-1, logits.size(-1)) # Compute loss + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients current_loss = self._loss_fn(logits, labels) * current_num_tokens # free logits otherwise it peaks backward memory del logits running_loss += current_loss - current_loss.backward() - # Step with optimizer - if (idx + 1) % self._gradient_accumulation_steps == 0: - # Get total number of tokens across all ranks to normalize gradients + # For optimizer in backward, we need to normalize before calling backward + # This case and gradient accumulation are mutually exclusive + if self._optimizer_in_bwd: torch.distributed.all_reduce(num_tokens) - # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) - # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + current_loss = current_loss / num_tokens + + current_loss.backward() - self._optimizer.step() - self._optimizer.zero_grad(set_to_none=True) + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + if not self._optimizer_in_bwd: + # Get total number of tokens across all ranks to normalize gradients + torch.distributed.all_reduce(num_tokens) + # This will ensure that the logged loss matches what we're optimizing + torch.distributed.all_reduce(running_loss) + # Manually scale the gradients from unnormalized loss by total # of tokens + training.scale_grads(self._model, 1 / num_tokens) + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) # Update the number of steps when the weights are updated self.global_step += 1 @@ -726,15 +881,22 @@ def train(self) -> None: time_per_step = time.perf_counter() - t0 log_dict = { "loss": loss_to_log, - "lr": self._optimizer.param_groups[0]["lr"], - "tokens_per_second_per_gpu": ( - num_tokens / time_per_step * world_size + "lr": get_lr( + ( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), ), + "tokens_per_second_per_gpu": num_tokens + / (time_per_step * world_size), } if self._log_peak_memory_stats: log_dict.update( training.get_memory_stats(device=self._device) ) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) self._metric_logger.log_dict( log_dict, step=self.global_step, @@ -784,7 +946,7 @@ def recipe_main(cfg: DictConfig) -> None: """ if not training.is_distributed(): raise RuntimeError( - "Distributed QAT recipe should be run via a distributed launcher." + "Distributed finetune recipe should be run via a distributed launcher." "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" ) init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") diff --git a/tests/recipes/test_full_finetune_distributed.py b/tests/recipes/test_full_finetune_distributed.py index f1f4256411..9c8d0eacd5 100644 --- a/tests/recipes/test_full_finetune_distributed.py +++ b/tests/recipes/test_full_finetune_distributed.py @@ -4,8 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os import runpy - import sys from pathlib import Path @@ -113,3 +113,89 @@ def test_loss( torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 ) + + @pytest.mark.integration_test + @pytest.mark.parametrize( + "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd", + [ + ("llama3/8B_full", "llama3", "tune", 1, 4, False), + ], + ) + @gpu_test(gpu_count=2) + def test_training_state_on_resume( + self, + micro_batch_size, + gradient_accumulation_steps, + config, + model_type, + ckpt_type, + optim_in_bwd, + tmpdir, + monkeypatch, + ): + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for two epochs + cmd_1 = f""" + tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \ + --config {config} \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + clip_grad_norm=100 \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type] + cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config + + monkeypatch.setattr(sys, "argv", cmd_1) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Resume training + cmd_2 = f""" + tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \ + --config {config} \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{tmpdir}' \ + checkpointer.checkpoint_files=[{os.path.join(tmpdir, "torchtune_model_0.pt")}]\ + checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + resume_from_checkpoint=True \ + metric_logger.filename={log_file} \ + clip_grad_norm=100 \ + """.split() + + cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config + + monkeypatch.setattr(sys, "argv", cmd_2) + runpy.run_path(TUNE_PATH, run_name="__main__") + + expected_loss_values = self._fetch_expected_loss_values(model_type)[2:] + + loss_values = get_loss_values_from_metric_logger(log_file) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 + ) diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index 819c70fdf0..85df960b22 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -181,7 +181,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): checkpointer._component_=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir={tmpdir} \ checkpointer.checkpoint_files=[{os.path.join(tmpdir, "hf_model_0001_0.pt")}]\ - checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} + checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\ checkpointer.output_dir={tmpdir} \ checkpointer.model_type=LLAMA2 \ tokenizer.path=/tmp/test-artifacts/tokenizer.model \ diff --git a/tests/recipes/test_lora_dpo_single_device.py b/tests/recipes/test_lora_dpo_single_device.py index 703ac2e471..53770dafc0 100644 --- a/tests/recipes/test_lora_dpo_single_device.py +++ b/tests/recipes/test_lora_dpo_single_device.py @@ -73,6 +73,8 @@ def test_training_state_on_resume( tune run lora_dpo_single_device \ --config llama2/7B_lora_dpo_single_device \ output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj'] \ + model.apply_lora_to_mlp=False \ checkpointer=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ @@ -102,6 +104,8 @@ def test_training_state_on_resume( tune run lora_dpo_single_device \ --config llama2/7B_lora_dpo_single_device \ output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj'] \ + model.apply_lora_to_mlp=False \ checkpointer=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir={tmpdir} \ checkpointer.checkpoint_files=[{ckpt_path}]\ @@ -138,6 +142,8 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): tune run lora_dpo_single_device \ --config llama2/7B_lora_dpo_single_device \ output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj'] \ + model.apply_lora_to_mlp=False \ checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ diff --git a/tests/recipes/test_lora_finetune_distributed.py b/tests/recipes/test_lora_finetune_distributed.py index 4943e1559b..268cadfaad 100644 --- a/tests/recipes/test_lora_finetune_distributed.py +++ b/tests/recipes/test_lora_finetune_distributed.py @@ -75,6 +75,8 @@ def test_loss( batch_size={micro_batch_size} \ gradient_accumulation_steps={gradient_accumulation_steps} \ output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj'] \ + model.apply_lora_to_mlp=False \ checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ @@ -146,6 +148,8 @@ def test_training_state_on_resume( batch_size=4 \ gradient_accumulation_steps=1 \ output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj'] \ + model.apply_lora_to_mlp=False \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ @@ -171,6 +175,8 @@ def test_training_state_on_resume( batch_size=4 \ gradient_accumulation_steps=1 \ output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj'] \ + model.apply_lora_to_mlp=False \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir={tmpdir} \ checkpointer.checkpoint_files=[{ckpt_path}]\ @@ -220,6 +226,8 @@ def test_save_and_load_merged_weights( batch_size=4 \ gradient_accumulation_steps=1 \ output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj'] \ + model.apply_lora_to_mlp=False \ model=torchtune.models.lora_small_test_model \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir='{ckpt_dir}' \ diff --git a/tests/recipes/test_lora_finetune_single_device.py b/tests/recipes/test_lora_finetune_single_device.py index d2521e4821..caaf5f5b43 100644 --- a/tests/recipes/test_lora_finetune_single_device.py +++ b/tests/recipes/test_lora_finetune_single_device.py @@ -88,6 +88,8 @@ def test_loss( batch_size={micro_batch_size} \ gradient_accumulation_steps={gradient_accumulation_steps} \ output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj'] \ + model.apply_lora_to_mlp=False \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}] \ @@ -146,6 +148,8 @@ def test_loss_qlora( batch_size={micro_batch_size} \ gradient_accumulation_steps={gradient_accumulation_steps} \ output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \ + model.apply_lora_to_mlp=True \ checkpointer=torchtune.training.FullModelMetaCheckpointer checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ @@ -206,6 +210,8 @@ def test_training_state_on_resume( batch_size=8 \ gradient_accumulation_steps=1 \ output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \ + model.apply_lora_to_mlp=True \ checkpointer=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ @@ -270,6 +276,8 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch, use_dora): tune run lora_finetune_single_device \ --config llama2/7B_lora_single_device \ output_dir={tmpdir} \ + model.lora_attn_modules=['q_proj','v_proj','k_proj','output_proj'] \ + model.apply_lora_to_mlp=True \ checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ diff --git a/tests/regression_tests/test_llama2_7b.py b/tests/regression_tests/test_llama2_7b.py index cba0a39032..cedbc39a31 100644 --- a/tests/regression_tests/test_llama2_7b.py +++ b/tests/regression_tests/test_llama2_7b.py @@ -36,6 +36,8 @@ def test_finetune_and_eval(self, tmpdir, capsys, monkeypatch): ft_cmd = f""" tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7B_lora \ + model.lora_attn_modules=['q_proj','v_proj'] \ + model.apply_lora_to_mlp=False \ output_dir={tmpdir} \ checkpointer=torchtune.training.FullModelTorchTuneCheckpointer checkpointer.checkpoint_dir='{ckpt_dir}' \ diff --git a/tests/torchtune/models/llama2/scripts/compare_fused_attention.py b/tests/torchtune/models/llama2/scripts/compare_fused_attention.py index 328d1c528f..0c6c3e938a 100644 --- a/tests/torchtune/models/llama2/scripts/compare_fused_attention.py +++ b/tests/torchtune/models/llama2/scripts/compare_fused_attention.py @@ -256,7 +256,6 @@ def compare_attn( max_seq_len: int, use_kv_cache: bool, ): - torch.manual_seed(16) inputs = torch.randn(4, 2048, 4096) @@ -269,8 +268,9 @@ def compare_attn( kv_cache = KVCache( batch_size=4, max_seq_len=max_seq_len, - n_kv_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, + dtype=inputs.dtype, ) else: kv_cache = None @@ -330,7 +330,6 @@ def compare_attn( if __name__ == "__main__": - # compare mha mha = { "num_heads": 32, diff --git a/tests/torchtune/models/llama2/scripts/compare_lora_attention.py b/tests/torchtune/models/llama2/scripts/compare_lora_attention.py index c6073297da..fb70c5b464 100644 --- a/tests/torchtune/models/llama2/scripts/compare_lora_attention.py +++ b/tests/torchtune/models/llama2/scripts/compare_lora_attention.py @@ -33,7 +33,6 @@ def compare_lora_attention( lora_rank: int, lora_alpha: float, ) -> None: - # make sure we have the right seed for generating outputs # this should match up the seed value set in the corresponding # unit test @@ -68,8 +67,9 @@ def compare_lora_attention( KVCache( batch_size=batch_size, max_seq_len=max_seq_len, - n_kv_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, + dtype=x.dtype, ) if batch_size is not None else None diff --git a/tests/torchtune/modules/model_fusion/test_fusion_models.py b/tests/torchtune/modules/model_fusion/test_deep_fusion.py similarity index 90% rename from tests/torchtune/modules/model_fusion/test_fusion_models.py rename to tests/torchtune/modules/model_fusion/test_deep_fusion.py index 322616276e..79b2f9ab3d 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_models.py +++ b/tests/torchtune/modules/model_fusion/test_deep_fusion.py @@ -22,7 +22,7 @@ class DummyModel(nn.Module): def __init__(self, dim, vocab_size): super().__init__() self.cache_enabled = False - self.embed = nn.Embedding(vocab_size, dim) + self.tok_embeddings = nn.Embedding(vocab_size, dim) self.q = nn.Linear(dim, dim) self.k = nn.Linear(dim, dim) self.v = nn.Linear(dim, dim) @@ -38,14 +38,22 @@ def caches_are_setup(self): def reset_caches(self): self.cache_enabled = False - def forward(self, tokens, mask, encoder_input, encoder_mask, input_pos): - x = self.embed(tokens) + def forward( + self, + tokens, + *, + mask=None, + encoder_input=None, + encoder_mask=None, + input_pos=None, + ): + x = self.tok_embeddings(tokens) if encoder_input is not None: q = self.q(x) - k = self.k(encoder_input) - v = self.v(encoder_input) + k = self.k(encoder_input) if encoder_input is not None else self.k(x) + v = self.v(encoder_input) if encoder_input is not None else self.v(x) x += nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=encoder_mask + q, k, v, attn_mask=encoder_mask if encoder_mask is not None else mask ) x = self.output(x) return x @@ -85,7 +93,7 @@ def fused_model(self, encoder, decoder) -> DeepFusionModel: return model @pytest.fixture - def inputs(self, dim, vocab_size): + def inputs(self, vocab_size): batch_size = 2 seq_len = 10 tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) @@ -183,5 +191,5 @@ def test_set_trainable_params(self, fused_model, encoder, decoder): "decoder.k.bias", "decoder.v.weight", "decoder.v.bias", - "decoder.embed.weight", + "decoder.tok_embeddings.weight", } diff --git a/tests/torchtune/modules/model_fusion/test_early_fusion.py b/tests/torchtune/modules/model_fusion/test_early_fusion.py new file mode 100644 index 0000000000..d7ff407289 --- /dev/null +++ b/tests/torchtune/modules/model_fusion/test_early_fusion.py @@ -0,0 +1,336 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict + +import pytest + +import torch +from tests.test_utils import assert_expected, fixed_init_model +from torch import nn +from torchtune.modules.model_fusion import EarlyFusionModel, register_fusion_module +from torchtune.training.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(1) + + +class DummyModel(nn.Module): + def __init__(self, dim, vocab_size): + super().__init__() + self.cache_enabled = False + self.tok_embeddings = nn.Embedding(vocab_size, dim) + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.output = nn.Linear(dim, vocab_size) + register_fusion_module(self.output) + + def setup_caches(self, batch_size, dtype, *args, **kwargs): + self.cache_enabled = True + + def caches_are_setup(self): + return self.cache_enabled + + def reset_caches(self): + self.cache_enabled = False + + def forward( + self, + tokens, + *, + mask=None, + encoder_input=None, + encoder_mask=None, + input_pos=None, + ): + x = self.tok_embeddings(tokens) + if encoder_input is not None: + q = self.q(x) + k = self.k(encoder_input) if encoder_input is not None else self.k(x) + v = self.v(encoder_input) if encoder_input is not None else self.v(x) + x += nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=encoder_mask if encoder_mask is not None else mask + ) + x = self.output(x) + return x + + +class TestEarlyFusionModel: + @pytest.fixture + def vocab_size(self) -> int: + return 100 + + @pytest.fixture + def dim(self) -> int: + return 64 + + @pytest.fixture + def batch_size(self) -> int: + return 2 + + @pytest.fixture + def seq_len(self) -> int: + return 10 + + @pytest.fixture + def decoder(self, dim, vocab_size) -> nn.Module: + decoder = DummyModel(dim, vocab_size) + fixed_init_model(decoder, max_val=0.1) + return decoder + + @pytest.fixture + def fused_model(self, vocab_size, dim, decoder) -> EarlyFusionModel: + red = nn.Embedding(vocab_size, dim) + fixed_init_model(red) + green = nn.Embedding(vocab_size, dim) + fixed_init_model(green) + blue = nn.Embedding(vocab_size, dim) + fixed_init_model(blue) + + model = EarlyFusionModel( + encoders={"red": red, "green": green, "blue": blue}, + decoder=decoder, + # These are IDs that are out of vocab in the decoder + encoder_tokens={ + "red": vocab_size, + "green": vocab_size + 1, + "blue": vocab_size + 2, + }, + decoder_trainable=True, + encoders_trainable={"red": False, "green": True, "blue": False}, + fusion_trainable=False, + ) + return model + + @pytest.fixture + def inputs(self, batch_size, seq_len, vocab_size): + tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) + red_seq_len, green_seq_len, blue_seq_len = 1, 2, 3 + tokens[:, 0] = vocab_size + tokens[:, 3:5] = vocab_size + 1 + tokens[:, 7:] = vocab_size + 2 + encoder_input = { + "red": {"input": torch.randint(0, vocab_size, (batch_size, red_seq_len))}, + "green": { + "input": torch.randint(0, vocab_size, (batch_size, green_seq_len)) + }, + "blue": {"input": torch.randint(0, vocab_size, (batch_size, blue_seq_len))}, + } + encoder_mask = torch.randint(0, 2, (batch_size, seq_len, seq_len)).bool() + input_pos = torch.Tensor([1]).int() + return tokens, encoder_input, encoder_mask, input_pos + + @pytest.fixture + def state_dict(self, dim, vocab_size): + return OrderedDict( + { + "decoder.q.weight": torch.randn((dim, dim)), + "decoder.q.bias": torch.randn((dim,)), + "decoder.k.weight": torch.randn((dim, dim)), + "decoder.k.bias": torch.randn((dim,)), + "decoder.v.weight": torch.randn((dim, dim)), + "decoder.v.bias": torch.randn((dim,)), + "decoder.output.weight": torch.randn((vocab_size, dim)), + "decoder.output.bias": torch.randn((vocab_size,)), + "decoder.tok_embeddings.weight": torch.randn((vocab_size, dim)), + "encoders.red.weight": torch.randn((vocab_size, dim)), + "encoders.green.weight": torch.randn((vocab_size, dim)), + "encoders.blue.weight": torch.randn((vocab_size, dim)), + } + ) + + @torch.no_grad() + def test_forward(self, fused_model, inputs, vocab_size): + """ + Test that the forward pass of the EarlyFusionModel works as expected. + """ + tokens, encoder_input, *_ = inputs + batch_size, seq_len = tokens.shape + + out = fused_model( + tokens, + encoder_input=encoder_input, + ) + + assert out.shape == (batch_size, seq_len, vocab_size) + assert_expected(out.mean(), torch.tensor(0.5647), atol=1e-3, rtol=1e-3) + + @torch.no_grad() + def test_forward_no_decoder(self, fused_model, inputs, dim): + """ + Test that the forward pass of the EarlyFusionModel works as expected. + """ + tokens, encoder_input, *_ = inputs + batch_size, seq_len = tokens.shape + + # No-op for the decoder + class DummyModule(nn.Module): + def forward(self, x, **kwargs): + return x + + fused_model.decoder = DummyModule() + + out = fused_model( + tokens, + encoder_input=encoder_input, + ) + + assert out.shape == (batch_size, seq_len, dim) + # Check that each encoder output is placed correctly in the fused output + red = fused_model.encoders["red"](**encoder_input["red"]) + assert_expected(out[:, :1, :], red, atol=1e-3, rtol=1e-3) + green = fused_model.encoders["green"](**encoder_input["green"]) + assert_expected(out[:, 3:5, :], green, atol=1e-3, rtol=1e-3) + blue = fused_model.encoders["blue"](**encoder_input["blue"]) + assert_expected(out[:, 7:, :], blue, atol=1e-3, rtol=1e-3) + + @torch.no_grad() + def test_forward_no_encoder(self, fused_model, batch_size, seq_len, vocab_size): + """ + Test the forward pass of the EarlyFusionModel with no encoder input. + """ + tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) + + actual = fused_model(tokens) + expected = fused_model.decoder(fused_model.tok_embeddings(tokens)) + + assert_expected(actual, expected, atol=1e-3, rtol=1e-3) + + @torch.no_grad() + def test_forward_no_decoder_uneven_encoder_tokens( + self, fused_model, dim, batch_size, seq_len, vocab_size + ): + """ + If each sample has a different number of encoder tokens in the sequence, test that mask scatter + of embeds still works as expected: + + This is a dog. + My dog is better than yours. + """ + red_seq_len, green_seq_len, blue_seq_len = 1, 2, 3 + # In a real encoder input, it would be padded to max number of media in the batch, so we don't + # make these test inputs uneven. The forward pass should still be able to take the number of embeddings + # it needs and ignore the rest, which would be pad embeddings. + encoder_input = { + "red": {"input": torch.randint(0, vocab_size, (batch_size, red_seq_len))}, + "green": { + "input": torch.randint(0, vocab_size, (batch_size, green_seq_len)) + }, + "blue": {"input": torch.randint(0, vocab_size, (batch_size, blue_seq_len))}, + } + tokens = torch.randint(0, vocab_size, (batch_size, seq_len)) + # For red encoder, only the first sample has a token + tokens[0, 0] = vocab_size + # For green encoder, first sample has 2 tokens, second sample has 1 token + tokens[0, 3:5] = vocab_size + 1 + tokens[1, 4] = vocab_size + 1 + # For blue encoder, first sample has 3 tokens, second sample has 2 tokens + tokens[0, 7:] = vocab_size + 2 + tokens[1, 8:] = vocab_size + 2 + + # No-op for the decoder + class DummyModule(nn.Module): + def forward(self, x, **kwargs): + return x + + fused_model.decoder = DummyModule() + + out = fused_model( + tokens, + encoder_input=encoder_input, + ) + + assert out.shape == (batch_size, seq_len, dim) + # Check that each encoder output is placed correctly in the fused output + red = fused_model.encoders["red"](**encoder_input["red"]) + assert_expected(out[0, 0, :], red[0, 0, :], atol=1e-3, rtol=1e-3) + green = fused_model.encoders["green"](**encoder_input["green"]) + assert_expected(out[0, 3:5, :], green[0, :, :], atol=1e-3, rtol=1e-3) + assert_expected(out[1, 4, :], green[1, 0, :], atol=1e-3, rtol=1e-3) + blue = fused_model.encoders["blue"](**encoder_input["blue"]) + assert_expected(out[0, 7:, :], blue[0, :, :], atol=1e-3, rtol=1e-3) + assert_expected(out[1, 8:, :], blue[1, :2, :], atol=1e-3, rtol=1e-3) + + @torch.no_grad() + def test_decoder_forward(self, fused_model, inputs, vocab_size): + """ + Test that the forward pass of the EarlyFusionModel works during decoding. + """ + tokens, encoder_input, encoder_mask, input_pos = inputs + tokens = tokens[:, input_pos] + encoder_mask = encoder_mask[:, input_pos] + batch_size, seq_len = tokens.shape + out = fused_model( + tokens, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + input_pos=input_pos, + ) + + assert out.shape == (batch_size, seq_len, vocab_size) + assert_expected(out.mean(), torch.tensor(0.2383), atol=1e-3, rtol=1e-3) + + def test_setup_cache(self, fused_model): + """ + Test that the cache methods works as expected. + """ + fused_model.setup_caches(2, torch.float32) + assert fused_model.caches_are_setup() + fused_model.reset_caches() + assert not fused_model.caches_are_setup() + + def test_set_trainable_params(self, fused_model): + """ + Test that the trainable parameters are set correctly. + """ + trainable_params = { + n for n, p in fused_model.named_parameters() if p.requires_grad + } + assert trainable_params == { + "decoder.q.weight", + "decoder.q.bias", + "decoder.k.weight", + "decoder.k.bias", + "decoder.v.weight", + "decoder.v.bias", + "tok_embeddings.weight", + "encoders.green.weight", + } + + def test_mismatched_encoder_tokens(self, decoder): + with pytest.raises(ValueError): + _ = EarlyFusionModel( + encoders={"encoder": nn.Identity(), "encoder2": nn.Identity()}, + decoder=decoder, + encoder_tokens={"encoder": 0, "encoder3": 1}, + encoders_trainable=False, + ) + + def test_mismatched_encoder_trainable(self, decoder): + with pytest.raises(ValueError): + _ = EarlyFusionModel( + encoders={"encoder": nn.Identity(), "encoder2": nn.Identity()}, + decoder=decoder, + encoder_tokens={"encoder": 0, "encoder2": 1}, + encoders_trainable={"encoder": True, "encoder3": False}, + ) + + def test_mismatched_encoder_input(self, fused_model, inputs): + tokens, _, _, _ = inputs + with pytest.raises(ValueError): + _ = fused_model( + tokens, + encoder_input={"encoder": {"input": torch.tensor([1])}}, + ) + + def test_state_dict_hooks(self, fused_model, state_dict): + fused_model.load_state_dict(state_dict) + actual = fused_model.state_dict() + expected = state_dict + assert_expected(actual, expected) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_embed.py b/tests/torchtune/modules/model_fusion/test_fusion_embed.py deleted file mode 100644 index 35ef5c0e87..0000000000 --- a/tests/torchtune/modules/model_fusion/test_fusion_embed.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import pytest - -import torch -from tests.test_utils import assert_expected, fixed_init_model -from torchtune.modules.model_fusion import FusionEmbedding -from torchtune.training.seed import set_seed - - -@pytest.fixture(autouse=True) -def random(): - set_seed(1) - - -class TestFusionEmbedding: - """ - Class for testing our FusionEmbedding. - """ - - @pytest.fixture - def dim(self) -> int: - return 2 - - @pytest.fixture - def vocab_size(self) -> int: - return 10 - - @pytest.fixture - def fusion_vocab_size(self) -> int: - return 5 - - @pytest.fixture - def embed(self, dim, vocab_size, fusion_vocab_size) -> FusionEmbedding: - embeds = FusionEmbedding( - vocab_size=vocab_size, fusion_vocab_size=fusion_vocab_size, embed_dim=dim - ) - fixed_init_model(embeds.embedding, min_val=0, max_val=0.5) - fixed_init_model(embeds.fusion_embedding, min_val=0.51, max_val=1) - return embeds - - @torch.no_grad() - def test_forward(self, embed, vocab_size, fusion_vocab_size, dim): - """ - Test that the forward pass of the FusionEmbedding works as expected. - """ - tokens = torch.randint(0, vocab_size + fusion_vocab_size, (2, 10)) - out = embed(tokens) - - assert out.shape == (2, 10, dim) - assert_expected(out.mean(), torch.tensor(0.3409), atol=1e-3, rtol=1e-3) - - # Only new tokens, embeddings should be > 0.5 - tokens = torch.randint(vocab_size, vocab_size + fusion_vocab_size, (2, 10)) - out = embed(tokens) - - assert out.min() > 0.5 - - # Only old tokens, embeddings should be < 0.5 - tokens = torch.randint(0, vocab_size, (2, 10)) - out = embed(tokens) - - assert out.max() < 0.5 - - def test_fusion_params(self, embed): - """ - Test that the currect fusion params are returned. - """ - fusion_params = set(embed.fusion_params()) - - assert fusion_params == {"fusion_embedding.weight"} - - def test_get_and_load_state_dict(self, embed): - """ - Test that the state dict hooks work in removing the "layer" variable - """ - state_dict = embed.state_dict() - state_keys = set(state_dict.keys()) - - assert state_keys == { - "weight", - "fusion_embedding.weight", - } - - # Check that the state_dict can be loaded back into the model - embed.load_state_dict(state_dict) diff --git a/tests/torchtune/modules/model_fusion/test_fusion_layer.py b/tests/torchtune/modules/model_fusion/test_fusion_layers.py similarity index 67% rename from tests/torchtune/modules/model_fusion/test_fusion_layer.py rename to tests/torchtune/modules/model_fusion/test_fusion_layers.py index a2fc0715eb..da8fdb4b1f 100644 --- a/tests/torchtune/modules/model_fusion/test_fusion_layer.py +++ b/tests/torchtune/modules/model_fusion/test_fusion_layers.py @@ -9,7 +9,7 @@ import torch from tests.test_utils import assert_expected, fixed_init_model from torch import nn -from torchtune.modules.model_fusion import FusionLayer +from torchtune.modules.model_fusion import FusionEmbedding, FusionLayer from torchtune.training.seed import set_seed @@ -60,6 +60,79 @@ def forward(self, x): return self.linear(x) +class TestFusionEmbedding: + """ + Class for testing our FusionEmbedding. + """ + + @pytest.fixture + def dim(self) -> int: + return 2 + + @pytest.fixture + def vocab_size(self) -> int: + return 10 + + @pytest.fixture + def fusion_vocab_size(self) -> int: + return 5 + + @pytest.fixture + def embed(self, dim, vocab_size, fusion_vocab_size) -> FusionEmbedding: + embeds = FusionEmbedding( + vocab_size=vocab_size, fusion_vocab_size=fusion_vocab_size, embed_dim=dim + ) + fixed_init_model(embeds.embedding, min_val=0, max_val=0.5) + fixed_init_model(embeds.fusion_embedding, min_val=0.51, max_val=1) + return embeds + + @torch.no_grad() + def test_forward(self, embed, vocab_size, fusion_vocab_size, dim): + """ + Test that the forward pass of the FusionEmbedding works as expected. + """ + tokens = torch.randint(0, vocab_size + fusion_vocab_size, (2, 10)) + out = embed(tokens) + + assert out.shape == (2, 10, dim) + assert_expected(out.mean(), torch.tensor(0.3409), atol=1e-3, rtol=1e-3) + + # Only new tokens, embeddings should be > 0.5 + tokens = torch.randint(vocab_size, vocab_size + fusion_vocab_size, (2, 10)) + out = embed(tokens) + + assert out.min() > 0.5 + + # Only old tokens, embeddings should be < 0.5 + tokens = torch.randint(0, vocab_size, (2, 10)) + out = embed(tokens) + + assert out.max() < 0.5 + + def test_fusion_params(self, embed): + """ + Test that the currect fusion params are returned. + """ + fusion_params = set(embed.fusion_params()) + + assert fusion_params == {"fusion_embedding.weight"} + + def test_get_and_load_state_dict(self, embed): + """ + Test that the state dict hooks work in removing the "layer" variable + """ + state_dict = embed.state_dict() + state_keys = set(state_dict.keys()) + + assert state_keys == { + "weight", + "fusion_embedding.weight", + } + + # Check that the state_dict can be loaded back into the model + embed.load_state_dict(state_dict) + + class TestFusionLayer: """ Class for testing our FusionLayer wrapper. diff --git a/tests/torchtune/modules/peft/test_utils.py b/tests/torchtune/modules/peft/test_utils.py index 032cd88ec4..de90150195 100644 --- a/tests/torchtune/modules/peft/test_utils.py +++ b/tests/torchtune/modules/peft/test_utils.py @@ -16,6 +16,7 @@ disable_adapter, DoRALinear, get_adapter_params, + get_adapter_state_dict, get_merged_lora_ckpt, LoRALinear, set_trainable_params, @@ -38,30 +39,30 @@ class DummyAdapterModule(nn.Module, AdapterModule): def __init__(self, in_dim, out_dim): super().__init__() - self.adapter = nn.Linear(in_dim, out_dim, bias=False) + self.lora = nn.Linear(in_dim, out_dim, bias=False) self.linear = nn.Linear(in_dim, out_dim) def adapter_params(self): - return ["adapter.weight"] + return ["lora.weight"] def forward(self, x): - return self.adapter(x) + self.non_adapter(x) + return self.lora(x) + self.non_adapter(x) class DummyAdapterParentModel(nn.Module, AdapterModule): def __init__(self, in_dim, out_dim): super().__init__() self.dummy_adapter_module = DummyAdapterModule(in_dim, out_dim) - self.parent_adapter = nn.Linear(in_dim, out_dim) + self.parent_lora = nn.Linear(in_dim, out_dim) self.parent_base_model = nn.Linear(in_dim, out_dim) def adapter_params(self): - return ["parent_adapter.weight", "parent_adapter.bias"] + return ["parent_lora.weight", "parent_lora.bias"] def forward(self, x): return ( self.dummy_adapter_module(x) - + self.parent_adapter(x) + + self.parent_lora(x) + self.parent_base_model(x) ) @@ -79,9 +80,9 @@ def dummy_model_expected_adapter_keys(): for i in range(N_LAYERS): keys.extend( [ - f"{i}.parent_adapter.weight", - f"{i}.parent_adapter.bias", - f"{i}.dummy_adapter_module.adapter.weight", + f"{i}.parent_lora.weight", + f"{i}.parent_lora.bias", + f"{i}.dummy_adapter_module.lora.weight", ] ) return keys @@ -204,6 +205,20 @@ def test_get_adapter_params(self, request, model_name, expected_keys): expected = request.getfixturevalue(expected_keys) assert set(expected) == set(adapter_params.keys()) + @pytest.mark.parametrize( + "model_name, expected_keys", + [ + ("dummy_adapter_parent_model", "dummy_model_expected_adapter_keys"), + ("lora_llama2_model", "lora_llama2_expected_adapter_keys"), + ("dora_llama2_model", "dora_llama2_expected_adapter_keys"), + ], + ) + def test_get_adapter_state_dict(self, request, model_name, expected_keys): + model = request.getfixturevalue(model_name) + adapter_state_dict = get_adapter_state_dict(model.state_dict()) + expected = request.getfixturevalue(expected_keys) + assert set(expected) == set(adapter_state_dict.keys()) + @pytest.mark.parametrize( "model_name, expected_trainable_keys, expected_frozen_keys", [ diff --git a/tests/torchtune/modules/test_attention.py b/tests/torchtune/modules/test_attention.py index 872f6684de..0d9dcb5434 100644 --- a/tests/torchtune/modules/test_attention.py +++ b/tests/torchtune/modules/test_attention.py @@ -123,7 +123,7 @@ def gqa_kv_cache( kv_cache = KVCache( batch_size=4, max_seq_len=max_seq_len, - num_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, dtype=torch.float32, ) @@ -178,7 +178,7 @@ def mha_kv_cache( kv_cache = KVCache( batch_size=4, max_seq_len=max_seq_len, - num_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, dtype=torch.float32, ) @@ -233,7 +233,7 @@ def mqa_kv_cache( kv_cache = KVCache( batch_size=4, max_seq_len=max_seq_len, - num_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, dtype=torch.float32, ) @@ -267,7 +267,6 @@ def test_forward_gqa(self, input: torch.Tensor, gqa: MultiHeadAttention) -> None def test_forward_gqa_kv_cache( self, input: torch.Tensor, gqa_kv_cache: MultiHeadAttention, attn_params_gqa ) -> None: - _, _, _, max_seq_len = attn_params_gqa _, seq_len, _ = input.shape @@ -293,7 +292,6 @@ def test_forward_mha(self, input: torch.Tensor, mha: MultiHeadAttention) -> None def test_forward_mha_kv_cache( self, input: torch.Tensor, mha_kv_cache: MultiHeadAttention, attn_params_mha ) -> None: - _, _, _, max_seq_len = attn_params_mha _, seq_len, _ = input.shape diff --git a/tests/torchtune/training/test_activation_offloading.py b/tests/torchtune/training/test_activation_offloading.py index 5d4c968e96..286a949e77 100644 --- a/tests/torchtune/training/test_activation_offloading.py +++ b/tests/torchtune/training/test_activation_offloading.py @@ -10,6 +10,8 @@ from torch import nn from torchtune.training import OffloadActivations +NUM_GPU_CYCLES_IN_ONE_SEC = 2000000000 # 2e9 is ~1s worth of GPU cycles + @gpu_test(gpu_count=1) @pytest.mark.parametrize("use_streams", [True, False]) @@ -46,7 +48,8 @@ def test_offloading_is_same_as_without(use_streams) -> None: def test_offloading_works_with_view_outputs() -> None: """ This test is quite contrived but tests against a very obscure situation where - any of the outputs of a backward node are a view of the unpacked tensor. + any of the outputs of a backward node are a view of the unpacked tensor. (See + the first line item under Note: [Track views of the unpacked]). We want to ensure that if an unpacked tensor may be used later that we do not free it too early. @@ -98,7 +101,7 @@ def forward(ctx, activation): @staticmethod def backward(ctx, viewed_activation): - torch.cuda._sleep(2000000000) # 2e9 is ~1s worth of GPU cycles + torch.cuda._sleep(NUM_GPU_CYCLES_IN_ONE_SEC) return viewed_activation == 1 class InspectEarlierActivation(torch.autograd.Function): @@ -129,3 +132,96 @@ def fwd(t): # delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd ctx.fwd_stash = {} loss_c.backward() + + +@gpu_test(gpu_count=1) +def test_offloading_works_with_view_ac_cached_buffers() -> None: + """ + Similar to test_offloading_works_with_view_outputs, but for when AC stashes + a view of the unpacked tensor. See the second line item under Note: [Track + views of the unpacked]. + + For details on how the following custom autograd function was contrived, + please see the image attached to the PR description in #1936. The visual + is more helpful than me trying to write a blob of text here. + """ + + class A(torch.autograd.Function): + @staticmethod + def forward(ctx, ones): + ctx.save_for_backward(ones * 5) # corruptedly saving 5s + return ones + + @staticmethod + def backward(ctx, activation_is_ones): + fives = ctx.saved_tensors[0] + assert torch.all(activation_is_ones) + return activation_is_ones + + class B(torch.autograd.Function): + @staticmethod + def forward(ctx, ones): + ctx.save_for_backward(ones.clone()) + return ones.clone() # important, a view of 1s will be saved in C + + @staticmethod + def backward(ctx, activation_is_ones): + saved_tensor = ctx.saved_tensors[0] + return activation_is_ones.clone() + + class C(torch.autograd.Function): + @staticmethod + def forward(ctx, ones): + ctx.save_for_backward(ones.t().t()) + return ones.clone() + + @staticmethod + def backward(ctx, grad): + saved_tensor = ctx.saved_tensors[0] + return saved_tensor == 1 + + class D(torch.autograd.Function): + @staticmethod + def forward(ctx, ones): + ctx.save_for_backward(torch.rand_like(ones)) + return torch.rand_like(ones) + + @staticmethod + def backward(ctx, grad): + saved_tensor = ctx.saved_tensors[0] + torch.cuda._sleep(NUM_GPU_CYCLES_IN_ONE_SEC) + return torch.rand_like(grad) + + class E(torch.autograd.Function): + @staticmethod + def forward(ctx, ones): + ctx.save_for_backward(torch.rand_like(ones)) + return torch.rand_like(ones) + + @staticmethod + def backward(ctx, grad): + # It doesn't matter what E saves, but it needs to save something + # just to trigger AC recompute to fill in this tensor. + saved_tensor = ctx.saved_tensors[0] + return torch.rand_like(grad) + + def checkpointed_region(b): + c = C.apply(b) + d = D.apply(c) + return E.apply(d) + + def fwd(t): + a = A.apply(t) + b = B.apply(a) + e = torch.utils.checkpoint.checkpoint( + checkpointed_region, b, use_reentrant=False + ) + return e.sum() + + tensor = torch.ones(256, 1024, device="cuda", requires_grad=True) + ctx = OffloadActivations(use_streams=True) + with ctx: + loss = fwd(tensor) + # delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd + ctx.fwd_stash = {} + loss.backward() diff --git a/tests/torchtune/training/test_distributed.py b/tests/torchtune/training/test_distributed.py index 1f4b92b4de..830a2ab4a8 100644 --- a/tests/torchtune/training/test_distributed.py +++ b/tests/torchtune/training/test_distributed.py @@ -312,8 +312,8 @@ def test_lora_state_dict(self): fsdp_optim_to_save.zero_grad() expected_model_sd = base_model.state_dict() expected_optim_sd = base_optim.state_dict() - model_full_sd = training.get_full_model_state_dict( - fsdp_model_to_save, is_rank_zero + model_full_sd = training.gather_cpu_state_dict( + fsdp_model_to_save.state_dict(), is_rank_zero ) optim_full_sd = training.get_full_optimizer_state_dict( fsdp_optim_to_save, @@ -467,8 +467,8 @@ def _test_qlora_state_dict(self, enable_activation_checkpointing: bool): fsdp_model_to_save(inp) expected_model_sd = {k: v.cpu() for k, v in base_model.state_dict().items()} - model_full_sd = training.get_full_model_state_dict( - fsdp_model_to_save, is_rank_zero + model_full_sd = training.gather_cpu_state_dict( + fsdp_model_to_save.state_dict(), is_rank_zero ) if is_rank_zero: self.assertEqual(set(model_full_sd.keys()), set(expected_model_sd.keys())) diff --git a/tests/torchtune/utils/test_device.py b/tests/torchtune/utils/test_device.py index b481090668..b96eb5ae3b 100644 --- a/tests/torchtune/utils/test_device.py +++ b/tests/torchtune/utils/test_device.py @@ -14,9 +14,12 @@ import torch from torchtune.utils._device import ( _get_device_type_from_env, - _setup_cuda_device, + _setup_device, batch_to_device, + DeviceSupport, get_device, + get_device_support, + get_torch_device_namespace, ) @@ -69,7 +72,10 @@ def test_get_gpu_device(self) -> None: if device_idx > 0: with pytest.raises( RuntimeError, - match=f"Device specified is cuda:0 but was assigned cuda:{device_idx}", + match=( + f"You can't specify a device index when using distributed training. " + f"Device specified is cuda:0 but local rank is:{device_idx}" + ), ): device = get_device("cuda:0") @@ -83,7 +89,24 @@ def test_get_gpu_device(self) -> None: # Test that we fall back to 0 if LOCAL_RANK is not specified device = torch.device(_get_device_type_from_env()) - device = _setup_cuda_device(device) + device = _setup_device(device) assert device.type == "cuda" assert device.index == 0 assert device.index == torch.cuda.current_device() + + @pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.") + @patch("torch.cuda.is_available", return_value=True) + def test_cuda_available(self, mock_cuda): + # Test if CUDA is available, get_device_support should return DeviceSupport.CUDA + device_support = get_device_support() + assert device_support == DeviceSupport.CUDA + assert device_support.device_type == "cuda" + assert device_support.device_name == "GPU" + assert device_support.communication_backend == "nccl" + + @pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.") + @patch("torch.cuda.is_available", return_value=True) + def test_get_torch_device_for_cuda(self, mock_cuda): + # Test if get_torch_device returns the correct torch.cuda module + torch_device = get_torch_device_namespace() + assert torch_device == torch.cuda diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index cdb1d45f01..c40e89184b 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -109,6 +109,9 @@ class Recipe: Config(name="mistral/7B_full", file_path="mistral/7B_full.yaml"), Config(name="gemma/2B_full", file_path="gemma/2B_full.yaml"), Config(name="gemma/7B_full", file_path="gemma/7B_full.yaml"), + Config(name="gemma2/2B_full", file_path="gemma2/2B_full.yaml"), + Config(name="gemma2/9B_full", file_path="gemma2/9B_full.yaml"), + Config(name="gemma2/27B_full", file_path="gemma2/27B_full.yaml"), Config(name="phi3/mini_full", file_path="phi3/mini_full.yaml"), Config(name="qwen2/7B_full", file_path="qwen2/7B_full.yaml"), Config(name="qwen2/0.5B_full", file_path="qwen2/0.5B_full.yaml"), @@ -216,6 +219,30 @@ class Recipe: name="gemma/7B_qlora_single_device", file_path="gemma/7B_qlora_single_device.yaml", ), + Config( + name="gemma2/2B_lora_single_device", + file_path="gemma2/2B_lora_single_device.yaml", + ), + Config( + name="gemma2/2B_qlora_single_device", + file_path="gemma2/2B_qlora_single_device.yaml", + ), + Config( + name="gemma2/9B_lora_single_device", + file_path="gemma2/9B_lora_single_device.yaml", + ), + Config( + name="gemma2/9B_qlora_single_device", + file_path="gemma2/9B_qlora_single_device.yaml", + ), + Config( + name="gemma2/27B_lora_single_device", + file_path="gemma2/27B_lora_single_device.yaml", + ), + Config( + name="gemma2/27B_qlora_single_device", + file_path="gemma2/27B_qlora_single_device.yaml", + ), Config( name="phi3/mini_lora_single_device", file_path="phi3/mini_lora_single_device.yaml", @@ -329,6 +356,9 @@ class Recipe: Config(name="mistral/7B_lora", file_path="mistral/7B_lora.yaml"), Config(name="gemma/2B_lora", file_path="gemma/2B_lora.yaml"), Config(name="gemma/7B_lora", file_path="gemma/7B_lora.yaml"), + Config(name="gemma2/2B_lora", file_path="gemma2/2B_lora.yaml"), + Config(name="gemma2/9B_lora", file_path="gemma2/9B_lora.yaml"), + Config(name="gemma2/27B_lora", file_path="gemma2/27B_lora.yaml"), Config(name="phi3/mini_lora", file_path="phi3/mini_lora.yaml"), Config(name="qwen2/7B_lora", file_path="qwen2/7B_lora.yaml"), Config(name="qwen2/0.5B_lora", file_path="qwen2/0.5B_lora.yaml"), diff --git a/torchtune/config/_utils.py b/torchtune/config/_utils.py index a5d1291802..93c19571c5 100644 --- a/torchtune/config/_utils.py +++ b/torchtune/config/_utils.py @@ -173,6 +173,11 @@ def _merge_yaml_and_cli_args(yaml_args: Namespace, cli_args: List[str]) -> DictC # key string to reflect this if k in yaml_kwargs and _has_component(yaml_kwargs[k]): k += "._component_" + + # None passed via CLI will be parsed as string, but we really want OmegaConf null + if v == "None": + v = "!!null" + # TODO: this is a hack but otherwise we can't pass strings with leading zeroes # to define the checkpoint file format. We manually override OmegaConf behavior # by prepending the value with !!str to force a string type diff --git a/torchtune/generation/_generation.py b/torchtune/generation/_generation.py index c2d60a7373..bb4b1ff0b0 100644 --- a/torchtune/generation/_generation.py +++ b/torchtune/generation/_generation.py @@ -67,7 +67,7 @@ def generate_next_token( model: TransformerDecoder, input_pos: torch.Tensor, x: torch.Tensor, - q: torch.Tensor, + q: Optional[torch.Tensor] = None, *, mask: Optional[torch.Tensor] = None, temperature: float = 1.0, @@ -82,7 +82,7 @@ def generate_next_token( with shape [bsz x seq_length]. x (torch.Tensor): tensor with the token IDs associated with the given prompt, with shape [bsz x seq_length]. - q (torch.Tensor): randomly sampled tensor for softmax sampling trick. + q (Optional[torch.Tensor]): randomly sampled tensor for softmax sampling trick. See https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/generate.py#L40 mask (Optional[torch.Tensor]): attention mask with shape [bsz x seq_length x seq_length], default None. @@ -302,9 +302,11 @@ def generate( # tensors are of identical shape to the prompt curr_masks = masks[:, :prompt_length, :prompt_length] - q = torch.empty( - (bsz, model.tok_embeddings.num_embeddings), device=prompt.device - ).exponential_(1, generator=rng) + q = None + if rng is not None: + q = torch.empty( + (bsz, model.tok_embeddings.num_embeddings), device=prompt.device + ).exponential_(1, generator=rng) tokens, generated_logits = generate_next_token( model, input_pos=input_pos[:, :prompt_length].squeeze(), @@ -360,9 +362,11 @@ def generate( curr_input_pos = input_pos[:, : curr_pos + 1] curr_masks = masks[:, : curr_pos + 1, : curr_pos + 1] - q = torch.empty( - (bsz, model.tok_embeddings.num_embeddings), device=prompt.device - ).exponential_(1, generator=rng) + q = None + if rng is not None: + q = torch.empty( + (bsz, model.tok_embeddings.num_embeddings), device=prompt.device + ).exponential_(1, generator=rng) tokens, logits = custom_generate_next_token( model, input_pos=curr_input_pos, diff --git a/torchtune/models/clip/_transform.py b/torchtune/models/clip/_transform.py index a9b60624ff..f0f7f2c3c5 100644 --- a/torchtune/models/clip/_transform.py +++ b/torchtune/models/clip/_transform.py @@ -159,10 +159,9 @@ def __call__( assert isinstance(image, Image.Image), "Input image must be a PIL image." # Make image torch.tensor((3, H, W), dtype=dtype), 0<=values<=1 - if hasattr(image, "mode") and image.mode == "RGBA": + if image.mode != "RGB": image = image.convert("RGB") image = F.to_image(image) - image = F.grayscale_to_rgb_image(image) image = F.to_dtype(image, dtype=self.dtype, scale=True) # Find the best canvas to fit the image without distortion diff --git a/torchtune/models/gemma/__init__.py b/torchtune/models/gemma/__init__.py index 48e4e84b10..f762de86b6 100644 --- a/torchtune/models/gemma/__init__.py +++ b/torchtune/models/gemma/__init__.py @@ -27,6 +27,4 @@ "lora_gemma_7b", "qlora_gemma_2b", "qlora_gemma_7b", - "gemma_hf_to_tune", - "gemma_tune_to_hf", ] diff --git a/torchtune/models/gemma/_component_builders.py b/torchtune/models/gemma/_component_builders.py index e7ab9b224c..ba5b666c98 100644 --- a/torchtune/models/gemma/_component_builders.py +++ b/torchtune/models/gemma/_component_builders.py @@ -46,7 +46,6 @@ def gemma( attn_dropout: float = 0.0, norm_eps: float = 1e-6, rope_base: int = 10_000, - norm_embeddings: bool = True, ) -> TransformerDecoder: """ Build the decoder associated with the gemma model. This includes: @@ -72,8 +71,6 @@ def gemma( Default: 0.0 norm_eps (float): epsilon in RMS norms Default: 1e-6 rope_base (int): base for the rotary positional embeddings. Default: 10_000 - norm_embeddings (bool): whether to apply layer norm before the self-attention - and mlp layers. Default: True Returns: TransformerDecoder: Instantiation of gemma model. @@ -146,7 +143,6 @@ def lora_gemma( attn_dropout: float = 0.0, norm_eps: float = 1e-6, rope_base: int = 10_000, - norm_embeddings: bool = True, # LoRA args lora_rank: int, lora_alpha: float, @@ -177,8 +173,6 @@ def lora_gemma( Default: 0.0 norm_eps (float): epsilon in RMS norms Default: 1e-6 rope_base (int): base for the rotary positional embeddings. Default: 10_000 - norm_embeddings (bool): whether to apply layer norm before the self-attention - and mlp layers. Default: True lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 diff --git a/torchtune/models/gemma2/__init__.py b/torchtune/models/gemma2/__init__.py new file mode 100644 index 0000000000..9fe11db7ab --- /dev/null +++ b/torchtune/models/gemma2/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ..gemma._model_builders import gemma_tokenizer +from ..gemma._tokenizer import GemmaTokenizer # noqa +from ._component_builders import gemma2, lora_gemma2 # noqa +from ._model_builders import ( # noqa + gemma2_27b, + gemma2_2b, + gemma2_9b, + lora_gemma2_27b, + lora_gemma2_2b, + lora_gemma2_9b, + qlora_gemma2_27b, + qlora_gemma2_2b, + qlora_gemma2_9b, +) + +__all__ = [ + "GemmaTokenizer", + "gemma2", + "gemma2_2b", + "gemma2_9b", + "gemma2_27b", + "gemma_tokenizer", + "lora_gemma2", + "lora_gemma2_2b", + "lora_gemma2_9b", + "lora_gemma2_27b", + "qlora_gemma2_2b", + "qlora_gemma2_9b", + "qlora_gemma2_27b", +] diff --git a/torchtune/models/gemma2/_attention.py b/torchtune/models/gemma2/_attention.py new file mode 100644 index 0000000000..1b7bf38447 --- /dev/null +++ b/torchtune/models/gemma2/_attention.py @@ -0,0 +1,339 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn +from torchtune.modules.attention_utils import _MaskType +from torchtune.modules.kv_cache import KVCache + +logger = logging.getLogger(__name__) + + +class Gemma2Attention(nn.Module): + """ + Adapated from official Google Pytorch Implementation: + https://github.com/google/gemma_pytorch/blob/80881c2e6e797ef1913a4a705d4b40394791cc58/gemma/model.py#L213 + to match torchtune style. + A new attention had to be added since nn.functional.scaled_dot_product_attention does allow soft capping + Args: + embed_dim (int): embedding dimension for the model + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``, + for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``. + head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``. + q_proj (nn.Module): projection layer for query. + k_proj (nn.Module): projection layer for key. + v_proj (nn.Module): projection layer for value. + output_proj (nn.Module): projection layer for output. + pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. + q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied + before updating from kv_cache. This means it will only support token wide normalization and not + batch or sequence wide normalization. + k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is. + kv_cache (Optional[KVCache]): KVCache object used to cache key and value + max_seq_len (int): maximum sequence length supported by the model. + This is needed to compute the RoPE Cache. Default: 4096. + is_causal (bool): sets the default mask to causal when no mask is provided + attn_dropout (float): dropout value passed onto the + scaled_dot_product_attention function. This argument is ignored if the + self.training is False. Default value is 0.0. + sliding_window_size (Optional[int]): size of the sliding window if None no sliding window is applied + softcapping (Optional[float]): capping value used for soft caping, if None no capping is performed + query_pre_attn_scalar (Optional[int]): value used for pre attention normalisation, if None head_dim is used instead + Raises: + ValueError: If ``num_heads % num_kv_heads != 0`` + ValueError: If ``embed_dim % num_heads != 0`` + ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` + ValueError: if q_norm is defined without k_norm or vice versa + """ + + def __init__( + self, + *, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + q_proj: nn.Module, + k_proj: nn.Module, + v_proj: nn.Module, + output_proj: nn.Module, + pos_embeddings: Optional[nn.Module] = None, + q_norm: Optional[nn.Module] = None, + k_norm: Optional[nn.Module] = None, + kv_cache: Optional[KVCache] = None, + max_seq_len: int = 4096, + is_causal: bool = True, + attn_dropout: float = 0.0, + sliding_window_size: Optional[int] = None, + softcapping: Optional[float] = 50.0, + query_pre_attn_scalar: Optional[int] = None, + ) -> None: + super().__init__() + if num_heads % num_kv_heads != 0: + raise ValueError( + f"num_heads ({num_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads})" + ) + + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim ({embed_dim}) must be divisible by " + f"num_heads ({num_heads})" + ) + + if attn_dropout < 0 or attn_dropout > 1: + raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") + + if bool(q_norm) ^ bool(k_norm): + raise ValueError("q and k norm must be set together") + + # Set attributes + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.embed_dim = embed_dim + self.attn_dropout = attn_dropout + self.head_dim = head_dim + self.max_seq_len = max_seq_len + self.is_causal = is_causal + + # Set layers + self.kv_cache = kv_cache + self.q_proj = q_proj + self.k_proj = k_proj + self.v_proj = v_proj + self.output_proj = output_proj + self.q_norm = q_norm + self.k_norm = k_norm + self.pos_embeddings = pos_embeddings + + # gemma related parameters + self.sliding_window_size = sliding_window_size + self.softcapping = softcapping + if query_pre_attn_scalar is not None: + self.scaling = query_pre_attn_scalar**-0.5 + else: + self.scaling = self.head_dim**-0.5 + + # this flag indicates whether to update the kv-cache during forward + # passes. when disabled, we can have the cache setup but still + # perform normal forward passes + self.cache_enabled = False + + def setup_cache( + self, batch_size: int, dtype: torch.dtype, max_seq_len: int + ) -> None: + """Setup key value caches for attention calculation. If called + after kv_cache is already setup, this will be skipped. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + max_seq_len (int): maximum sequence length model will be run with. + """ + # Don't overwrite user defined kv_cache from init + if self.kv_cache is not None: + logger.warning( + "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." + ) + else: + self.kv_cache = KVCache( + batch_size=batch_size, + max_seq_len=max_seq_len, + num_kv_heads=self.num_heads, + head_dim=self.head_dim, + dtype=dtype, + ) + self.cache_enabled = True + + def reset_cache(self): + """Reset the key value caches.""" + if self.kv_cache is None: + raise RuntimeError( + "Key value caches are not setup. Call ``setup_caches()`` first." + ) + self.kv_cache.reset() + + def forward( + self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + *, + mask: Optional[_MaskType] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape [b x s_x x d] for the query + y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input + for k and v. For self attention, x=y. Optional only with kv_cache enabled. + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. Either: + + A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``, + or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers. + A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means + token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask + is used by default. + + A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence + created via `create_block_mask `_. We use + :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks. + Default is None. + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape [b x s]. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Raises: + ValueError: If no ``y`` input and ``kv_cache`` is not enabled. + + Returns: + torch.Tensor: output tensor with attention applied + + Notation used for tensor shapes: + - b: batch size + - s_x: sequence length for x + - s_y: sequence length for y + - n_h: num heads + - n_kv: num kv heads + - d: embed dim + - h_d: head dim + """ + # until flex attention implementation exists, we do not accept block masks + if mask is not None and (not isinstance(mask, torch.Tensor)): + raise NotImplementedError( + "Block masks are not implemeted yet, use packed=False." + ) + + # x has shape [b, s_x, d] + # y has shape [b, s_y, d] + b, s_x, _ = x.shape + s_y = y.shape[1] if y is not None else 0 + + # q has shape [b, s_x, num_heads * head_dim] + q = self.q_proj(x) + + # number of queries per key/value + q_per_kv = self.num_heads // self.num_kv_heads + q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) + + # Apply positional embeddings + if self.pos_embeddings is not None: + q = self.pos_embeddings(q, input_pos=input_pos) + + # [b, n_h, s_x, h_d] + q = q.transpose(1, 2) + + # Normalize q + if self.q_norm is not None: + q = self.q_norm(q) + + if y is None: + if self.kv_cache is None: + raise ValueError( + "Must provide y input or use kv_cache to enable streaming decoding" + ) + k = self.kv_cache.k_cache + v = self.kv_cache.v_cache + else: + # Update k and v shape, positional embeddings, and normalization + + # k has shape [b, s_y, num_kv_heads * head_dim] + # v has shape [b, s_y, num_kv_heads * head_dim] + k = self.k_proj(y) + v = self.v_proj(y) + + # Apply positional embeddings + # k: [b, s_y, n_kv, h_d] + k = k.view(b, s_y, -1, self.head_dim) + if self.pos_embeddings is not None: + k = self.pos_embeddings(k, input_pos=input_pos) + + # View + expand + reshape bring num_kv_heads to num_heads for k and v + # to match q. + + # k: [b, s_y, n_kv, 1, h_d] + # v: [b, s_y, n_kv, 1, h_d] + k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim) + v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim) + + # If needed, expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + if self.num_heads != self.num_kv_heads: + k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) + v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) + + # [b, s, n_h, h_d] + k = k.reshape(b, s_y, -1, self.head_dim) + v = v.reshape(b, s_y, -1, self.head_dim) + + # [b, n_h, s, h_d] + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Normalize k + if self.k_norm is not None: + k = self.k_norm(k) + + # Update key-value cache + if self.kv_cache is not None and self.cache_enabled: + k, v = self.kv_cache.update(k, v) + + q.mul_(self.scaling) + output = torch.matmul( + q, k.transpose(2, 3) + ) # [batch_size, n_local_heads, input_len, head_dim] + + # if mask is None: default to causal mask + if mask is None: + mask = torch.tril( + torch.ones( + size=(s_x, s_x), + dtype=torch.bool, + ).to(x.device) + ) + + # update masks bias to be 0 for visible tokens and -2.3819763e38 otherwise + # this is similar to what torch sdpa is doing: + # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + if mask.dtype == torch.bool: + mask = torch.where(mask.logical_not(), -2.3819763e38, 0) + + if self.sliding_window_size is not None: + all_ones = torch.ones_like(mask) + + sliding_mask = torch.triu( + all_ones, -1 * self.sliding_window_size + 1 + ) * torch.tril(all_ones, self.sliding_window_size - 1) + mask = torch.where(sliding_mask == 1, mask, -2.3819763e38) + + if mask.dim() == 3: + # This is the case for block masks where attention is different per sample + # we want mask to be broadcastable with output so we aim for (bs, 1, s_x, s_y) + mask = mask.unsqueeze(1) + + if self.softcapping is not None: + output = output / self.softcapping + output = torch.tanh(output) + output = output * self.softcapping + + output = output + mask + output = F.softmax(output.float(), dim=-1).type_as(q) + + # [batch_size, n_local_heads, input_len, head_dim] + output = torch.matmul(output, v) + + # reshape the output to be the same shape as the input + output = output.transpose(1, 2).contiguous().view(b, s_x, -1) + return self.output_proj(output) diff --git a/torchtune/models/gemma2/_component_builders.py b/torchtune/models/gemma2/_component_builders.py new file mode 100644 index 0000000000..0ddef36857 --- /dev/null +++ b/torchtune/models/gemma2/_component_builders.py @@ -0,0 +1,413 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn +import torch +from typing import List +from torchtune.modules.common_utils import _register_reparametrize_state_dict_hooks +from typing import List, Optional + +from torchtune.modules import ( + FrozenNF4Linear, + RotaryPositionalEmbeddings, + TransformerSelfAttentionLayer, +) + +from torchtune.models.gemma2._attention import Gemma2Attention +from torchtune.models.gemma.rms_norm import GemmaRMSNorm +from torchtune.modules import TransformerDecoder, TiedLinear +from torchtune.models.gemma.gemma_norm_embedding import GemmaNormEmbeddings +from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear +from torchtune.models.gemma._component_builders import gemma_mlp, lora_gemma_mlp + +""" +Component builders for the Gemma2 2B, 9B models and popular variants such as LoRA. + +torchtune provides composable building blocks. Builder functions help +stitch these building blocks into higher-level components. This design has +two benefits: +- The building blocks themselves are very flexible. For example, ``MultiHeadAttention`` +can take either nn.Linear or nn.LoRALinear for ``q_proj``. +- Builder functions expose a set of configurable params which keep the constructors of +the building blocks simple. +""" + +class TanhSoftCapping(nn.Module): + def __init__( + self, + capping_value: float, + ) -> None: + super().__init__() + self.capping_value = capping_value + + def forward(self, attn_weights): + attn_weights = attn_weights / self.capping_value + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * self.capping_value + return attn_weights + +class Gemma2FinalNorm(nn.Module): + """ + Combines RMSNorm and SoftCapping + """ + def __init__( + self, + capping_value: float, + embed_dim: int, + eps: float + ) -> None: + super().__init__() + self.capping_value = capping_value + self.rms_norm = GemmaRMSNorm(embed_dim, eps=eps) + self.logit_capping = TanhSoftCapping(capping_value) + + def forward(self, x): + x = self.rms_norm(x) + x = self.logit_capping(x) + return x + + +def gemma2( + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + norm_eps: float = 1e-6, + rope_base: int = 10_000, + hidden_capping_value: float = 50., + final_capping_value: float = 30., + sliding_window_size: int = 4096, + query_pre_attn_scalar: Optional[int] = None, +) -> TransformerDecoder: + """ + Build the decoder associated with the gemma2 model. This includes: + - Token embeddings + - num_layers number of TransformerSelfAttentionLayer blocks + - RMS Norm layer applied to the output of the transformer + - Final projection into token space + + + Args: + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + head_dim (int): dimension of head + num_kv_heads (int): number of key and value heads. + embed_dim (int): embedding dimension for self-attention + intermediate_dim (int): intermediate dimension for MLP + max_seq_len (int): maximum sequence length the model will be run with, + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + norm_eps (float): epsilon in RMS norms Default: 1e-6 + rope_base (int): base for the rotary positional embeddings. Default: 10_000 + + Returns: + TransformerDecoder: Instantiation of gemma model. + """ + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + + layers = torch.nn.ModuleList() + + for layer_idx in range(num_layers): + + mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + + self_att = Gemma2Attention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(num_heads * head_dim, embed_dim, bias=False), + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + # perform sliding window on half of the layers only + sliding_window_size=sliding_window_size if (layer_idx % 2)==0 else None, + softcapping=hidden_capping_value, + query_pre_attn_scalar=query_pre_attn_scalar + ) + + layer = TransformerSelfAttentionLayer( + attn=self_att, + mlp=mlp, + sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + sa_scale=GemmaRMSNorm(embed_dim, eps=norm_eps), + mlp_scale=GemmaRMSNorm(embed_dim, eps=norm_eps), + ) + layers.append(layer) + tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim) + output_proj = TiedLinear(tok_embeddings) + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + output=output_proj, + head_dim=head_dim, + norm=Gemma2FinalNorm(final_capping_value, embed_dim, eps=norm_eps), + ) + return model + + + +def lora_gemma2( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + *, + # gemma args + vocab_size: int, + num_layers: int, + num_heads: int, + head_dim: int, + num_kv_heads: int, + embed_dim: int, + intermediate_dim: int, + max_seq_len: int, + attn_dropout: float = 0.0, + norm_eps: float = 1e-6, + rope_base: int = 10_000, + hidden_capping_value: float = 50., + final_capping_value: float = 30., + sliding_window_size: int = 4096, + query_pre_attn_scalar: Optional[int] = None, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Return a version of Gemma with LoRA applied based on the passed in configuration. + Note: output projection lora is not supported because it is tied to token embeddings + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + head_dim (int): dimension of head + num_kv_heads (int): number of key and value heads. + embed_dim (int): embedding dimension for self-attention + intermediate_dim (int): intermediate dimension for MLP + max_seq_len (int): maximum sequence length the model will be run with, + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + norm_eps (float): epsilon in RMS norms Default: 1e-6 + rope_base (int): base for the rotary positional embeddings. Default: 10_000 + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + use_dora (bool): Decompose the LoRA weight into magnitude and direction, as + introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + Returns: + TransformerDecoder: Instantiation of Gemma model with LoRA applied to + a subset of the attention projections in each layer. + """ + + tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim) + output_proj = TiedLinear(tok_embeddings) + + layers = torch.nn.ModuleList() + + for layer_idx in range(num_layers): + if apply_lora_to_mlp: + mlp = lora_gemma_mlp( + dim=embed_dim, + hidden_dim=intermediate_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + else: + mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) + self_att = lora_gemma2_self_attention( + lora_modules=lora_attn_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + rope_base=rope_base, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + # perform sliding window on half of the layers only + sliding_window_size=sliding_window_size if (layer_idx % 2)==0 else None, + softcapping=hidden_capping_value, + query_pre_attn_scalar=query_pre_attn_scalar, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora = use_dora, + quantize_base = quantize_base, + ) + + layer = TransformerSelfAttentionLayer( + attn=self_att, + mlp=mlp, + sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps), + sa_scale=GemmaRMSNorm(embed_dim, eps=norm_eps), + mlp_scale=GemmaRMSNorm(embed_dim, eps=norm_eps), + ) + layers.append(layer) + + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + output=output_proj, + head_dim=head_dim, + norm=Gemma2FinalNorm(final_capping_value, embed_dim, eps=norm_eps) + ) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly + # so as to not increase peak memory + # TODO this is clowny, figure out a better way to get what precision the rest + # of the model is in + _register_reparametrize_state_dict_hooks(model, dtype=tok_embeddings.weight.dtype) + + return model + + +def lora_gemma2_self_attention( + lora_modules: List[LORA_ATTN_MODULES], + *, + # MultiHeadAttention args + embed_dim: int, + num_heads: int, + head_dim: int, + num_kv_heads: int, + max_seq_len: int, + attn_dropout: float = 0.0, + rope_base: int = 10_000, + sliding_window_size: Optional[int] = None, + softcapping: Optional[float] = 50., + query_pre_attn_scalar: Optional[int], + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, + +) -> Gemma2Attention: + if not lora_modules: + raise ValueError( + f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" + ) + + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + adapter_cls = DoRALinear if use_dora else LoRALinear + + q_proj = ( + adapter_cls( + embed_dim, + num_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "q_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) + ) + k_proj = ( + adapter_cls( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "k_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + v_proj = ( + adapter_cls( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "v_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + output_proj = ( + adapter_cls( + num_heads * head_dim, + embed_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "output_proj" in lora_modules + else ( + nn.Linear(num_heads * head_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(num_heads * head_dim, embed_dim, bias=False) + ) + ) + + rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + + self_att = Gemma2Attention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=q_proj, + k_proj=k_proj, + v_proj=v_proj, + output_proj=output_proj, + pos_embeddings=rope, + kv_cache=None, + max_seq_len=max_seq_len, + attn_dropout=attn_dropout, + sliding_window_size=sliding_window_size, + softcapping=softcapping, + query_pre_attn_scalar=query_pre_attn_scalar + ) + return self_att \ No newline at end of file diff --git a/torchtune/models/gemma2/_convert_weights.py b/torchtune/models/gemma2/_convert_weights.py new file mode 100644 index 0000000000..fa4df0e469 --- /dev/null +++ b/torchtune/models/gemma2/_convert_weights.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch + +from torchtune.models.convert_weights import get_mapped_key + +""" +Gemma 2 and Gemma original implementations share different normalization but with +the same name, so it is mandatory to differentiate their state dict in order to map +correctly the different weights. +They are essentially the same except for "model.layers.{}.post_attention_layernorm.weight" key. +See discussion here: https://github.com/pytorch/torchtune/pull/1835#discussion_r1803410251 +""" + +_GEMMA2_FROM_HF = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attn.output_proj.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.mlp.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.mlp.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.mlp.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.sa_norm.scale", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.sa_scale.scale", + "model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.mlp_norm.scale", + "model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.mlp_scale.scale", + "model.norm.weight": "norm.rms_norm.scale", + "lm_head.weight": "output.weight", +} + + +def gemma2_hf_to_tune( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, + head_dim: int = None, +) -> Dict[str, torch.Tensor]: + """ + Convert a state dict from HF's format to torchtune's format. State dicts + from multiple checkpoint files should be consolidated into a single state dict + before calling this function. + + Eg of HF-format state dict can be found in the ``meta-llama/Llama-2-7b-hf`` + repo in HF (https://huggingface.co/meta-llama/Llama-2-7b-hf). + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in HF's format. + num_heads (int): Number of heads in the model. + num_kv_heads (int): Number of heads in the key/value projection layers. + dim (int): Dimension of the model. + head_dim (int): Dimension of the head. If not provided, it will be calculated + as dim // num_heads. + + Returns: + Dict[str, torch.Tensor]: State dict in torchtune's format. + """ + converted_state_dict = {} + if head_dim is None: + head_dim = dim // num_heads + + def _permute(t, n_heads): + return ( + t.view(n_heads, 2, head_dim // 2, dim) + .transpose(1, 2) + .reshape((head_dim * n_heads), dim) + ) + + for key, value in state_dict.items(): + if "rotary_emb.inv_freq" not in key: # Skip loading the position embeddings + new_key = get_mapped_key(key, _GEMMA2_FROM_HF) + if "q_proj" in key: + value = _permute(value, num_heads) + elif "k_proj" in key: + value = _permute(value, num_kv_heads) + + converted_state_dict[new_key] = value + return converted_state_dict + + +def gemma2_tune_to_hf( + state_dict: Dict[str, torch.Tensor], + num_heads: int = 32, + num_kv_heads: int = 32, + dim: int = 4096, + head_dim: int = None, +): + """ + Convert a state dict from torchtune's format to HF's format. This function + doesn't handle any sharding or splitting of state dicts. It follows the + state_dict IN -> state_dict OUT pattern. + + Args: + state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. + num_heads (int): Number of heads in the model. + num_kv_heads (int): Number of heads in the key/value projection layers. + dim (int): Dimension of the model. + head_dim (int): Dimension of model attention heads. Default None. + + Returns: + Dict[str, torch.Tensor]: State dict in HF's format. + """ + converted_state_dict = {} + inverted_mapping_dict = {v: k for k, v in _GEMMA2_FROM_HF.items()} + + if head_dim is None: + head_dim = dim // num_heads + + def _permute(t, n_heads): + return ( + t.view(n_heads, head_dim // 2, 2, dim) + .transpose(1, 2) + .reshape((head_dim * n_heads), dim) + ) + + for key, value in state_dict.items(): + new_key = get_mapped_key(key, inverted_mapping_dict) + if "q_proj" in key: + value = _permute(value, num_heads) + elif "k_proj" in key: + value = _permute(value, num_kv_heads) + converted_state_dict[new_key] = value + + return converted_state_dict diff --git a/torchtune/models/gemma2/_model_builders.py b/torchtune/models/gemma2/_model_builders.py new file mode 100644 index 0000000000..a07021c518 --- /dev/null +++ b/torchtune/models/gemma2/_model_builders.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +from torchtune.models.gemma2._component_builders import gemma2, lora_gemma2 +from torchtune.modules import TransformerDecoder + +from torchtune.modules.peft import LORA_ATTN_MODULES +from functools import partial + +""" +Model builders build specific instantiations using component builders. For example +the ``gemma_2b`` model builder uses the ``gemma2`` component builder. +""" + + +def gemma2_2b() -> TransformerDecoder: + """ + Builder for creating a Gemma2 2B model initialized w/ the default 2b parameter values + from: https://github.com/google/gemma_pytorch/blob/main/gemma/config.py + + Returns: + TransformerDecoder: Instantiation of Gemma2 2B model + """ + return gemma2( + vocab_size=256_000, + num_layers=26, + num_heads=8, + head_dim=256, + num_kv_heads=4, + embed_dim=2304, + intermediate_dim=9216, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + hidden_capping_value=30.0, + final_capping_value=50.0, + sliding_window_size=4096, + ) + + +def lora_gemma2_2b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Gemma2 2B model with LoRA enabled. + + The Gemma defaults are the same as in :func:`~torchtune.models.gemma.gemma_2b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 + use_dora (bool): Decompose the LoRA weight into magnitude and direction, as + introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Gemma2 2B model with LoRA applied + """ + return lora_gemma2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + vocab_size=256_000, + num_layers=26, + num_heads=8, + head_dim=256, + num_kv_heads=4, + embed_dim=2304, + intermediate_dim=9216, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + hidden_capping_value=30.0, + final_capping_value=50.0, + sliding_window_size=4096, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + +qlora_gemma2_2b = partial(lora_gemma2_2b, quantize_base=True) + +qlora_gemma2_2b.__doc__ = """ +Builder for creating a Gemma2 model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_gemm2a_2b` for full API arguments. +""" + + + +def gemma2_9b() -> TransformerDecoder: + """ + Builder for creating a Gemma2 9B model initialized w/ the default 9b parameter values + from: https://github.com/google/gemma_pytorch/blob/main/gemma/config.py + + Returns: + TransformerDecoder: Instantiation of Gemma 9B model + """ + return gemma2( + vocab_size=256_000, + num_layers=42, + num_heads=16, + head_dim=256, + num_kv_heads=8, + embed_dim=3584, + intermediate_dim=14336, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + hidden_capping_value=30.0, + final_capping_value=50.0, + sliding_window_size=4096, + ) + + +def lora_gemma2_9b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Gemma 9B model with LoRA enabled. + + The Gemma defaults are the same as in :func:`~torchtune.models.gemma.gemma_7b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 + use_dora (bool): Decompose the LoRA weight into magnitude and direction, as + introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Gemma2 9B model with LoRA applied + """ + return lora_gemma2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + vocab_size=256_000, + num_layers=42, + num_heads=16, + head_dim=256, + num_kv_heads=8, + embed_dim=3584, + intermediate_dim=14336, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + hidden_capping_value=30.0, + final_capping_value=50.0, + sliding_window_size=4096, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + +qlora_gemma2_9b = partial(lora_gemma2_9b, quantize_base=True) + +qlora_gemma2_9b.__doc__ = """ +Builder for creating a Gemma model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_gemma2_9b` for full API arguments. +""" + +def gemma2_27b() -> TransformerDecoder: + """ + Builder for creating a Gemma2 27B model initialized w/ the default 27b parameter values + from: https://github.com/google/gemma_pytorch/blob/main/gemma/config.py + + Returns: + TransformerDecoder: Instantiation of Gemma2 27B model + """ + return gemma2( + vocab_size=256_000, + num_layers=46, + num_heads=32, + head_dim=128, + num_kv_heads=16, + embed_dim=4608, + intermediate_dim=36864, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + hidden_capping_value=30.0, + final_capping_value=50.0, + sliding_window_size=4096, + query_pre_attn_scalar=144, + ) + + +def lora_gemma2_27b( + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Builder for creating a Gemma2 27B model with LoRA enabled. + + The Gemma defaults are the same as in :func:`~torchtune.models.gemma.gemma_7b`, + while LoRA default params are based on + https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43. + + Args: + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0 + use_dora (bool): Decompose the LoRA weight into magnitude and direction, as + introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353). + quantize_base (bool): Whether to quantize base model weights + + Returns: + TransformerDecoder: Instantiation of Gemma2 27B model with LoRA applied + """ + return lora_gemma2( + lora_attn_modules=lora_attn_modules, + apply_lora_to_mlp=apply_lora_to_mlp, + vocab_size=256_000, + num_layers=46, + num_heads=32, + head_dim=128, + num_kv_heads=16, + embed_dim=4608, + intermediate_dim=36864, + max_seq_len=8192, + attn_dropout=0.0, + norm_eps=1e-6, + hidden_capping_value=30.0, + final_capping_value=50.0, + sliding_window_size=4096, + query_pre_attn_scalar=144, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + +qlora_gemma2_27b = partial(lora_gemma2_27b, quantize_base=True) + +qlora_gemma2_27b.__doc__ = """ +Builder for creating a Gemma model with QLoRA enabled. Base model weights in linear layers +that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314. +Please see `lora_gemma2_27b` for full API arguments. +""" diff --git a/torchtune/models/qwen2_5/_model_builders.py b/torchtune/models/qwen2_5/_model_builders.py index 4474958862..7d39802375 100644 --- a/torchtune/models/qwen2_5/_model_builders.py +++ b/torchtune/models/qwen2_5/_model_builders.py @@ -29,7 +29,7 @@ def qwen2_5_0_5b() -> TransformerDecoder: TransformerDecoder: Instantiation of Qwen2.5 0.5B model Note: - Qwen2.5 0.5B-3B model builders will enable `tie_word_embeddings` by default. + Qwen2.5 0.5B-3B model builders will enable ``tie_word_embeddings`` by default (see :func:`~torchtune.models.qwen2.qwen2`) """ return qwen2( vocab_size=151936, @@ -59,7 +59,7 @@ def qwen2_5_1_5b_base() -> TransformerDecoder: except 0.5B and 3B. Make sure to select the correct model builder for the weights. Note: - Qwen2.5 0.5B-3B model builders will enable `tie_word_embeddings` by default. + Qwen2.5 0.5B-3B model builders will enable ``tie_word_embeddings`` by default (see :func:`~torchtune.models.qwen2.qwen2`). """ return qwen2( vocab_size=151936, @@ -89,7 +89,7 @@ def qwen2_5_1_5b_instruct() -> TransformerDecoder: except 0.5B and 3B. Make sure to select the correct model builder for the weights. Note: - Qwen2.5 0.5B-3B model builders will enable `tie_word_embeddings` by default. + Qwen2.5 0.5B-3B model builders will enable ``tie_word_embeddings`` by default (see :func:`~torchtune.models.qwen2.qwen2`) """ return qwen2( vocab_size=151936, @@ -115,7 +115,7 @@ def qwen2_5_3b() -> TransformerDecoder: TransformerDecoder: Instantiation of Qwen2.5 3B model Note: - Qwen2.5 0.5B-3B model builders will enable `tie_word_embeddings` by default. + Qwen2.5 0.5B-3B model builders will enable ``tie_word_embeddings`` by default (see :func:`~torchtune.models.qwen2.qwen2`) """ return qwen2( vocab_size=151936, @@ -419,7 +419,7 @@ def lora_qwen2_5_0_5b( TransformerDecoder: Instantiation of Qwen2.5 0.5B model with LoRA applied Note: - Qwen2.5 0.5B-3B model builders will enable `tie_word_embeddings` by default. + Qwen2.5 0.5B-3B model builders will enable ``tie_word_embeddings`` by default (see :func:`~torchtune.models.qwen2.qwen2`) """ return lora_qwen2( lora_attn_modules=lora_attn_modules, @@ -476,7 +476,7 @@ def lora_qwen2_5_1_5b_base( TransformerDecoder: Instantiation of Qwen2.5 1.5B model with LoRA applied Note: - Qwen2.5 0.5B-3B model builders will enable `tie_word_embeddings` by default. + Qwen2.5 0.5B-3B model builders will enable ``tie_word_embeddings`` by default (see :func:`~torchtune.models.qwen2.qwen2`) Note: The base and instruct versions have slightly different architectures for all Qwen2.5 model sizes @@ -537,7 +537,7 @@ def lora_qwen2_5_1_5b_instruct( TransformerDecoder: Instantiation of Qwen2.5 1.5B model with LoRA applied Note: - Qwen2.5 0.5B-3B model builders will enable `tie_word_embeddings` by default. + Qwen2.5 0.5B-3B model builders will enable ``tie_word_embeddings`` by default (see :func:`~torchtune.models.qwen2.qwen2`) Note: The base and instruct versions have slightly different architectures for all Qwen2.5 model sizes @@ -598,7 +598,7 @@ def lora_qwen2_5_3b( TransformerDecoder: Instantiation of Qwen2.5 3B model with LoRA applied Note: - Qwen2.5 0.5B-3B model builders will enable `tie_word_embeddings` by default. + Qwen2.5 0.5B-3B model builders will enable ``tie_word_embeddings`` by default (see :func:`~torchtune.models.qwen2.qwen2`) """ return lora_qwen2( lora_attn_modules=lora_attn_modules, diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 879f0679cf..b74c70113e 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -164,7 +164,7 @@ def setup_cache( self.kv_cache = KVCache( batch_size=batch_size, max_seq_len=max_seq_len, - num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, head_dim=self.head_dim, dtype=dtype, ) @@ -258,47 +258,37 @@ def forward( else: # Update k and v shape, positional embeddings, and normalization - # k has shape [b, s_y, num_kv_heads * head_dim] - # v has shape [b, s_y, num_kv_heads * head_dim] + # k,v shape [b, s_y, num_kv_heads * head_dim] k = self.k_proj(y) v = self.v_proj(y) # Apply positional embeddings - # k: [b, s_y, n_kv, h_d] + # k,v shape: [b, s_y, n_kv, h_d] k = k.view(b, s_y, -1, self.head_dim) + v = v.view(b, s_y, -1, self.head_dim) if self.pos_embeddings is not None: k = self.pos_embeddings(k, input_pos=input_pos) - # View + expand + reshape bring num_kv_heads to num_heads for k and v - # to match q. + # k,v shape: [b, n_kv, s_y, h_d] + k = k.transpose(1, 2) + v = v.transpose(1, 2) - # k: [b, s_y, n_kv, 1, h_d] - # v: [b, s_y, n_kv, 1, h_d] - k = k.view(b, s_y, self.num_kv_heads, 1, self.head_dim) - v = v.view(b, s_y, self.num_kv_heads, 1, self.head_dim) + # Update key-value cache + if self.kv_cache is not None and self.cache_enabled: + k, v = self.kv_cache.update(k, v) # If needed, expand the key and value tensors to have the same shape # as the query tensor by copying values across the relevant dim + # k,v shape: [b, n_h, s, h_d] if self.num_heads != self.num_kv_heads: - k = k.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) - v = v.expand(b, s_y, self.num_kv_heads, q_per_kv, self.head_dim) - - # [b, s, n_h, h_d] - k = k.reshape(b, s_y, -1, self.head_dim) - v = v.reshape(b, s_y, -1, self.head_dim) - - # [b, n_h, s, h_d] - k = k.transpose(1, 2) - v = v.transpose(1, 2) + expand_shape = (-1, -1, q_per_kv, -1, -1) + k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) + v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) # Normalize k if self.k_norm is not None: k = self.k_norm(k) - # Update key-value cache - if self.kv_cache is not None and self.cache_enabled: - k, v = self.kv_cache.update(k, v) - output = self._attention_call( q, k, diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index facd9703ca..e96491c22a 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -17,9 +17,7 @@ class KVCache(nn.Module): Args: batch_size (int): batch size model will be run with max_seq_len (int): maximum sequence length model will be run with - num_heads (int): number of heads. We take num_heads instead of num_kv_heads because - the cache is created after we've expanded the key and value tensors to have the - same shape as the query tensor. See attention.py for more details + num_kv_heads (int): number of key/value heads. head_dim (int): per-attention head embedding dimension dtype (torch.dtype): dtype for the caches """ @@ -28,12 +26,12 @@ def __init__( self, batch_size: int, max_seq_len: int, - num_heads: int, + num_kv_heads: int, head_dim: int, dtype: torch.dtype, ) -> None: super().__init__() - cache_shape = (batch_size, num_heads, max_seq_len, head_dim) + cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim) self.register_buffer( "k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False ) @@ -66,7 +64,7 @@ def update( already been filled, use ``.reset()``, which will reset the cache to the zero-th position. Example: - >>> cache = KVCache(batch_size=2, max_seq_len=16, num_heads=4, head_dim=32, dtype=torch.bfloat16) + >>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16) >>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32)) >>> cache.update(keys, values) >>> # now positions 0 through 7 are filled diff --git a/torchtune/modules/model_fusion/__init__.py b/torchtune/modules/model_fusion/__init__.py index 7ad788bd57..21a3c1c063 100644 --- a/torchtune/modules/model_fusion/__init__.py +++ b/torchtune/modules/model_fusion/__init__.py @@ -4,7 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._fusion import DeepFusionModel, FusionEmbedding, FusionLayer +from ._deep_fusion import DeepFusionModel +from ._early_fusion import EarlyFusionModel +from ._fusion_layers import FusionEmbedding, FusionLayer from ._fusion_utils import get_fusion_params, register_fusion_module __all__ = [ @@ -13,4 +15,5 @@ "FusionEmbedding", "register_fusion_module", "get_fusion_params", + "EarlyFusionModel", ] diff --git a/torchtune/modules/model_fusion/_deep_fusion.py b/torchtune/modules/model_fusion/_deep_fusion.py new file mode 100644 index 0000000000..6a61c43744 --- /dev/null +++ b/torchtune/modules/model_fusion/_deep_fusion.py @@ -0,0 +1,212 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional, Union + +import torch +from torch import nn +from torchtune.modules import TransformerDecoder +from torchtune.modules.model_fusion._fusion_utils import get_fusion_params +from torchtune.modules.peft._utils import set_trainable_params + + +class DeepFusionModel(nn.Module): + """DeepFusion is a type of fused model architecture where a pretrained encoder is combined + with a pretrained decoder (LLM) in the internal decoder layers. This is a popular architecture for multimodal models, with + a full overview available in `The Evolution of Multimodal Model Architectures `_. + A common deep fusion architecture is to fuse the encoder input into the decoder with interspersed cross-attention + layers. This module makes no assumptions on how the encoder and decoder are fused; it simply + passes in the encoder embeddings to the decoder and lets the decoder handle any fusion. + + This module has the same methods and forward signature as :class:`~torchtune.modules.TransformerDecoder` and can be used + interchangeably where :class:`~torchtune.modules.TransformerDecoder` is. It combines the encoder with the decoder as a + single module for checkpointing and finetuning. It is expected that the encoder and decoder + are already defined with any extra learnable ``fusion_params``: learnable parameters to help + adapt the pre-trained encoder to the pre-trained decoder. + + DeepFusionModel currently only supports a single encoder. + + Example: + >>> # decoder is a TransformerDecoder (e.g. llama3_8b) with fused cross attention layers + >>> embed = FusionEmbedding(...) + >>> layer = FusionLayer( + ... layer=TransformerSelfAttentionLayer(...), + ... fusion_layer=TransformerCrossAttentionLayer(...), + ... ) + >>> decoder = TransformerDecoder(tok_embeddings=embed, layers=layer, num_layers=32, ...) + >>> + >>> # encoder is pre-trained encoder (e.g. clip_vit_224) with an added projection head + >>> projection_head = FeedForward(...) + >>> register_fusion_module(projection_head)) + >>> encoder = nn.Sequential(clip_vit_224(), projection_head) + >>> + >>> # DeepFusionModel combines the encoder and decoder + >>> model = DeepFusionModel(decoder, encoder) + >>> + >>> # Load full fused checkpoints (e.g. a Flamingo checkpoint) + >>> model.load_state_dict(...) + >>> + >>> # Or load pretrained individual models (fusion_params are not loaded) + >>> model.encoder.load_state_dict(..., strict=False) + >>> model.decoder.load_state_dict(..., strict=False) + >>> + >>> # Forward pass + >>> output = model(tokens, mask, encoder_input, encoder_mask, input_pos) + + Args: + decoder (TransformerDecoder): decoder module + encoder (nn.Module): encoder module + decoder_trainable (bool): whether to train or freeze the decoder. Default is False. + encoder_trainable (bool): whether to train or freeze the encoder. Default is False. + fusion_trainable (bool): whether to train the fusion parameters. Default is True. + + """ + + def __init__( + self, + decoder: TransformerDecoder, + encoder: nn.Module, + *, + decoder_trainable: bool = False, + encoder_trainable: bool = False, + fusion_trainable: bool = True, + ): + super().__init__() + self.decoder = decoder + self.encoder = encoder + + trainable_params = set() + if encoder_trainable: + trainable_params |= { + f"encoder.{n}" for n, p in self.encoder.named_parameters() + } + if decoder_trainable: + trainable_params |= { + f"decoder.{n}" for n, p in self.decoder.named_parameters() + } + if fusion_trainable: + trainable_params |= set(get_fusion_params(self)) + else: + trainable_params -= set(get_fusion_params(self)) + set_trainable_params(self, trainable_params) + + def set_num_output_chunks(self, num_output_chunks: int) -> None: + """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. + This should be called before the first forward pass, in the recipe.""" + self.decoder.set_num_output_chunks(num_output_chunks) + + def setup_caches( + self, + batch_size: int, + dtype: torch.dtype, + *, + encoder_max_seq_len: Optional[int] = None, + decoder_max_seq_len: Optional[int] = None, + ): + """ + Sets up key-value attention caches for inference for ``self.decoder``. + For each layer in ``self.decoder.layers``: + - :class:`torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. + - :class:`torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. + - :class:`torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + encoder_max_seq_len (Optional[int]): maximum encoder cache sequence length. + decoder_max_seq_len (Optional[int]): maximum decoder cache sequence length. + """ + self.decoder.setup_caches( + batch_size, + dtype, + encoder_max_seq_len=encoder_max_seq_len, + decoder_max_seq_len=decoder_max_seq_len, + ) + + def caches_are_setup(self) -> bool: + """ + Check if the key value caches are setup. This means ``setup_caches`` has been called, and + the relevant attention modules in the model have created their ``KVCache``. + """ + return self.decoder.caches_are_setup() + + def caches_are_enabled(self) -> bool: + """ + Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant + attention modules will be "enabled" and all forward passes will update the caches. This behaviour + can be disabled without altering the state of the KV-caches by "disabling" the KV-caches + using :func:`~torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False. + """ + return self.decoder.caches_are_enabled() + + def reset_caches(self): + """ + Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero, + without deleting or reallocating cache tensors. + """ + self.decoder.reset_caches() + + def forward( + self, + tokens: torch.Tensor, + *, + mask: Optional[torch.Tensor] = None, + encoder_input: Optional[Dict] = None, + encoder_mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + """ + Args: + tokens (torch.Tensor): input tensor with shape ``[b x s]`` + mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask + with shape ``[b x s x s]``. This is applied after the query-key multiplication and + before the softmax. A value of True in row i and column j means token i attends + to token j. A value of False means token i does not attend to token j. If no + mask is specified, a causal mask is used by default. Default is None. + encoder_input (Optional[Dict]): Optional input for the encoder. + encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between + tokens and encoder embeddings. A True value at position i,j means token i can attend + to embedding j in the decoder. Mask has shape ``[b x s x s_e]``. Default is None. + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape ``[b x s]``. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + + Note: At the very first step of inference, when the model is provided with a prompt, + ``input_pos`` would contain the positions of all of the tokens in the prompt + (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the + KV values for each position. + + Returns: + Tensor: output tensor with shape ``[b x s x v]`` or a list of layer \ + output tensors defined by ``output_hidden_states`` with the \ + final output tensor appended to the list. + + Notation used for tensor shapes: + - b: batch size + - s: token sequence length + - s_e: encoder sequence length + - v: vocab size + - d: token embed dim + - d_e: encoder embed dim + - m_s: max seq len + """ + # During decoding, encoder_input will only be provided + # for new inputs. Previous encoder outputs are cached + # in the decoder cache. + encoder_embed = None + if encoder_input is not None: + encoder_embed = self.encoder(**encoder_input) + + output = self.decoder( + tokens=tokens, + mask=mask, + encoder_input=encoder_embed, + encoder_mask=encoder_mask, + input_pos=input_pos, + ) + return output diff --git a/torchtune/modules/model_fusion/_early_fusion.py b/torchtune/modules/model_fusion/_early_fusion.py new file mode 100644 index 0000000000..d20b2d119f --- /dev/null +++ b/torchtune/modules/model_fusion/_early_fusion.py @@ -0,0 +1,278 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn +from torchtune.modules import TransformerDecoder +from torchtune.modules.model_fusion._fusion_utils import get_fusion_params +from torchtune.modules.peft._utils import set_trainable_params + + +class EarlyFusionModel(nn.Module): + """EarlyFusion is a type of fused model architecture where pretrained encoder(s) are combined + with a pretrained decoder (LLM) at the model input and not in internal layers. This is a popular architecture + for multimodal models, with a full overview available in `The Evolution of Multimodal Model Architectures + `_. This module works both for decoders in which the encoder tokens are + inside the vocab and outside the vocab. + + This module has the same methods and forward signature as :class:`~torchtune.modules.TransformerDecoder` and can be used + interchangeably where :class:`~torchtune.modules.TransformerDecoder` is. It combines the encoders with the decoder as a + single module for checkpointing and finetuning. It is expected that the encoders and decoder + are already defined with any extra learnable ``fusion_params``: learnable parameters to help + adapt the pre-trained encoders to the pre-trained decoder. + + You can pass in multiple encoders as a dictionary into ``encoders``. + + Note: Once the decoder is wrapped in this module, the decoder's ``tok_embeddings`` module is moved + to the parent EarlyFusionModel's ``tok_embeddings``. You should not forward pass the decoder individually. + Instead, use EarlyFusionModel's forward pass with ``encoder_input=None`` to get decoder-only outputs. + State dicts will automatically be updated on save and load to account for this change. + + Example: + >>> # decoder is a text-only TransformerDecoder (e.g. llama3_8b) with no modifications + >>> decoder = llama3_8b() + >>> + >>> # encoder is pre-trained encoder (e.g. clip_vit_224) with an added projection head + >>> projection_head = FeedForward(...) + >>> register_fusion_module(projection_head)) + >>> encoders = {"image": nn.Sequential(clip_vit_224(), projection_head)} + >>> + >>> # EarlyFusionModel combines the encoder and decoder + >>> model = EarlyFusionModel(decoder, encoders, encoder_tokens={"image": 128256}) + >>> + >>> # Load full fused checkpoints + >>> model.load_state_dict(...) + >>> + >>> # Forward pass + >>> encoder_input = {"image": {...}} + >>> output = model(tokens, mask=mask, encoder_input=encoder_input, encoder_mask=encoder_mask, input_pos=input_pos) + >>> + >>> # Forward pass decoder only + >>> output = model(tokens, mask=mask, input_pos=input_pos) + + Args: + decoder (TransformerDecoder): decoder module + encoders (Dict[str, nn.Module]): dictionary mapping encoder name as a string to the encoder module. + encoder_tokens (Dict[str, int]): dictionary mapping encoder name to special token ID indicating where + in the text sequence the encoder embedding outputs should be injected. + decoder_trainable (bool): whether to train or freeze the decoder. Default is False. + encoders_trainable (Union[bool, Dict[str, bool]]): whether to train or freeze the encoder. Use a single + boolean to set trainable for all encoders or a dictionary keyed by encoder names to specify trainable + for each encoder individually. Encoder names should match with ``encoders``. Default is False. + fusion_trainable (bool): whether to train the fusion parameters. Default is True. + + Raises: + ValueError: if ``encoders`` and ``encoders_trainable`` keys do not match + """ + + def __init__( + self, + decoder: TransformerDecoder, + encoders: Dict[str, nn.Module], + encoder_tokens: Dict[str, int], + decoder_trainable: bool = False, + encoders_trainable: Union[bool, Dict[str, bool]] = False, + fusion_trainable: bool = True, + ): + super().__init__() + if encoders.keys() != encoder_tokens.keys() or ( + not isinstance(encoders_trainable, bool) + and encoders.keys() != encoders_trainable.keys() + ): + raise ValueError( + f"Found mismatched keys in encoders, encoder_tokens, and/or encoders_trainable. Expected {encoders.keys()}" + ) + + self.decoder = decoder + self.encoders = nn.ModuleDict(encoders) + self.encoder_tokens = encoder_tokens + self.encoders_trainable = ( + {k: encoders_trainable for k in self.encoders.keys()} + if isinstance(encoders_trainable, bool) + else encoders_trainable + ) + + # A little surgery in the decoder to give the + # fusion module access to control the embeddings + # The alternative is to pass a special tok_embeddings + # module into TransformerDecoder builder that does the + # merging there + self.tok_embeddings = decoder.tok_embeddings + decoder.tok_embeddings = nn.Identity() + + self._register_state_dict_hook(self._state_dict_hook) + self.register_load_state_dict_pre_hook(self._load_state_dict_hook) + + trainable_params = set() + for encoder, trainable in self.encoders_trainable.items(): + if trainable: + trainable_params |= { + f"encoders.{encoder}.{n}" + for n, p in self.encoders[encoder].named_parameters() + } + if decoder_trainable: + trainable_params |= { + f"decoder.{n}" for n, p in self.decoder.named_parameters() + } + trainable_params |= { + f"tok_embeddings.{n}" for n, p in self.tok_embeddings.named_parameters() + } + if fusion_trainable: + trainable_params |= set(get_fusion_params(self)) + else: + trainable_params -= set(get_fusion_params(self)) + + set_trainable_params(self, trainable_params) + + @staticmethod + def _state_dict_hook(module, state_dict, *args, **kwargs): + """ + Keep tok_embeddings inside of decoder state_dict + + [!Note] This update changes the order of the OrderedDict + """ + for n, p in module.tok_embeddings.named_parameters(): + state_dict[f"decoder.tok_embeddings.{n}"] = p + del state_dict[f"tok_embeddings.{n}"] + + @staticmethod + def _load_state_dict_hook(module, state_dict, *args, **kwargs): + """Undo the change from _state_dict_hook""" + old_keys = list(state_dict.keys()) + for key in old_keys: + if key.startswith("decoder.tok_embeddings"): + state_dict[key[len("decoder.") :]] = state_dict[key] + del state_dict[key] + + def set_num_output_chunks(self, num_output_chunks: int) -> None: + """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. + This should be called before the first forward pass, in the recipe.""" + self.decoder.set_num_output_chunks(num_output_chunks) + + def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: + """Setup key value caches for attention calculation. + + Args: + batch_size (int): batch size for the caches. + dtype (torch.dtype): dtype for the caches. + """ + self.decoder.setup_caches(batch_size, dtype) + + def caches_are_setup(self) -> bool: + """ + Check if the key value caches are setup. This means ``setup_caches`` has been called, and + the relevant attention modules in the model have created their ``KVCache``. + """ + return self.decoder.caches_are_setup() + + def caches_are_enabled(self) -> bool: + """ + Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant + attention modules will be "enabled" and all forward passes will update the caches. This behaviour + can be disabled without altering the state of the KV-caches by "disabling" the KV-caches + using :func:`~torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False. + """ + return self.decoder.caches_are_enabled() + + def reset_caches(self): + """Reset the key value caches.""" + self.decoder.reset_caches() + + def _decoder_embed(self, tokens) -> Tuple[torch.Tensor, torch.Tensor]: + """Embed the text-only tokens with the decoder's tok_embeddings""" + encoder_token_ids = torch.tensor(list(self.encoder_tokens.values())) + # [bsz, seq_len], True indicates the token is not an encoder special token + is_text = ~torch.isin(tokens, encoder_token_ids) + text_tokens = torch.masked_select(tokens, is_text) + # [num_text, embed_dim] + text_embeds = self.tok_embeddings(text_tokens) + return is_text, text_embeds + + def forward( + self, + tokens: torch.Tensor, + *, + mask: Optional[torch.Tensor] = None, + encoder_input: Optional[Dict[str, Dict[str, Any]]] = None, + input_pos: Optional[torch.Tensor] = None, + **kwargs: Dict[str, Any], # no need for encoder_mask + ) -> torch.Tensor: + """ + Note: This module assumes that there will be enough encoder inputs (i.e., total number of images in the batch) + for the number of encoder tokens in the batch. + + Args: + tokens (torch.Tensor): input tensor with shape ``[b x s]`` + mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask + with shape ``[b x s x s]``. This is applied after the query-key multiplication and + before the softmax. A value of True in row i and column j means token i attends + to token j. A value of False means token i does not attend to token j. If no + mask is specified, a causal mask is used by default. Default is None. + encoder_input (Optional[Dict[str, Dict[str, Any]]]): Optional input kwargs for the encoders. Must be + keyed by encoder name and match the keys of ``encoders`` + input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids + of each token. During training, this is used to indicate the positions + of each token relative to its sample when packed, shape ``[b x s]``. + During inference, this indicates the position of the current token. + If none, assume the index of the token is its position id. Default is None. + **kwargs (Dict[str, Any]): additional keyword arguments. This is solely used to match the + :class:`~torchtune.modules.TransformerDecoder` forward and does not have any effect. + + Note: At the very first step of inference, when the model is provided with a prompt, + ``input_pos`` would contain the positions of all of the tokens in the prompt + (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the + KV values for each position. + + Returns: + torch.Tensor: output tensor with shape ``[b x s x v]`` or a list of layer \ + output tensors defined by ``output_hidden_states`` with the \ + final output tensor appended to the list. + + Raises: + ValueError: if ``encoder_input`` keys do not match ``encoders`` keys + + Notation used for tensor shapes: + - b: batch size + - s: token sequence length + - s_e: encoder sequence length + - v: vocab size + - d: token embed dim + - d_e: encoder embed dim + - m_s: max seq len + """ + if encoder_input is not None and encoder_input.keys() != self.encoders.keys(): + raise ValueError( + f"Found mismatched keys in encoder_input and instantiated encoders. " + f"Got {encoder_input.keys()}, expected {self.encoders.keys()}." + ) + + bsz, seq_len = tokens.shape + # is_text: [bsz, seq_len], text_embeds: [num_text, embed_dim] + is_text, text_embeds = self._decoder_embed(tokens) + embed_dim = text_embeds.shape[-1] + + # Holds the final embedding vector + fused_embeds = torch.empty( + bsz, seq_len, embed_dim, dtype=text_embeds.dtype, device=text_embeds.device + ) + # Place the text-only embeddings + fused_embeds = fused_embeds.masked_scatter(is_text.unsqueeze(-1), text_embeds) + + encoder_input = encoder_input or {} + for encoder, inp in encoder_input.items(): + # [bsz, num_encoder_tokens, embed_dim] + encoder_embeds = self.encoders[encoder](**inp) + # [bsz * num_encoder_tokens, embed_dim] + encoder_embeds = encoder_embeds.view(-1, embed_dim) + # [bsz, seq_len, 1] + encoder_mask = (tokens == self.encoder_tokens[encoder]).unsqueeze(-1) + # At locations where encoder token is found, replace with encoder embedding + fused_embeds = fused_embeds.masked_scatter(encoder_mask, encoder_embeds) + + output = self.decoder(fused_embeds, mask=mask, input_pos=input_pos) + return output diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion_layers.py similarity index 54% rename from torchtune/modules/model_fusion/_fusion.py rename to torchtune/modules/model_fusion/_fusion_layers.py index 907c7a2ed0..7fe3939ec4 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion_layers.py @@ -4,13 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Optional, Union +from typing import Dict, List import torch from torch import nn -from torchtune.modules import TransformerDecoder -from torchtune.modules.model_fusion._fusion_utils import get_fusion_params -from torchtune.modules.peft._utils import set_trainable_params class FusionLayer(nn.Module): @@ -273,9 +270,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # num_fusion_tokens = (input >= vocab_size).sum() fusion_tokens = torch.masked_select(input, ~mask) - vocab_size - # [batch_size x num_tokens x embed_dim] + # [batch_size * num_tokens, embed_dim] embeds = self.embedding(tokens) - # [batch_size x num_fusion_tokens x embed_dim] + # [batch_size * num_fusion_tokens, embed_dim] fusion_embeds = self.fusion_embedding(fusion_tokens) # [batch_size x seq_length x embed_dim] @@ -284,197 +281,3 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: out = out.masked_scatter(mask, embeds) out = out.masked_scatter(~mask, fusion_embeds) return out - - -class DeepFusionModel(nn.Module): - """DeepFusion is a type of fused model architecture where a pretrained encoder is combined - with a pretrained decoder (LLM). This is a popular architecture for multimodal models, with - a full overview available in `The Evolution of Multimodal Model Architectures `_. - - This module has the same methods and forward signature as :class:`~torchtune.modules.TransformerDecoder` and can be used - interchangeably where :class:`~torchtune.modules.TransformerDecoder` is. It combines the encoder with the decoder as a - single module for checkpointing and finetuning. It is expected that the encoder and decoder - are already defined with any extra learnable ``fusion_params``: learnable parameters to help - adapt the pre-trained encoder to the pre-trained decoder. - - Example: - >>> # decoder is a TransformerDecoder (e.g. llama3_8b) with fused cross attention layers - >>> embed = FusionEmbedding(...) - >>> layer = FusionLayer( - ... layer=TransformerSelfAttentionLayer(...), - ... fusion_layer=TransformerCrossAttentionLayer(...), - ... ) - >>> decoder = TransformerDecoder(tok_embeddings=embed, layers=layer, num_layers=32, ...) - >>> - >>> # encoder is pre-trained encoder (e.g. clip_vit_224) with an added projection head - >>> projection_head = FeedForward(...) - >>> register_fusion_module(projection_head)) - >>> encoder = nn.Sequential(clip_vit_224(), projection_head) - >>> - >>> # DeepFusionModel combines the encoder and decoder - >>> model = DeepFusionModel(decoder, encoder) - >>> - >>> # Load full fused checkpoints (e.g. a Flamingo checkpoint) - >>> model.load_state_dict(...) - >>> - >>> # Or load pretrained individual models (fusion_params are not loaded) - >>> model.encoder.load_state_dict(..., strict=False) - >>> model.decoder.load_state_dict(..., strict=False) - >>> - >>> # Forward pass - >>> output = model(tokens, mask, encoder_input, encoder_mask, input_pos) - - Args: - decoder (TransformerDecoder): decoder module - encoder (nn.Module): encoder module - decoder_trainable (bool): whether to train or freeze the decoder. Default is False. - encoder_trainable (bool): whether to train or freeze the encoder. Default is False. - fusion_trainable (bool): whether to train the fusion parameters. Default is True. - - """ - - def __init__( - self, - decoder: TransformerDecoder, - encoder: nn.Module, - *, - decoder_trainable: bool = False, - encoder_trainable: bool = False, - fusion_trainable: bool = True, - ): - super().__init__() - self.decoder = decoder - self.encoder = encoder - - trainable_params = set() - if encoder_trainable: - trainable_params |= { - f"encoder.{n}" for n, p in self.encoder.named_parameters() - } - if decoder_trainable: - trainable_params |= { - f"decoder.{n}" for n, p in self.decoder.named_parameters() - } - if fusion_trainable: - trainable_params |= set(get_fusion_params(self)) - else: - trainable_params -= set(get_fusion_params(self)) - set_trainable_params(self, trainable_params) - - def set_num_output_chunks(self, num_output_chunks: int) -> None: - """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. - This should be called before the first forward pass, in the recipe.""" - self.decoder.set_num_output_chunks(num_output_chunks) - - def setup_caches( - self, - batch_size: int, - dtype: torch.dtype, - *, - encoder_max_seq_len: int = None, - decoder_max_seq_len: int = None, - ): - """ - Sets up key-value attention caches for inference for ``self.decoder``. - For each layer in ``self.decoder.layers``: - - :class:`torchtune.modules.TransformerSelfAttentionLayer` will use ``decoder_max_seq_len``. - - :class:`torchtune.modules.TransformerCrossAttentionLayer` will use ``encoder_max_seq_len``. - - :class:`torchtune.modules.fusion.FusionLayer` will use both ``decoder_max_seq_len`` and ``encoder_max_seq_len``. - - Args: - batch_size (int): batch size for the caches. - dtype (torch.dtype): dtype for the caches. - encoder_max_seq_len (int): maximum encoder cache sequence length. - decoder_max_seq_len (int): maximum decoder cache sequence length. - """ - self.decoder.setup_caches( - batch_size, - dtype, - encoder_max_seq_len=encoder_max_seq_len, - decoder_max_seq_len=decoder_max_seq_len, - ) - - def caches_are_setup(self) -> bool: - """ - Check if the key value caches are setup. This means ``setup_caches`` has been called, and - the relevant attention modules in the model have created their ``KVCache``. - """ - return self.decoder.caches_are_setup() - - def caches_are_enabled(self) -> bool: - """ - Checks if the key value caches are enabled. Once KV-caches have been setup, the relevant - attention modules will be "enabled" and all forward passes will update the caches. This behaviour - can be disabled without altering the state of the KV-caches by "disabling" the KV-caches - using :func:`~torchtune.modules.common_utils.disable_kv_cache`, upon which ``caches_are_enabled`` would return False. - """ - return self.decoder.caches_are_enabled() - - def reset_caches(self): - """ - Resets KV-cache buffers on relevant attention modules to zero, and reset cache positions to zero, - without deleting or reallocating cache tensors. - """ - self.decoder.reset_caches() - - def forward( - self, - tokens: torch.Tensor, - *, - mask: Optional[torch.Tensor] = None, - encoder_input: Optional[Dict] = None, - encoder_mask: Optional[torch.Tensor] = None, - input_pos: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Args: - tokens (torch.Tensor): input tensor with shape ``[b x s]`` - mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask - with shape ``[b x s x s]``. This is applied after the query-key multiplication and - before the softmax. A value of True in row i and column j means token i attends - to token j. A value of False means token i does not attend to token j. If no - mask is specified, a causal mask is used by default. Default is None. - encoder_input (Optional[Dict]): Optional input for the encoder. - encoder_mask (Optional[torch.Tensor]): Boolean tensor defining a relational matrix between - tokens and encoder embeddings. A True value at position i,j means token i can attend - to embedding j in the decoder. Mask has shape ``[b x s x s_e]``. Default is None. - input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids - of each token. During training, this is used to indicate the positions - of each token relative to its sample when packed, shape ``[b x s]``. - During inference, this indicates the position of the current token. - If none, assume the index of the token is its position id. Default is None. - - Note: At the very first step of inference, when the model is provided with a prompt, - ``input_pos`` would contain the positions of all of the tokens in the prompt - (eg: ``torch.arange(prompt_length)``). This is because we will need to compute the - KV values for each position. - - Returns: - Tensor: output tensor with shape ``[b x s x v]`` or a list of layer \ - output tensors defined by ``output_hidden_states`` with the \ - final output tensor appended to the list. - - Notation used for tensor shapes: - - b: batch size - - s: token sequence length - - s_e: encoder sequence length - - v: vocab size - - d: token embed dim - - d_e: encoder embed dim - - m_s: max seq len - """ - # During decoding, encoder_input will only be provided - # for new inputs. Previous encoder outputs are cached - # in the decoder cache. - encoder_embed = None - if encoder_input is not None: - encoder_embed = self.encoder(**encoder_input) - - output = self.decoder( - tokens=tokens, - mask=mask, - encoder_input=encoder_embed, - encoder_mask=encoder_mask, - input_pos=input_pos, - ) - return output diff --git a/torchtune/modules/model_fusion/_fusion_utils.py b/torchtune/modules/model_fusion/_fusion_utils.py index c22cc03549..e10bfcb3e5 100644 --- a/torchtune/modules/model_fusion/_fusion_utils.py +++ b/torchtune/modules/model_fusion/_fusion_utils.py @@ -65,5 +65,5 @@ def get_fusion_params(model: nn.Module) -> Dict[str, nn.Parameter]: current_fusion_params.remove(n) assert ( current_fusion_params == [] - ), f"Fusion params {current_adapter_params} not converted" + ), f"Fusion params {current_fusion_params} not converted" return fusion_params diff --git a/torchtune/modules/peft/__init__.py b/torchtune/modules/peft/__init__.py index 44922aa83d..4d678ea6ab 100644 --- a/torchtune/modules/peft/__init__.py +++ b/torchtune/modules/peft/__init__.py @@ -8,6 +8,7 @@ AdapterModule, disable_adapter, get_adapter_params, + get_adapter_state_dict, get_lora_module_names, get_merged_lora_ckpt, load_dora_magnitudes, @@ -30,6 +31,7 @@ "validate_state_dict_for_lora", "load_dora_magnitudes", "disable_adapter", + "get_adapter_state_dict", "get_merged_lora_ckpt", "get_lora_module_names", ] diff --git a/torchtune/modules/peft/_utils.py b/torchtune/modules/peft/_utils.py index 4768d77619..e0e29bb716 100644 --- a/torchtune/modules/peft/_utils.py +++ b/torchtune/modules/peft/_utils.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib -from typing import Any, Dict, Generator, List, Literal, Optional, Protocol, Set +from typing import Any, Dict, Generator, List, Literal, Optional, Protocol, Set, Union import torch from torch import nn @@ -62,13 +62,15 @@ def get_adapter_params(model: nn.Module) -> Dict[str, nn.Parameter]: return adapter_params -def set_trainable_params(model: nn.Module, adapter_params: Dict[str, Any]) -> None: +def set_trainable_params( + model: nn.Module, adapter_params: Union[Dict[str, Any], Set] +) -> None: """ Set trainable parameters for an nn.Module based on a state dict of adapter parameters. Args: model (nn.Module): Instance of model class containing some adapter params. - adapter_params (Dict[str, Any]): State dict mapping adapter key names to their + adapter_params (Union[Dict[str, Any], Set]): State dict mapping adapter key names to their respective nn.Parameters (i.e. outputs of :func:`~torchtune.modules.peft.get_adapter_params`.) Returns: @@ -107,6 +109,27 @@ def get_lora_module_names( return lora_module_keys +def get_adapter_state_dict( + state_dict: Dict[str, Any], device: Optional[str] = "cpu" +) -> Dict[str, Any]: + """ + Return the subset of the full state_dict from a model that correspond to an adapter. + Assumes that "lora" and "magnitude" are unique names for adapter parameters, and + that the state_dict is not sharded. All returned parameters are moved to CPU. + + Args: + state_dict (Dict[str, Any]): Full model state dict. + device (Optional[str]): device to move adapter parameters to. Default: 'cpu' + + Returns: + Dict[str, Any]: the subset of model's state dict containing + only adapter parameters. + + """ + adapter_key_filter = lambda x: "lora" in x or "magnitude" in x + return {k: v.to(device) for k, v in state_dict.items() if adapter_key_filter(k)} + + def validate_state_dict_for_lora( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool, diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index a1e1cdbd73..06ec1b5312 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -12,8 +12,8 @@ from torchtune.training._distributed import ( contains_fsdp, FSDPPolicyType, + gather_cpu_state_dict, get_full_finetune_fsdp_wrap_policy, - get_full_model_state_dict, get_full_optimizer_state_dict, get_shard_conditions, get_world_size_and_rank, @@ -120,7 +120,7 @@ "FSDPPolicyType", "get_full_finetune_fsdp_wrap_policy", "lora_fsdp_wrap_policy", - "get_full_model_state_dict", + "gather_cpu_state_dict", "get_full_optimizer_state_dict", "load_from_full_model_state_dict", "load_from_full_optimizer_state_dict", diff --git a/torchtune/training/_activation_offloading.py b/torchtune/training/_activation_offloading.py index bee9adce6d..6880c78f9b 100644 --- a/torchtune/training/_activation_offloading.py +++ b/torchtune/training/_activation_offloading.py @@ -289,19 +289,43 @@ def wait_and_del_remaining_references() -> None: # Stash the tensor to keep memory alive until compute stream is complete self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor + # Note: [Track views of the unpacked] + # Why do we get the use count of the unpacked tensor here? We want an + # initial count to compare to later, during the post-hook of the + # backward node, when we need to decide whether we're allowed to free + # the tensor yet. In what obscure cases must we delay freeing the + # tensor (and thus call record_stream)? + # 1. Any of the outputs of the backward node is a view of the unpacked + # tensor. + # 2. In the case that this unpacked tensor will be used in a + # checkpointed region, if one of the recomputed saved tensors ends + # up as a view of the unpacked tensor. + # 3. The user abuses the system somehow and manually relies on the + # unpacked tensor to exist after the backward node has executed. + storage_refcount = torch._C._storage_Use_Count( + maybe_gpu_tensor.untyped_storage()._cdata + ) + def hook(outputs, inputs): # create events for the current node inputs/outputs if they were streamed in if brought_back_from_cpu: - # if any of the outputs is a view of the tensor, meaning the tensor might be used later, - # we cannot presume to delete it after only the current node is done! So we use our frenemy, - # record_stream, to ensure the Tensor stays unmessed with until it's done getting used - # in the compute stream (s0 here). Note that the con here is we introduce non-deterministic - # memory usage, but this case should not happen often. + # See Note: [Track views of the unpacked] + # IF any of the outputs is a view of the tensor, OR if a view of + # the tensor has been saved as a part of checkpoint's recompute + # process, OR the user has abusedly incurred a reference on the + # unpacked tensor, THEN the tensor might be used later and we + # cannot presume to delete it after only the current node is + # done! So we use our frenemy, record_stream, to ensure the + # Tensor stays unmessed with until it's done getting used in the + # compute stream (s0 here). Note that the con here is we introduce + # non-deterministic (thus higher) memory usage, but this case + # should not happen often. unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id] - if any( - o.untyped_storage() is unpacked_tensor.untyped_storage() - for o in outputs - if o is not None + if ( + torch._C._storage_Use_Count( + unpacked_tensor.untyped_storage()._cdata + ) + > storage_refcount ): unpacked_tensor.record_stream(self.s0) del self.bwd_tensor_stash[unpack_tensor_id] diff --git a/torchtune/training/_compile.py b/torchtune/training/_compile.py index 668df921c5..922aef0997 100644 --- a/torchtune/training/_compile.py +++ b/torchtune/training/_compile.py @@ -61,7 +61,7 @@ def compile_model( model.compile(backend=backend) -def compile_loss(loss: nn.Module, verbose: bool = True) -> None: +def compile_loss(loss: nn.Module, verbose: bool = True) -> nn.Module: """ Utility to compile and return loss function. If the loss function is chunked cross-entropy, we only compile the upcast + cross-entropy calculation, not the chunking. For other losses diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index aa68364f88..3511662442 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -17,9 +17,6 @@ from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard from torch.distributed._tensor import distribute_tensor, DTensor from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - _CHECKPOINT_WRAPPED_MODULE, -) from torch.distributed.checkpoint.state_dict import _init_optim_state from torch.distributed.fsdp import ShardingStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy @@ -300,7 +297,9 @@ def load_from_full_model_state_dict( for param_name, full_tensor in full_sd.items(): sharded_meta_param = meta_sharded_sd.get(param_name) full_tensor = full_tensor.to(sharded_meta_param.dtype).to(device) - if isinstance(sharded_meta_param._local_tensor, NF4Tensor): + if hasattr(sharded_meta_param, "_local_tensor") and isinstance( + sharded_meta_param._local_tensor, NF4Tensor + ): full_tensor = to_nf4(full_tensor) # replicating logic from `_fsdp_param.py`` `_init_sharded_param` # otherwise `distribute_tensor(DTensor(local=NF4))` @@ -346,88 +345,58 @@ def load_from_full_model_state_dict( return model.load_state_dict(sharded_sd, strict=strict, assign=True) -def get_full_model_state_dict( - model: "FSDPModule", # noqa +def gather_cpu_state_dict( + sharded_sd: Dict[str, DTensor], # noqa is_rank_zero: bool, device: Optional[torch.device] = None, - trainable_only: bool = False, ) -> Dict[str, Any]: """ Converting sharded state dict into a full state dict on CPU - Returning non-empty result on rank0 to avoid peaking CPU memory + Returning non-empty result only on rank0 to avoid peaking CPU memory Args: - model (FSDPModule): wrapped module + sharded_sd (Dict[str, DTensor]): Sharded state dict of DTensors is_rank_zero (bool): flag to check if the process is on rank 0 device (Optional[torch.device]): device to use for sharded tensors. Default: None - trainable_only (bool): flag to check if only trainable parameters should be returned. Default: False - - Raises: - AssertionError: if the model contains NF4Tensor and the model is not wrapped with FSDP Returns: Dict[str, Any]: State dict on CPU """ - # [Warning] FSDPModel.state_dict converts all Parameter Tensors to DTensors - sharded_sd = model.state_dict() cpu_state_dict = {} - has_nf4 = any( - isinstance(param._local_tensor, NF4Tensor) for param in model.parameters() - ) - if has_nf4: - from torch.distributed._composable.fsdp.fully_shard import FSDPModule - - # Iterating from lowerer modules to higher - # Unsharding lora adapters before unsharding transformer block - for module_name, module in reversed(list(model.named_modules())): - if not isinstance(module, FSDPModule): - continue - module.unshard(async_op=False) - if is_rank_zero: - module_name = module_name.replace(f".{_CHECKPOINT_WRAPPED_MODULE}", "") - for local_fqn, param in module.named_parameters(): - local_fqn = local_fqn.replace(f".{_CHECKPOINT_WRAPPED_MODULE}", "") - if len(module_name) > 0: - full_fqn = module_name + "." + local_fqn - else: - full_fqn = local_fqn - if trainable_only and not param.requires_grad: - # skip trainable params when trainable_only is True - continue - if full_fqn in cpu_state_dict: - # Iterate over every param in every module bottoms-up - # When lower TransformerBlock gets unsharded, - # we insert (full_fqn, full_tensor) into cpu_state_dict. - # When higher Transformer gets unsharded, we avoid updating - # params from lower TransformerBlockonly again. Instead, only updating - # tok_embeddings etc that belongs to Transformer - continue - if isinstance(param, NF4Tensor): - # upcasting NF4 to original dtype - param = param.to(param.dtype) - if isinstance(param, DTensor): - raise AssertionError( - f"Internal error: expect unsharded {full_fqn} in plain torch.Tensor but got DTensor." - " Might be a bug in get_full_model_state_dict" - ) - cpu_state_dict[full_fqn] = param.cpu() - module.reshard() - else: - for param_name, sharded_param in sharded_sd.items(): - # without this, it may hang forever for +70B models. - torch.distributed.barrier() - if sharded_param.is_cpu: - assert device is not None and device.type == "cuda", ( - f"Expect cuda but got device={device}. " - "Please call get_full_model_state_dict(..., device=self._device)," - " so DTensor can communicate over NCCL." + for param_name, sharded_param in sharded_sd.items(): + if sharded_param.is_cpu: + # Move back to device if offloaded to CPU + sharded_param = sharded_param.to(device) + if isinstance(sharded_param._local_tensor, NF4Tensor): + # NF4Tensor does not support all_gather from DTensor + # so we need to manually all_gather + mesh = sharded_param.device_mesh + nf4_tensor = sharded_param._local_tensor + quant_params, metadata = nf4_tensor.fsdp_pre_all_gather(mesh) + full_quant_params = [] + for quant_param in quant_params: + d0, *dn = quant_param.shape + shape = (d0 * mesh.get_group().size(), *dn) + full_quant_param = torch.empty( + shape, device=quant_param.device, dtype=quant_param.dtype + ) + dist.all_gather_into_tensor( + full_quant_param, quant_param, mesh.get_group(), async_op=False ) - sharded_param = sharded_param.to(device) + full_quant_params.append(full_quant_param) + full_param, _ = nf4_tensor.fsdp_post_all_gather( + full_quant_params, metadata, nf4_tensor.dtype + ) + # upcasting NF4 to original dtype + full_param = full_param.to(full_param.dtype) + else: + # Gather DTensor full_param = sharded_param.full_tensor() - if is_rank_zero: - cpu_state_dict[param_name] = full_param.cpu() - else: - del full_param + if is_rank_zero: + cpu_state_dict[param_name] = full_param.cpu() + else: + del full_param + torch.distributed.barrier() return cpu_state_dict diff --git a/torchtune/training/_grad_scaler.py b/torchtune/training/_grad_scaler.py index aab938bc90..484cd8f372 100644 --- a/torchtune/training/_grad_scaler.py +++ b/torchtune/training/_grad_scaler.py @@ -21,6 +21,11 @@ def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None: Outputs: None (grad fields are modified in place) """ + device = None for p in model.parameters(): + # First ensure scaler is on the same device as the model + if not device: + device = p.device + scaler = scaler.to(device) if p.grad is not None: p.grad *= scaler diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index 91e74649b9..94a315cafc 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -111,7 +111,7 @@ class FullModelTorchTuneCheckpointer(_CheckpointerInterface): checkpoint_dir (str): Directory containing the checkpoint files checkpoint_files (List[str]): List of checkpoint files to load. Since the checkpointer takes care of sorting by file ID, the order in this list does not matter - model_type (ModelType): Model type of the model for which the checkpointer is being loaded + model_type (str): Model type of the model for which the checkpointer is being loaded output_dir (str): Directory to save the checkpoint files adapter_checkpoint (Optional[str]): Path to the adapter weights. Default is None recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. Default is None @@ -130,7 +130,7 @@ def __init__( self, checkpoint_dir: str, checkpoint_files: List[str], - model_type: ModelType, + model_type: str, output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, @@ -159,7 +159,7 @@ def __init__( ) self._resume_from_checkpoint = resume_from_checkpoint - self._model_type = model_type + self._model_type = ModelType[model_type] self._output_dir = Path(output_dir) # recipe_checkpoint contains the recipe state. This should be available if @@ -277,7 +277,7 @@ def save_checkpoint( # If the recipe state needs to be output, first remove the model state dict if intermediate_checkpoint: - _ = state_dict.pop(training.MODEL_KEY) + _ = state_dict.pop(training.MODEL_KEY, None) _ = state_dict.pop(training.ADAPTER_KEY, None) _ = state_dict.pop(training.ADAPTER_CONFIG, None) output_path = Path.joinpath(self._output_dir, "recipe_state.pt") @@ -322,7 +322,7 @@ class FullModelHFCheckpointer(_CheckpointerInterface): checkpoint_dir (str): Directory containing the checkpoint files checkpoint_files (Union[List[str], Dict[str, str]]): List of checkpoint files to load. Since the checkpointer takes care of sorting by file ID, the order in this list does not matter. TODO: update this - model_type (ModelType): Model type of the model for which the checkpointer is being loaded + model_type (str): Model type of the model for which the checkpointer is being loaded output_dir (str): Directory to save the checkpoint files adapter_checkpoint (Optional[str]): Path to the adapter weights. Default is None recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. Default is None @@ -338,7 +338,7 @@ def __init__( self, checkpoint_dir: str, checkpoint_files: Union[List[str], Dict[str, str]], - model_type: ModelType, + model_type: str, output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, @@ -488,6 +488,16 @@ def load_checkpoint(self) -> Dict[str, Any]: "supported_aspect_ratios", None ), ) + elif self._model_type == ModelType.GEMMA2: + from torchtune.models.gemma2._convert_weights import gemma2_hf_to_tune + + converted_state_dict[training.MODEL_KEY] = gemma2_hf_to_tune( + merged_state_dict, + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + head_dim=self._config.get("head_dim", None), + ) else: converted_state_dict[training.MODEL_KEY] = convert_weights.hf_to_tune( merged_state_dict, @@ -578,6 +588,16 @@ def save_checkpoint( "supported_aspect_ratios", None ), ) + elif self._model_type == ModelType.GEMMA2: + from torchtune.models.gemma2._convert_weights import gemma2_tune_to_hf + + state_dict[training.MODEL_KEY] = gemma2_tune_to_hf( + state_dict[training.MODEL_KEY], + num_heads=self._config["num_attention_heads"], + num_kv_heads=self._config["num_key_value_heads"], + dim=self._config["hidden_size"], + head_dim=self._config.get("head_dim", None), + ) else: state_dict[training.MODEL_KEY] = convert_weights.tune_to_hf( state_dict[training.MODEL_KEY], @@ -733,7 +753,7 @@ class FullModelMetaCheckpointer(_CheckpointerInterface): checkpoint_dir (str): Directory containing the checkpoint files checkpoint_files (List[str]): List of checkpoint files to load. Currently this checkpointer only supports loading a single checkpoint file. - model_type (ModelType): Model type of the model for which the checkpointer is being loaded + model_type (str): Model type of the model for which the checkpointer is being loaded output_dir (str): Directory to save the checkpoint files adapter_checkpoint (Optional[str]): Path to the adapter weights. Default is None recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. Default is None @@ -749,7 +769,7 @@ def __init__( self, checkpoint_dir: str, checkpoint_files: List[str], - model_type: ModelType, + model_type: str, output_dir: str, adapter_checkpoint: Optional[str] = None, recipe_checkpoint: Optional[str] = None, @@ -899,7 +919,7 @@ def save_checkpoint( # If the recipe state needs to be output, first remove the model state dict # and if it exists, remove the adapter state dict as well if intermediate_checkpoint: - _ = state_dict.pop(training.MODEL_KEY) + _ = state_dict.pop(training.MODEL_KEY, None) _ = state_dict.pop(training.ADAPTER_KEY, None) _ = state_dict.pop(training.ADAPTER_CONFIG, None) output_path = Path.joinpath(self._output_dir, "recipe_state.pt") diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 2d353b007c..2fa7265194 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -45,6 +45,7 @@ class ModelType(Enum): Attributes: GEMMA (str): Gemma family of models. See :func:`~torchtune.models.gemma.gemma` + GEMMA2 (str): Gemma 2 family of models. See :func:`~torchtune.models.gemma2.gemma2` LLAMA2 (str): Llama2 family of models. See :func:`~torchtune.models.llama2.llama2` LLAMA3 (str): Llama3 family of models. See :func:`~torchtune.models.llama3.llama3` LLAMA3_2 (str): Llama3.2 family of models. See :func:`~torchtune.models.llama3_2.llama3_2` @@ -65,6 +66,7 @@ class ModelType(Enum): """ GEMMA: str = "gemma" + GEMMA2: str = "gemma2" LLAMA2: str = "llama2" LLAMA3: str = "llama3" LLAMA3_2: str = "llama3_2" diff --git a/torchtune/training/memory.py b/torchtune/training/memory.py index 597e135e18..6d2981b011 100644 --- a/torchtune/training/memory.py +++ b/torchtune/training/memory.py @@ -16,7 +16,7 @@ ) from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.optim.lr_scheduler import LRScheduler -from torchtune.utils import get_logger +from torchtune.utils import get_device_support, get_logger, get_torch_device_namespace _log: logging.Logger = get_logger() @@ -45,11 +45,11 @@ def set_activation_checkpointing( def cleanup_before_training() -> None: """ - Call gc collect, empty CUDA cache, and reset peak memory stats. + Call gc collect, empty device cache, and reset peak memory stats. """ gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + get_torch_device_namespace().empty_cache() + get_torch_device_namespace().reset_peak_memory_stats() class OptimizerInBackwardWrapper: @@ -260,19 +260,17 @@ def get_memory_stats(device: torch.device, reset_stats: bool = True) -> dict: Raises: ValueError: If the passed-in device is not CUDA. """ - if device.type != "cuda": - raise ValueError( - f"Logging memory stats is only supported on CUDA devices, got {device}" - ) + if device.type == "cpu": + raise ValueError("Logging memory stats is not supported on CPU devices") - peak_memory_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / ( + torch_device = get_torch_device_namespace() + peak_memory_active = torch_device.memory_stats().get("active_bytes.all.peak", 0) / ( 1024**3 ) - peak_mem_alloc = torch.cuda.max_memory_allocated(device) / (1024**3) - peak_mem_reserved = torch.cuda.max_memory_reserved(device) / (1024**3) - + peak_mem_alloc = torch_device.max_memory_allocated(device) / (1024**3) + peak_mem_reserved = torch_device.max_memory_reserved(device) / (1024**3) if reset_stats: - torch.cuda.reset_peak_memory_stats(device) + torch_device.reset_peak_memory_stats(device) memory_stats = { "peak_memory_active": peak_memory_active, @@ -292,9 +290,10 @@ def log_memory_stats(stats: Dict[str, float]) -> None: stats (Dict[str, float]): A dictionary containing the peak memory active, peak memory allocated, and peak memory reserved stats. """ + device_support = get_device_support() _log.info( "Memory stats after model init:" - f"\n\tGPU peak memory allocation: {stats['peak_memory_alloc']:.2f} GiB" - f"\n\tGPU peak memory reserved: {stats['peak_memory_reserved']:.2f} GiB" - f"\n\tGPU peak memory active: {stats['peak_memory_active']:.2f} GiB" + f"\n\t{device_support.device_name} peak memory allocation: {stats['peak_memory_alloc']:.2f} GiB" + f"\n\t{device_support.device_name} peak memory reserved: {stats['peak_memory_reserved']:.2f} GiB" + f"\n\t{device_support.device_name} peak memory active: {stats['peak_memory_active']:.2f} GiB" ) diff --git a/torchtune/training/precision.py b/torchtune/training/precision.py index 7cda05caa0..6da300be72 100644 --- a/torchtune/training/precision.py +++ b/torchtune/training/precision.py @@ -10,9 +10,11 @@ import torch from torchtune.utils import get_logger +from torchtune.utils._device import is_npu_available log = get_logger() + PRECISION_STR_TO_DTYPE: Dict[str, torch.dtype] = { "fp16": torch.float16, "bf16": torch.bfloat16, @@ -50,6 +52,7 @@ def verify_bf16_support() -> bool: - CUDA compute capability >= 8 - NCCL is available and version >= 2.10 - MPS is available and torch was built with MPS + - NPU is available and supports bf16 Returns: bool: True if bf16 is available, False otherwise. @@ -62,7 +65,8 @@ def verify_bf16_support() -> bool: and torch.cuda.nccl.version() >= (2, 10) ) mps_support = torch.backends.mps.is_available() and torch.backends.mps.is_built() - return cuda_support or mps_support + npu_support = is_npu_available and torch.npu.is_bf16_supported() + return cuda_support or mps_support or npu_support def get_dtype( diff --git a/torchtune/utils/__init__.py b/torchtune/utils/__init__.py index 28e28753fc..59de1b5aa7 100644 --- a/torchtune/utils/__init__.py +++ b/torchtune/utils/__init__.py @@ -4,8 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._device import batch_to_device, get_device -from ._logging import get_logger +from ._device import ( + batch_to_device, + DeviceSupport, + get_device, + get_device_support, + get_torch_device_namespace, +) +from ._logging import get_logger, log_rank_zero from ._version import torch_version_ge @@ -14,4 +20,8 @@ "get_device", "get_logger", "torch_version_ge", + "get_device_support", + "get_torch_device_namespace", + "DeviceSupport", + "log_rank_zero", ] diff --git a/torchtune/utils/_device.py b/torchtune/utils/_device.py index beefb0dbfa..36ca14a358 100644 --- a/torchtune/utils/_device.py +++ b/torchtune/utils/_device.py @@ -5,11 +5,15 @@ # LICENSE file in the root directory of this source tree. import os +from enum import Enum from typing import Optional import torch from torchtune.utils._import_guard import _SUPPORTS_FLEX_ATTENTION +from torchtune.utils._logging import get_logger + +logger = get_logger("DEBUG") if _SUPPORTS_FLEX_ATTENTION: from torch.nn.attention.flex_attention import BlockMask @@ -17,6 +21,19 @@ BlockMask = torch.Tensor +def is_torch_npu_available() -> bool: + """Check the availability of NPU""" + try: + import torch_npu # noqa: F401 + + return torch.npu.is_available() + except ImportError: + return False + + +is_npu_available = is_torch_npu_available() + + def _get_local_rank() -> Optional[int]: """Function that gets the local rank from the environment. @@ -29,8 +46,8 @@ def _get_local_rank() -> Optional[int]: return local_rank -def _setup_cuda_device(device: torch.device) -> torch.device: - """Function that sets the CUDA device and infers the cuda +def _setup_device(device: torch.device) -> torch.device: + """Function that sets the CUDA-like device and infers the device index if not set. Args: @@ -43,29 +60,35 @@ def _setup_cuda_device(device: torch.device) -> torch.device: device """ local_rank = _get_local_rank() or 0 + device_support = get_device_support() + device_type = device_support.device_type + device_name = device_support.device_name + torch_device = get_torch_device_namespace() + if device.index is None: - device = torch.device(type="cuda", index=local_rank) + device = torch.device(type=device_type, index=local_rank) # Ensure index is available before setting device - if device.index >= torch.cuda.device_count(): + if device.index >= torch_device.device_count(): raise RuntimeError( - "The local rank is larger than the number of available GPUs." + f"The local rank is larger than the number of available {device_name}s." ) - - torch.cuda.set_device(device) + torch_device.set_device(device) return device def _get_device_type_from_env() -> str: """Function that gets the torch.device based on the current machine. - This currently only supports CPU, CUDA. + This currently only supports CPU, CUDA, NPU. Returns: device """ if torch.cuda.is_available(): device = "cuda" + elif is_npu_available: + device = "npu" else: device = "cpu" return device @@ -88,12 +111,12 @@ def _validate_device_from_env(device: torch.device) -> None: local_rank = _get_local_rank() # Check if the device index is correct - if device.type == "cuda" and local_rank is not None: + if device.type != "cpu" and local_rank is not None: # Ensure device index matches assigned index when distributed training if device.index != local_rank: raise RuntimeError( - f"You can't specify a device index when using distributed training. \ - Device specified is {device} but was assigned cuda:{local_rank}" + f"You can't specify a device index when using distributed training. " + f"Device specified is {device} but local rank is:{local_rank}" ) # Check if the device is available on this machine @@ -110,10 +133,10 @@ def get_device(device: Optional[str] = None) -> torch.device: distributed settings, and returns a :func:`~torch.device`. If device string is not provided, this function will infer the device based on the environment. - If CUDA is available and being used, this function also sets the CUDA device. + If CUDA-like is available and being used, this function also sets the CUDA-like device. Args: - device (Optional[str]): The name of the device to use, e.g. "cuda" or "cpu". + device (Optional[str]): The name of the device to use, e.g. "cuda" or "cpu" or "npu". Example: >>> device = get_device("cuda") @@ -126,8 +149,8 @@ def get_device(device: Optional[str] = None) -> torch.device: if device is None: device = _get_device_type_from_env() device = torch.device(device) - if device.type == "cuda": - device = _setup_cuda_device(device) + if device.type in ["cuda", "npu"]: + device = _setup_device(device) _validate_device_from_env(device) return device @@ -156,3 +179,63 @@ def batch_to_device(batch: dict, device: torch.device) -> None: f"""To use batch_to_device, all elements in the batch must be a dict or Tensor. Got key "{k}" with value of type {type(v)}""" ) + + +class DeviceSupport(Enum): + """ + This is a simple enum for compute devices, + This currently only supports CPU, CUDA, NPU. + The following enumeration defines various device configurations with attributes: + 1. `device_type` (str): The type of device (e.g., "cpu", "cuda", "npu"). + 2. `device_name` (str): A user-friendly name for the device (e.g., "CPU", "GPU", "NPU"). + 3. `communication_backend` (str): Specifies the backend used for communication on this device (e.g., "gloo", "nccl", "hccl"). + """ + + CPU = ("cpu", "CPU", "gloo") + CUDA = ("cuda", "GPU", "nccl") + NPU = ("npu", "NPU", "hccl") + + def __init__( + self, + device_type: str, + device_name: str, + communication_backend: str, + ): + self.device_type = device_type + self.device_name = device_name + self.communication_backend = communication_backend + + @staticmethod + def from_type(device_type: str): + for member in DeviceSupport: + if member.device_type == device_type: + return member + raise ValueError(f"Unknown device type: {device_type}.") + + +def get_device_support() -> DeviceSupport: + """function that gets the DeviceSupport with compute devices based on the current machine. + + This currently only supports CPU, CUDA, NPU. + + Returns: + device_support: DeviceSupport + """ + device_type = _get_device_type_from_env() + return DeviceSupport.from_type(device_type) + + +def get_torch_device_namespace() -> any: + """Return the corresponding torch attribute based on the device type string. + + Returns: + module: The corresponding torch device namespace, or torch.cuda if not found. + """ + device_type = get_device_support().device_type + try: + return getattr(torch, device_type) + except AttributeError: + logger.warning( + f"Device namespace '{device_type}' not found in torch, try to load torch.cuda." + ) + return torch.cuda