Skip to content

Commit

Permalink
Merge branch 'main' into fix-dtype-mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
lancerts authored Oct 14, 2024
2 parents d4504c4 + 3146916 commit c267e59
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 5 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
<a name="readme-top"></a>

# Liger Kernel: Efficient Triton Kernels for LLM Training


Expand Down Expand Up @@ -357,3 +359,15 @@ Biblatex entry:

## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date)

## Contributors

<a href="https://github.com/linkedin/Liger-Kernel/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=linkedin/Liger-Kernel"/>
</a>

<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
↑ Back to Top ↑
</a>
</p>
3 changes: 2 additions & 1 deletion src/liger_kernel/transformers/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ def lce_forward(
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/transformers/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def lce_forward(

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if self.training and (labels is not None):
Expand All @@ -116,6 +115,8 @@ def lce_forward(
lce = LigerFusedLinearCrossEntropyLoss()
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
elif labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/transformers/model/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ def lce_forward(
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/transformers/model/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ def lce_forward(

else:
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/transformers/model/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ def lce_forward(
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down

0 comments on commit c267e59

Please sign in to comment.