diff --git a/README.md b/README.md index 5ebcec9a..1ddedb79 100644 --- a/README.md +++ b/README.md @@ -293,21 +293,6 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$ > **Note:** > Reported speedups and memory reductions are with respect to the LLaMA 3-8B Hugging Face layer implementations. All models use 4K hidden size and 4K sequence length and are evaluated based on memory usage and wall time for the forward+backward pass on a single NVIDIA A100 80G GPU using small batch sizes. Liger kernels exhibit more efficient scaling to larger batch sizes, detailed further in the [Benchmark](./benchmark) folder. -## Note on ML Compiler - -### Torch Compile - -Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). In the following example, Liger Kernel can further optimize the model on top of Torch Compile, reducing the memory by more than half. - -| Configuration | Throughput (tokens/sec) | Memory Reserved (GB) | -|--------------------------------|----------------------------|-------------------------| -| Torch Compile | 3780 | 66.4 | -| Torch Compile + Liger Kernel | 3702 | 31.0 | - -> **Note:** -> 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s. -> 2. Tested on torch `2.5.0.dev20240731+cu118` - ## Contributing [CONTRIBUTING GUIDE](https://github.com/linkedin/Liger-Kernel/blob/main/CONTRIBUTING.md) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index e9a28afb..371a8919 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -2,7 +2,7 @@ import triton from liger_kernel.ops.cross_entropy import liger_cross_entropy_kernel -from liger_kernel.ops.utils import element_mul_kernel +from liger_kernel.ops.utils import amp_custom_bwd, amp_custom_fwd, element_mul_kernel # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 # However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling @@ -19,9 +19,7 @@ def fused_linear_cross_entropy_forward( label_smoothing=0.0, reduction="mean", ): - dtype = ( - torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype - ) + dtype = _input.dtype device = _input.device # inputs have shape: BT x H @@ -189,6 +187,7 @@ def fused_linear_cross_entropy_backward( class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): @staticmethod + @amp_custom_fwd def forward( ctx, _input, @@ -228,6 +227,7 @@ def forward( return loss @staticmethod + @amp_custom_bwd def backward(ctx, grad_output): (grad_input, grad_weight, grad_bias) = ctx.saved_tensors grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( diff --git a/src/liger_kernel/ops/utils.py b/src/liger_kernel/ops/utils.py index 2c01f3ac..beaa75b9 100644 --- a/src/liger_kernel/ops/utils.py +++ b/src/liger_kernel/ops/utils.py @@ -12,6 +12,7 @@ import functools import importlib +import operator from typing import Callable import torch @@ -63,6 +64,18 @@ def compare_version(package: str, operator: Callable, target: str): return operator(pkg_version, Version(target)) +def get_amp_custom_fwd_bwd() -> Callable: + if compare_version("torch", operator.ge, "2.4.0"): + return ( + functools.partial(torch.amp.custom_fwd, device_type="cuda"), + functools.partial(torch.amp.custom_bwd, device_type="cuda"), + ) + return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd + + +amp_custom_fwd, amp_custom_bwd = get_amp_custom_fwd_bwd() + + torch_to_triton_dtype = { torch.float32: tl.float32, torch.float16: tl.float16, diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 00000000..806fa866 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,8 @@ +import pytest +import torch + + +@pytest.fixture(autouse=True) +def clear_cuda_cache(): + yield + torch.cuda.empty_cache() diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index 57e2cf53..7bb32baa 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -203,3 +203,72 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol): y2.backward(grad_output) assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 4, 512, 512), # The test does not work on some CI GPUs. Issue #160 + (8, 2048, 4096, 32000), # llama2, mistral + # Comment out to speed up testing + (4, 2048, 4096, 128256), # llama3 8B + (4, 1024, 8192, 128256), # llama3 70B + (4, 423, 8192, 32000), # random shape + ], +) +@pytest.mark.parametrize( + "cast_dtype, atol, rtol", + [ + (torch.bfloat16, 5e-3, 5e-2), + (torch.float16, 5e-3, 5e-2), + ], +) +def test_amp(B, T, H, V, cast_dtype, atol, rtol): + device = "cuda" + dtype = torch.float32 + torch_lm_head_ce = TorchLMHeadCE( + H=H, + V=V, + bias=True, + label_smoothing=0.0, + reduction="mean", + dtype=dtype, + ).to(device) + liger_lm_head_ce = LigerLMHeadCE( + H=H, + V=V, + bias=True, + label_smoothing=0.0, + reduction="mean", + dtype=dtype, + ).to(device) + + # init the linear in all CEs with the same weights + torch_lm_head_ce.lin.weight.data = liger_lm_head_ce.lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) + + _tensor = torch.randn(B * T, H, device=device, dtype=dtype) + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + with torch.autocast(device_type="cuda", dtype=cast_dtype): + output1 = torch_lm_head_ce(_input1, target) + output2 = liger_lm_head_ce(_input2, target) + + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + + with torch.autocast(device_type="cuda", dtype=cast_dtype): + output1.backward() + output2.backward() + + assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) + + assert_verbose_allclose( + torch_lm_head_ce.lin.weight.grad, + liger_lm_head_ce.lin.weight.grad, + atol=atol, + rtol=rtol, + ) diff --git a/test/transformers/test_fused_linear_jsd.py b/test/transformers/test_fused_linear_jsd.py index 321f45ab..2024e054 100644 --- a/test/transformers/test_fused_linear_jsd.py +++ b/test/transformers/test_fused_linear_jsd.py @@ -89,12 +89,12 @@ def forward(self, student_input, teacher_input, label=None): @pytest.mark.parametrize( "B, T, H, V", [ - (2, 4, 2048, 3200), - (2, 2048, 4096, 32000), # llama2, mistral + (2, 2, 512, 1600), + (2, 4, 1024, 1600), # Comment out to speed up testing # (4, 2048, 4096, 128256), # llama3 8B # (4, 1024, 8192, 128256), # llama3 70B - (4, 423, 8192, 32000), # random shape + (4, 423, 167, 1423), # random shape ], ) @pytest.mark.parametrize( diff --git a/test/transformers/test_jsd.py b/test/transformers/test_jsd.py index 564b85cf..37e12180 100644 --- a/test/transformers/test_jsd.py +++ b/test/transformers/test_jsd.py @@ -52,8 +52,8 @@ def forward( _SHAPE_PARAMS = ( "B, T, V", [ - (2, 4096, 32000), # llama2, mistral - (2, 4096, 32000), # llama2, mistral + (2, 1024, 3200), + (2, 1024, 3200), # weird shape (41, 401, 1271), pytest.param( @@ -66,7 +66,7 @@ def forward( reason="This test requires a GPU with at least 36GB of memory", ), ), - (3, 423, 32000), + (3, 423, 1600), ], )