From 04d5a0e1d442439c65170cc67b112eba42dc37ee Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Mon, 14 Oct 2024 14:08:03 -0700 Subject: [PATCH 1/2] Move `logits.float()` call (#308) ## Summary The analogous `logits.float()` calls were moved in the Hugging Face modeling source code to be inside the `if labels is not None` block to avoid upcasting logits unless they are being used in a loss calculation; this avoids a memory spike during inference if the model is in lower precision. * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/llama/modeling_llama.py#L1211-L1212 * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/mixtral/modeling_mixtral.py#L1329-L1330 * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/phi3/modeling_phi3.py#L1303-L1304 * https://github.com/huggingface/transformers/blob/37ea04013b34b39c01b51aeaacd8d56f2c62a7eb/src/transformers/models/qwen2/modeling_qwen2.py#L1206-L1207 Some of your models already have this change: https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/transformers/model/mistral.py#L114-L116 https://github.com/linkedin/Liger-Kernel/blob/ff6650bbcef5d31b7522694cbeb73a21169460e9/src/liger_kernel/transformers/model/gemma.py#L114-L116 See also: * https://github.com/huggingface/transformers/issues/30860 ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --- src/liger_kernel/transformers/model/llama.py | 3 ++- src/liger_kernel/transformers/model/mixtral.py | 3 ++- src/liger_kernel/transformers/model/phi3.py | 3 ++- src/liger_kernel/transformers/model/qwen2.py | 3 ++- src/liger_kernel/transformers/model/qwen2_vl.py | 3 ++- 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/liger_kernel/transformers/model/llama.py b/src/liger_kernel/transformers/model/llama.py index 9cf6ed44..d0a5daee 100644 --- a/src/liger_kernel/transformers/model/llama.py +++ b/src/liger_kernel/transformers/model/llama.py @@ -120,8 +120,9 @@ def lce_forward( logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) - logits = logits.float() if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/liger_kernel/transformers/model/mixtral.py b/src/liger_kernel/transformers/model/mixtral.py index f449284c..ce022b0d 100644 --- a/src/liger_kernel/transformers/model/mixtral.py +++ b/src/liger_kernel/transformers/model/mixtral.py @@ -103,7 +103,6 @@ def lce_forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if self.training and (labels is not None): @@ -116,6 +115,8 @@ def lce_forward( lce = LigerFusedLinearCrossEntropyLoss() loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) elif labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/liger_kernel/transformers/model/phi3.py b/src/liger_kernel/transformers/model/phi3.py index 4cb7ec0e..bd08eeb7 100644 --- a/src/liger_kernel/transformers/model/phi3.py +++ b/src/liger_kernel/transformers/model/phi3.py @@ -108,10 +108,11 @@ def lce_forward( loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) else: logits = self.lm_head(hidden_states) - logits = logits.float() loss = None if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/liger_kernel/transformers/model/qwen2.py b/src/liger_kernel/transformers/model/qwen2.py index b8e9957e..f317d418 100644 --- a/src/liger_kernel/transformers/model/qwen2.py +++ b/src/liger_kernel/transformers/model/qwen2.py @@ -109,8 +109,9 @@ def lce_forward( else: logits = self.lm_head(hidden_states) - logits = logits.float() if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() diff --git a/src/liger_kernel/transformers/model/qwen2_vl.py b/src/liger_kernel/transformers/model/qwen2_vl.py index eb5709f6..6f56000c 100644 --- a/src/liger_kernel/transformers/model/qwen2_vl.py +++ b/src/liger_kernel/transformers/model/qwen2_vl.py @@ -150,8 +150,9 @@ def lce_forward( loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) else: logits = self.lm_head(hidden_states) - logits = logits.float() if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() From 31469169eb286b792d96f2e92d4fcff47538d01d Mon Sep 17 00:00:00 2001 From: barbarian360 <94866865+barbarian360@users.noreply.github.com> Date: Tue, 15 Oct 2024 02:38:57 +0530 Subject: [PATCH 2/2] Added contributors and back to top (#304) ## Summary Added the contributors section in the readme and also added the back to top button. --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index cb42e445..bbd8d03d 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ + + # Liger Kernel: Efficient Triton Kernels for LLM Training @@ -357,3 +359,15 @@ Biblatex entry: ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date) + +## Contributors + + + contributors + + +

+ + ↑ Back to Top ↑ + +