Skip to content

Commit

Permalink
Add gemma lightning example for single L40 GPU (#120)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

Add gemma lightning example for single L40 GPU

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Shao Tang <tangshao28@gmail.com>
Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
  • Loading branch information
3 people authored Aug 28, 2024
1 parent 3f922a1 commit 8089f1e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 14 deletions.
9 changes: 7 additions & 2 deletions examples/lightning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
## How to Run
```bash
pip install -r requirements.txt
python training.py

# For single L40 48GB GPU
python training.py --model google/gemma-2b --num_gpu 1 --strategy ddp

# For 8XA100 40GB
python training.py --model meta-llama/Meta-Llama-3-8B --strategy deepspeed
```

**Notes**
1. The example uses Llama3 model that requires community license agreement and HuggingFace Hub login. If you want to use Llama3 in this example, please make sure you have done the followings:
* Agree on the community license agreement https://huggingface.co/meta-llama/Meta-Llama-3-8B
* Run `huggingface-cli login` and enter your HuggingFace token
2. The default hyperparameters and configurations work on single node with 8xA100 80GB GPUs. For running on device with less GPU RAM, please consider reducing the per-GPU batch size and/or enable `CPUOffload` in FSDP.
2. The default hyperparameters and configurations for gemma works on single L40 48GB GPU and config for llama work on single node with 8xA100 40GB GPUs. For running on device with less GPU RAM, please consider reducing the per-GPU batch size and/or enable `CPUOffload` in FSDP.


<!-- Benchmark TBD -->
49 changes: 37 additions & 12 deletions examples/lightning/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,38 @@
from lightning.pytorch.strategies import DeepSpeedStrategy, FSDPStrategy
from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision
from torch.utils.data import DataLoader
from transformers.models.gemma.modeling_gemma import GemmaDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from trl import DataCollatorForCompletionOnlyLM

from liger_kernel.transformers import apply_liger_kernel_to_llama
from liger_kernel.transformers import (
apply_liger_kernel_to_gemma,
apply_liger_kernel_to_llama,
)

apply_liger_kernel_to_llama(fused_linear_cross_entropy=True, cross_entropy=False)
apply_liger_kernel_to_gemma(fused_linear_cross_entropy=True, cross_entropy=False)


_RETAIN_COLUMNS = {"input_ids", "attention_mask", "labels"}
QUESTION = "<Question>"
CHOICES = "<Choices>"
ANSWER = "<Answer>"


@dataclass
class Args:
model: str = "meta-llama/Meta-Llama-3-8B"
model: str = "google/gemma-2b"
data: str = "cais/mmlu"
output_dir: str = "mmlu_finetuning"
max_length: int = 2048
# deepspeed will OOM with 16
batch_size: int = 8
# for llam3 8B model, deepspeed will OOM with 16 on 8XA100 80G and 8 will OOM with 8XA100 40G
batch_size: int = 4
lr: float = 6e-6
weight_decay: float = 0.05
warmup_ratio: float = 0.1
seed: int = 42
strategy: str = "deepspeed"
strategy: str = "ddp"
num_gpu: int = None


def warmup_cosine_schedule(warmup_steps, total_steps, min_lr=0):
Expand Down Expand Up @@ -148,7 +153,12 @@ def __init__(self, tokenizer, args: Args):
super().__init__()
self.args = args
self.tokenizer = tokenizer
response_prompt = tokenizer.encode(f" {ANSWER}", add_special_tokens=False)
self.response_template_str = (
" <Answer>" if "Meta-Llama-3-8B" in self.args.model else "<Answer>"
)
response_prompt = tokenizer.encode(
f"{self.response_template_str}", add_special_tokens=False
)
self.collator = DataCollatorForCompletionOnlyLM(
tokenizer=tokenizer,
response_template=response_prompt,
Expand All @@ -164,7 +174,7 @@ def formatting_func(self, example):
s = "Below is a question and multiple choice answers, choices separated by a semicolon. Please select the best answer for the question. "
s += f"{QUESTION}{example['question'][i]} "
s += f"{CHOICES}{choices} "
s += f"{ANSWER}{example['answer'][i]}"
s += f"{self.response_template_str}{example['answer'][i]}"
output_texts.append(s)
return output_texts

Expand Down Expand Up @@ -229,7 +239,16 @@ def train():
pl.seed_everything(args.seed)
os.makedirs(args.output_dir, exist_ok=True)

layers = {LlamaDecoderLayer}
if "Meta-Llama-3-8B" in args.model:
layers = {LlamaDecoderLayer}
elif "gemma" in args.model:
layers = {GemmaDecoderLayer}
else:
layers = {}
raise Warning(
f"Unimplemented layer wrap policy for {args.model} in this example"
)

if args.strategy == "fsdp":
strategy = FSDPStrategy(
auto_wrap_policy=layers,
Expand All @@ -242,16 +261,22 @@ def train():
),
forward_prefetch=True,
)
else:
precision = None
elif args.strategy == "deepspeed":
strategy = DeepSpeedStrategy(stage=3)
precision = "bf16-mixed"
else:
strategy = "ddp"
precision = "bf16-true"

trainer = pl.Trainer(
accelerator="cuda",
strategy=strategy,
devices=torch.cuda.device_count(),
devices=torch.cuda.device_count() if args.num_gpu is None else args.num_gpu,
default_root_dir=args.output_dir,
log_every_n_steps=1,
max_epochs=1,
precision=None if args.strategy == "fsdp" else "bf16-mixed",
precision=precision,
)

tokenizer = transformers.AutoTokenizer.from_pretrained(
Expand Down

0 comments on commit 8089f1e

Please sign in to comment.