Skip to content

Commit

Permalink
Merge branch 'main' into ce-z-loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Tcc0403 authored Oct 27, 2024
2 parents f7083f2 + 6cdc93d commit b89f335
Show file tree
Hide file tree
Showing 22 changed files with 683 additions and 130 deletions.
57 changes: 35 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
<a name="readme-top"></a>

# Liger Kernel: Efficient Triton Kernels for LLM Training


Expand All @@ -6,6 +8,7 @@
<th style="padding: 10px;" colspan="2">Stable</th>
<th style="padding: 10px;" colspan="2">Nightly</th>
<th style="padding: 10px;">Discord</th>
<th style="padding: 10px;">Gurubase (experimental)</th>
</tr>
<tr>
<td style="padding: 10px;">
Expand Down Expand Up @@ -33,18 +36,24 @@
<img src="https://dcbadge.vercel.app/api/server/gpumode?style=flat" alt="Join Our Discord">
</a>
</td>
<td style="padding: 10px;">
<a href="https://gurubase.io/g/liger-kernel">
<img src="https://img.shields.io/badge/Gurubase-Ask%20Liger%20Kernel%20Guru-006BFF" alt="Ask Liger Kernel Guru">
</a>
</td>
</tr>
</table>



<img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">

[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Structure](#structure) | [Contributing](#contributing) | [Acknowledgement](#acknowledgement)
[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)

<details>
<summary>Latest News 🔥</summary>


- [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
- [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks!
- [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel)
- [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)
Expand Down Expand Up @@ -252,6 +261,7 @@ loss.backward()
| FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
| JSD | `liger_kernel.transformers.LigerJSD` |
| FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` |

- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
- **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
Expand All @@ -267,6 +277,8 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.


### Experimental Kernels

Expand All @@ -281,21 +293,6 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
> **Note:**
> Reported speedups and memory reductions are with respect to the LLaMA 3-8B Hugging Face layer implementations. All models use 4K hidden size and 4K sequence length and are evaluated based on memory usage and wall time for the forward+backward pass on a single NVIDIA A100 80G GPU using small batch sizes. Liger kernels exhibit more efficient scaling to larger batch sizes, detailed further in the [Benchmark](./benchmark) folder.
## Note on ML Compiler

### Torch Compile

Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). In the following example, Liger Kernel can further optimize the model on top of Torch Compile, reducing the memory by more than half.

| Configuration | Throughput (tokens/sec) | Memory Reserved (GB) |
|--------------------------------|----------------------------|-------------------------|
| Torch Compile | 3780 | 66.4 |
| Torch Compile + Liger Kernel | 3702 | 31.0 |

> **Note:**
> 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
> 2. Tested on torch `2.5.0.dev20240731+cu118`
## Contributing

[CONTRIBUTING GUIDE](https://github.com/linkedin/Liger-Kernel/blob/main/CONTRIBUTING.md)
Expand Down Expand Up @@ -347,13 +344,29 @@ It also includes components from projects licensed under:

Biblatex entry:
```bib
@software{liger2024,
title = {Liger-Kernel: Efficient Triton Kernels for LLM Training},
author = {Hsu, Pin-Lun and Dai, Yun and Kothapalli, Vignesh and Song, Qingquan and Tang, Shao and Zhu, Siyu},
url = {https://github.com/linkedin/Liger-Kernel},
year = {2024}
@article{hsu2024ligerkernelefficienttriton,
title={Liger Kernel: Efficient Triton Kernels for LLM Training},
author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen},
year={2024},
eprint={2410.10989},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2410.10989},
journal={arXiv preprint arXiv:2410.10989},
}
```

## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date)

## Contributors

<a href="https://github.com/linkedin/Liger-Kernel/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=linkedin/Liger-Kernel"/>
</a>

<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
↑ Back to Top ↑
</a>
</p>
52 changes: 37 additions & 15 deletions benchmark/scripts/benchmark_fused_linear_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,40 @@


class TorchJSD(torch.nn.Module):
def __init__(self, beta: float = 0.5, dtype: torch.dtype = torch.float):
def __init__(
self,
beta: float = 0.5,
ignore_index: int = -100,
dtype: torch.dtype = torch.float,
):
super(TorchJSD, self).__init__()
self.kl = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True)
self.beta = beta
self.ignore_index = ignore_index
self.dtype = dtype

def forward(
self,
log_q: torch.tensor, # input
log_p: torch.tensor, # target
log_q: torch.Tensor, # input
log_p: torch.Tensor, # target
label=None,
):
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
m = torch.lerp(torch.exp(log_p), torch.exp(log_q), self.beta)
loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl(
torch.log(m), log_q
)
m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (
1 - self.beta
) * self.kl(torch.log(m), log_q).sum(dim=-1)

if label is not None:
loss = torch.where(label != self.ignore_index, loss, 0.0)
n_non_ignore = (label != self.ignore_index).sum().item()
if n_non_ignore == 0:
loss = 0.0
else:
loss = (loss / n_non_ignore).sum()
else:
loss = (loss / log_q.shape[0]).sum()
return loss.to(self.dtype)


Expand All @@ -48,8 +65,9 @@ def __init__(
V: int,
dtype: torch.dtype,
device: torch.device,
temperature: float = 1.0,
beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
):
super().__init__()
self.student_lin = torch.nn.Linear(
Expand All @@ -58,16 +76,16 @@ def __init__(
self.teacher_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype, device=device
)
self.jsd = TorchJSD(beta, dtype=dtype)
self.jsd = TorchJSD(beta=beta, ignore_index=ignore_index, dtype=dtype)
self.temperature = temperature

def forward(self, student_input, teacher_input):
def forward(self, student_input, teacher_input, label=None):
student_logits = self.student_lin(student_input)
teacher_logits = self.teacher_lin(teacher_input)
student_prob = torch.log_softmax(student_logits / self.temperature, dim=-1)
teacher_prob = torch.log_softmax(teacher_logits / self.temperature, dim=-1)

return self.jsd(student_prob, teacher_prob)
return self.jsd(student_prob, teacher_prob, label)


class LigerLMHeadJSD(torch.nn.Module):
Expand All @@ -77,8 +95,9 @@ def __init__(
V: int,
dtype: torch.dtype,
device: torch.device,
temperature: float = 1.0,
beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
):
super().__init__()
self.student_lin = torch.nn.Linear(
Expand All @@ -87,14 +106,17 @@ def __init__(
self.teacher_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype, device=device
)
self.fused_jsd = LigerFusedLinearJSD(beta, temperature)
self.fused_jsd = LigerFusedLinearJSD(
jsd_beta=beta, ignore_index=ignore_index, temperature=temperature
)

def forward(self, student_input, teacher_input):
def forward(self, student_input, teacher_input, label=None):
return self.fused_jsd(
student_input,
self.student_lin.weight,
teacher_input,
self.teacher_lin.weight,
label,
)


Expand Down
36 changes: 26 additions & 10 deletions benchmark/scripts/benchmark_jsd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.nn as nn
import triton
from utils import (
QUANTILES,
Expand All @@ -13,24 +12,41 @@
from liger_kernel.transformers.jsd import LigerJSD


class TorchJSD(nn.Module):
def __init__(self, beta: float = 0.5, dtype: torch.dtype = torch.float):
class TorchJSD(torch.nn.Module):
def __init__(
self,
beta: float = 0.5,
ignore_index: int = -100,
dtype: torch.dtype = torch.float,
):
super(TorchJSD, self).__init__()
self.kl = nn.KLDivLoss(reduction="batchmean", log_target=True)
self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True)
self.beta = beta
self.ignore_index = ignore_index
self.dtype = dtype

def forward(
self,
log_q: torch.tensor, # input
log_p: torch.tensor, # target
log_q: torch.Tensor, # input
log_p: torch.Tensor, # target
label=None,
):
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
m = torch.lerp(torch.exp(log_p), torch.exp(log_q), self.beta)
loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl(
torch.log(m), log_q
)
m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (
1 - self.beta
) * self.kl(torch.log(m), log_q).sum(dim=-1)

if label is not None:
loss = torch.where(label != self.ignore_index, loss, 0.0)
n_non_ignore = (label != self.ignore_index).sum().item()
if n_non_ignore == 0:
loss = 0.0
else:
loss = (loss / n_non_ignore).sum()
else:
loss = (loss / log_q.shape[0]).sum()
return loss.to(self.dtype)


Expand Down
8 changes: 4 additions & 4 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import triton

from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
from liger_kernel.ops.utils import element_mul_kernel
from liger_kernel.ops.utils import amp_custom_bwd, amp_custom_fwd, element_mul_kernel

# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
Expand All @@ -20,9 +20,7 @@ def fused_linear_cross_entropy_forward(
label_smoothing=0.0,
reduction="mean",
):
dtype = (
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype
)
dtype = _input.dtype
device = _input.device

# inputs have shape: BT x H
Expand Down Expand Up @@ -193,6 +191,7 @@ def fused_linear_cross_entropy_backward(

class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
@staticmethod
@amp_custom_fwd
def forward(
ctx,
_input,
Expand Down Expand Up @@ -240,6 +239,7 @@ def forward(
return loss

@staticmethod
@amp_custom_bwd
def backward(ctx, grad_output):
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
Expand Down
Loading

0 comments on commit b89f335

Please sign in to comment.