Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
ParagEkbote authored Oct 28, 2024
2 parents b41a343 + 6cdc93d commit 7ab2adf
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 25 deletions.
15 changes: 0 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,21 +293,6 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
> **Note:**
> Reported speedups and memory reductions are with respect to the LLaMA 3-8B Hugging Face layer implementations. All models use 4K hidden size and 4K sequence length and are evaluated based on memory usage and wall time for the forward+backward pass on a single NVIDIA A100 80G GPU using small batch sizes. Liger kernels exhibit more efficient scaling to larger batch sizes, detailed further in the [Benchmark](./benchmark) folder.
## Note on ML Compiler

### Torch Compile

Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). In the following example, Liger Kernel can further optimize the model on top of Torch Compile, reducing the memory by more than half.

| Configuration | Throughput (tokens/sec) | Memory Reserved (GB) |
|--------------------------------|----------------------------|-------------------------|
| Torch Compile | 3780 | 66.4 |
| Torch Compile + Liger Kernel | 3702 | 31.0 |

> **Note:**
> 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
> 2. Tested on torch `2.5.0.dev20240731+cu118`
## Contributing

[CONTRIBUTING GUIDE](https://github.com/linkedin/Liger-Kernel/blob/main/CONTRIBUTING.md)
Expand Down
8 changes: 4 additions & 4 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import triton

from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel
from liger_kernel.ops.utils import element_mul_kernel
from liger_kernel.ops.utils import amp_custom_bwd, amp_custom_fwd, element_mul_kernel

# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
Expand All @@ -19,9 +19,7 @@ def fused_linear_cross_entropy_forward(
label_smoothing=0.0,
reduction="mean",
):
dtype = (
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype
)
dtype = _input.dtype
device = _input.device

# inputs have shape: BT x H
Expand Down Expand Up @@ -189,6 +187,7 @@ def fused_linear_cross_entropy_backward(

class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
@staticmethod
@amp_custom_fwd
def forward(
ctx,
_input,
Expand Down Expand Up @@ -228,6 +227,7 @@ def forward(
return loss

@staticmethod
@amp_custom_bwd
def backward(ctx, grad_output):
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
Expand Down
13 changes: 13 additions & 0 deletions src/liger_kernel/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import functools
import importlib
import operator
from typing import Callable

import torch
Expand Down Expand Up @@ -63,6 +64,18 @@ def compare_version(package: str, operator: Callable, target: str):
return operator(pkg_version, Version(target))


def get_amp_custom_fwd_bwd() -> Callable:
if compare_version("torch", operator.ge, "2.4.0"):
return (
functools.partial(torch.amp.custom_fwd, device_type="cuda"),
functools.partial(torch.amp.custom_bwd, device_type="cuda"),
)
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd


amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd()


torch_to_triton_dtype = {
torch.float32: tl.float32,
torch.float16: tl.float16,
Expand Down
8 changes: 8 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest
import torch


@pytest.fixture(autouse=True)
def clear_cuda_cache():
yield
torch.cuda.empty_cache()
69 changes: 69 additions & 0 deletions test/transformers/test_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,72 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol):
y2.backward(grad_output)

assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol)


@pytest.mark.parametrize(
"B, T, H, V",
[
(2, 4, 512, 512), # The test does not work on some CI GPUs. Issue #160
(8, 2048, 4096, 32000), # llama2, mistral
# Comment out to speed up testing
(4, 2048, 4096, 128256), # llama3 8B
(4, 1024, 8192, 128256), # llama3 70B
(4, 423, 8192, 32000), # random shape
],
)
@pytest.mark.parametrize(
"cast_dtype, atol, rtol",
[
(torch.bfloat16, 5e-3, 5e-2),
(torch.float16, 5e-3, 5e-2),
],
)
def test_amp(B, T, H, V, cast_dtype, atol, rtol):
device = "cuda"
dtype = torch.float32
torch_lm_head_ce = TorchLMHeadCE(
H=H,
V=V,
bias=True,
label_smoothing=0.0,
reduction="mean",
dtype=dtype,
).to(device)
liger_lm_head_ce = LigerLMHeadCE(
H=H,
V=V,
bias=True,
label_smoothing=0.0,
reduction="mean",
dtype=dtype,
).to(device)

# init the linear in all CEs with the same weights
torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand(
V, H, device=device, dtype=dtype
)

_tensor = torch.randn(B * T, H, device=device, dtype=dtype)
_input1 = _tensor.detach().clone().requires_grad_(True)
_input2 = _tensor.detach().clone().requires_grad_(True)

target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long)

with torch.autocast(device_type="cuda", dtype=cast_dtype):
output1 = torch_lm_head_ce(_input1, target)
output2 = liger_lm_head_ce(_input2, target)

assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol)

with torch.autocast(device_type="cuda", dtype=cast_dtype):
output1.backward()
output2.backward()

assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol)

assert_verbose_allclose(
torch_lm_head_ce.lin.weight.grad,
liger_lm_head_ce.lin.weight.grad,
atol=atol,
rtol=rtol,
)
6 changes: 3 additions & 3 deletions test/transformers/test_fused_linear_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ def forward(self, student_input, teacher_input, label=None):
@pytest.mark.parametrize(
"B, T, H, V",
[
(2, 4, 2048, 3200),
(2, 2048, 4096, 32000), # llama2, mistral
(2, 2, 512, 1600),
(2, 4, 1024, 1600),
# Comment out to speed up testing
# (4, 2048, 4096, 128256), # llama3 8B
# (4, 1024, 8192, 128256), # llama3 70B
(4, 423, 8192, 32000), # random shape
(4, 423, 167, 1423), # random shape
],
)
@pytest.mark.parametrize(
Expand Down
6 changes: 3 additions & 3 deletions test/transformers/test_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def forward(
_SHAPE_PARAMS = (
"B, T, V",
[
(2, 4096, 32000), # llama2, mistral
(2, 4096, 32000), # llama2, mistral
(2, 1024, 3200),
(2, 1024, 3200),
# weird shape
(41, 401, 1271),
pytest.param(
Expand All @@ -66,7 +66,7 @@ def forward(
reason="This test requires a GPU with at least 36GB of memory",
),
),
(3, 423, 32000),
(3, 423, 1600),
],
)

Expand Down

0 comments on commit 7ab2adf

Please sign in to comment.