Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FusedLinearCrossEntropy for Gemma2 #320

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
72 changes: 57 additions & 15 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
import operator
from typing import Optional

import torch
import triton
import triton.language as tl

from liger_kernel.ops.utils import element_mul_kernel
from liger_kernel.ops.utils import compare_version, element_mul_kernel

if compare_version("triton", operator.ge, "3.0.0"):
try:
# typical import path with dispatch available
from triton.language.extra.libdevice import tanh
except ModuleNotFoundError:
# for working with NGC containers
from triton.language.extra.cuda.libdevice import tanh
else:
from triton.language.math import tanh


@triton.jit
Expand All @@ -18,7 +31,9 @@ def liger_cross_entropy_kernel(
ignore_index,
label_smoothing: tl.constexpr,
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
softcap,
BLOCK_SIZE: tl.constexpr,
HAS_SOFTCAPPING: tl.constexpr,
):
"""
This kernel computes both cross entropy loss and the gradient of the input.
Expand All @@ -36,7 +51,9 @@ def liger_cross_entropy_kernel(
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
BLOCK_SIZE (int): The block size for Triton operations.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
"""

# https://github.com/triton-lang/triton/issues/1058
Expand Down Expand Up @@ -68,6 +85,8 @@ def liger_cross_entropy_kernel(
ori_X_y = tl.load(
X_ptr + y
) # we need to store the original value of X_y for the loss calculation
if HAS_SOFTCAPPING:
ori_X_y = softcap * tanh(ori_X_y / softcap)

# Label smoothing is a general case of normal cross entropy
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
Expand All @@ -79,6 +98,8 @@ def liger_cross_entropy_kernel(
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
if HAS_SOFTCAPPING:
X_block = softcap * tanh(X_block / softcap)
block_max = tl.max(X_block)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
Expand Down Expand Up @@ -109,10 +130,27 @@ def liger_cross_entropy_kernel(
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
if HAS_SOFTCAPPING:
intermediate = tanh(X_block / softcap)
X_block = softcap * intermediate

if reduction == "mean":
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
X_block = tl.where(
X_offsets != y,
(tl.exp(X_block - m) / d - eps) / (n_non_ignore),
(tl.exp(X_block - m) / d - eps - (1 - label_smoothing))
/ (n_non_ignore),
)

else:
X_block = tl.exp(X_block - m) / d - eps
X_block = tl.where(
X_offsets != y,
(tl.exp(X_block - m) / d - eps),
(tl.exp(X_block - m) / d - eps - (1 - label_smoothing)),
)

if HAS_SOFTCAPPING:
X_block = X_block * (1 - intermediate * intermediate)

tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)

Expand All @@ -132,7 +170,7 @@ def liger_cross_entropy_kernel(
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
# = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
# = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
Expand All @@ -144,15 +182,7 @@ def liger_cross_entropy_kernel(
if reduction == "mean":
loss = loss / n_non_ignore

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
X_y = tl.load(X_ptr + y)
if reduction == "mean":
X_y += -(1 - label_smoothing) / (n_non_ignore)
else:
X_y += -(1 - label_smoothing)

tl.store(loss_ptr, loss)
tl.store(X_ptr + y, X_y)


# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
Expand All @@ -161,7 +191,9 @@ def liger_cross_entropy_kernel(
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning


def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction):
def cross_entropy_forward(
_input, target, ignore_index, label_smoothing, reduction, softcap
):
BT, V = _input.shape
n_rows = BT

Expand Down Expand Up @@ -191,7 +223,9 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
ignore_index=ignore_index,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
BLOCK_SIZE=BLOCK_SIZE,
HAS_SOFTCAPPING=True if softcap is not None else False,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
num_warps=32,
Expand Down Expand Up @@ -233,7 +267,13 @@ class LigerCrossEntropyFunction(torch.autograd.Function):

@staticmethod
def forward(
ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean"
ctx,
_input: torch.Tensor,
target: torch.Tensor,
ignore_index: int = -100,
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
):
"""
The forward pass of the Liger Cross Entropy loss.
Expand All @@ -245,12 +285,13 @@ def forward(
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduction (str): The reduction to apply to the output: "none" | "mean | "sum".
softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap).

Returns:
tensor: The computed loss.
"""
loss, _input = cross_entropy_forward(
_input, target, ignore_index, label_smoothing, reduction
_input, target, ignore_index, label_smoothing, reduction, softcap
)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
Expand Down Expand Up @@ -278,4 +319,5 @@ def backward(ctx, grad_output):
None,
None,
None,
None,
)
15 changes: 13 additions & 2 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def fused_linear_cross_entropy_forward(
ignore_index=-100,
label_smoothing=0.0,
reduction="mean",
softcap=None,
):
dtype = _input.dtype
device = _input.device
Expand Down Expand Up @@ -87,7 +88,9 @@ def fused_linear_cross_entropy_forward(
ignore_index=ignore_index,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
BLOCK_SIZE=BLOCK_SIZE,
HAS_SOFTCAPPING=True if softcap is not None else False,
num_warps=32,
)

Expand Down Expand Up @@ -197,6 +200,7 @@ def forward(
ignore_index=-100,
label_smoothing=0.0,
reduction="mean",
softcap=None,
):
"""
Fusing the last linear layer with cross-entropy loss
Expand All @@ -216,7 +220,14 @@ def forward(
reduction: reduction to apply
"""
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
_input, weight, target, bias, ignore_index, label_smoothing, reduction
_input,
weight,
target,
bias,
ignore_index,
label_smoothing,
reduction,
softcap,
)
# downcast to dtype and store for backward
ctx.save_for_backward(
Expand All @@ -233,4 +244,4 @@ def backward(ctx, grad_output):
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
grad_output, grad_input, grad_weight, grad_bias
)
return (grad_input, grad_weight, None, grad_bias, None, None, None)
return (grad_input, grad_weight, None, grad_bias, None, None, None, None)
42 changes: 31 additions & 11 deletions src/liger_kernel/transformers/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@
from torch.nn import CrossEntropyLoss
from typing import Optional

import torch

from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction


class LigerCrossEntropyLoss(CrossEntropyLoss):
def __init__(self, *args, **kwargs):
super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
assert (self.label_smoothing >= 0) and (
self.label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
assert self.reduction in {
class LigerCrossEntropyLoss(torch.nn.Module):
def __init__(
self,
ignore_index: int = -100,
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
):
super().__init__()
assert (label_smoothing >= 0) and (
label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
assert reduction in {
"mean",
"sum",
"none",
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}"
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
assert (
softcap is None or softcap > 0
), f"softcap must greater than 0.0 or None. Got: {softcap}"
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction
self.softcap = softcap

def forward(self, _input, target):
def forward(self, _input: torch.Tensor, target: torch.Tensor):
return LigerCrossEntropyFunction.apply(
_input, target, self.ignore_index, self.label_smoothing, self.reduction
_input,
target,
self.ignore_index,
self.label_smoothing,
self.reduction,
self.softcap,
)
32 changes: 28 additions & 4 deletions src/liger_kernel/transformers/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,36 @@
from torch.nn import CrossEntropyLoss
from typing import Optional

import torch

from liger_kernel.ops.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyFunction,
)


class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss):
def __init__(self, *args, **kwargs):
super(LigerFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs)
class LigerFusedLinearCrossEntropyLoss(torch.nn.Module):
def __init__(
self,
ignore_index: int = -100,
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
):
super().__init__()
assert (label_smoothing >= 0) and (
label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
assert reduction in {
"mean",
"sum",
"none",
}, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}"
assert (
softcap is None or softcap > 0
), f"softcap must greater than 0.0 or None. Got: {softcap}"
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction
self.softcap = softcap

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearCrossEntropyFunction.apply(
Expand All @@ -18,4 +41,5 @@ def forward(self, lin_weight, _input, target, bias=None):
self.ignore_index,
self.label_smoothing,
self.reduction,
self.softcap,
)
Loading
Loading