diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index c72ba8d4..d326f73c 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -1,8 +1,21 @@ +import operator +from typing import Optional + import torch import triton import triton.language as tl -from liger_kernel.ops.utils import element_mul_kernel +from liger_kernel.ops.utils import compare_version, element_mul_kernel + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh +else: + from triton.language.math import tanh @triton.jit @@ -18,7 +31,9 @@ def liger_cross_entropy_kernel( ignore_index, label_smoothing: tl.constexpr, reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time + softcap, BLOCK_SIZE: tl.constexpr, + HAS_SOFTCAPPING: tl.constexpr, ): """ This kernel computes both cross entropy loss and the gradient of the input. @@ -36,7 +51,9 @@ def liger_cross_entropy_kernel( ignore_index (int): The index to ignore in the target. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduction (str): The string for the reduction to apply + softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). BLOCK_SIZE (int): The block size for Triton operations. + HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. """ # https://github.com/triton-lang/triton/issues/1058 @@ -68,6 +85,8 @@ def liger_cross_entropy_kernel( ori_X_y = tl.load( X_ptr + y ) # we need to store the original value of X_y for the loss calculation + if HAS_SOFTCAPPING: + ori_X_y = softcap * tanh(ori_X_y / softcap) # Label smoothing is a general case of normal cross entropy # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 @@ -79,6 +98,8 @@ def liger_cross_entropy_kernel( X_block = tl.load( X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") ) + if HAS_SOFTCAPPING: + X_block = softcap * tanh(X_block / softcap) block_max = tl.max(X_block) if label_smoothing > 0: # scale X beforehand to avoid overflow @@ -109,10 +130,27 @@ def liger_cross_entropy_kernel( X_block = tl.load( X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") ) + if HAS_SOFTCAPPING: + intermediate = tanh(X_block / softcap) + X_block = softcap * intermediate + if reduction == "mean": - X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + X_block = tl.where( + X_offsets != y, + (tl.exp(X_block - m) / d - eps) / (n_non_ignore), + (tl.exp(X_block - m) / d - eps - (1 - label_smoothing)) + / (n_non_ignore), + ) + else: - X_block = tl.exp(X_block - m) / d - eps + X_block = tl.where( + X_offsets != y, + (tl.exp(X_block - m) / d - eps), + (tl.exp(X_block - m) / d - eps - (1 - label_smoothing)), + ) + + if HAS_SOFTCAPPING: + X_block = X_block * (1 - intermediate * intermediate) tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) @@ -132,7 +170,7 @@ def liger_cross_entropy_kernel( # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: - # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) + # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd)) # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 @@ -144,15 +182,7 @@ def liger_cross_entropy_kernel( if reduction == "mean": loss = loss / n_non_ignore - # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` - X_y = tl.load(X_ptr + y) - if reduction == "mean": - X_y += -(1 - label_smoothing) / (n_non_ignore) - else: - X_y += -(1 - label_smoothing) - tl.store(loss_ptr, loss) - tl.store(X_ptr + y, X_y) # The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 @@ -161,7 +191,9 @@ def liger_cross_entropy_kernel( MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning -def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction): +def cross_entropy_forward( + _input, target, ignore_index, label_smoothing, reduction, softcap +): BT, V = _input.shape n_rows = BT @@ -191,7 +223,9 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti ignore_index=ignore_index, label_smoothing=label_smoothing, reduction=reduction, + softcap=softcap if softcap is not None else 0.0, BLOCK_SIZE=BLOCK_SIZE, + HAS_SOFTCAPPING=True if softcap is not None else False, # TODO: 32 seems to give the best performance # Performance is quite sensitive to num_warps num_warps=32, @@ -233,7 +267,13 @@ class LigerCrossEntropyFunction(torch.autograd.Function): @staticmethod def forward( - ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean" + ctx, + _input: torch.Tensor, + target: torch.Tensor, + ignore_index: int = -100, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, ): """ The forward pass of the Liger Cross Entropy loss. @@ -245,12 +285,13 @@ def forward( ignore_index (int): The index to ignore in the target. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduction (str): The reduction to apply to the output: "none" | "mean | "sum". + softcap (Optional[float]): The upper threshold for scaling logits to the range (-softcap, +softcap). Returns: tensor: The computed loss. """ loss, _input = cross_entropy_forward( - _input, target, ignore_index, label_smoothing, reduction + _input, target, ignore_index, label_smoothing, reduction, softcap ) # TODO: investigation # If we don't detach the _input tensor, the memory will double @@ -278,4 +319,5 @@ def backward(ctx, grad_output): None, None, None, + None, ) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 371a8919..ae012ec7 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -18,6 +18,7 @@ def fused_linear_cross_entropy_forward( ignore_index=-100, label_smoothing=0.0, reduction="mean", + softcap=None, ): dtype = _input.dtype device = _input.device @@ -87,7 +88,9 @@ def fused_linear_cross_entropy_forward( ignore_index=ignore_index, label_smoothing=label_smoothing, reduction=reduction, + softcap=softcap if softcap is not None else 0.0, BLOCK_SIZE=BLOCK_SIZE, + HAS_SOFTCAPPING=True if softcap is not None else False, num_warps=32, ) @@ -197,6 +200,7 @@ def forward( ignore_index=-100, label_smoothing=0.0, reduction="mean", + softcap=None, ): """ Fusing the last linear layer with cross-entropy loss @@ -216,7 +220,14 @@ def forward( reduction: reduction to apply """ loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( - _input, weight, target, bias, ignore_index, label_smoothing, reduction + _input, + weight, + target, + bias, + ignore_index, + label_smoothing, + reduction, + softcap, ) # downcast to dtype and store for backward ctx.save_for_backward( @@ -233,4 +244,4 @@ def backward(ctx, grad_output): grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( grad_output, grad_input, grad_weight, grad_bias ) - return (grad_input, grad_weight, None, grad_bias, None, None, None) + return (grad_input, grad_weight, None, grad_bias, None, None, None, None) diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index b2457481..232ace85 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -1,21 +1,41 @@ -from torch.nn import CrossEntropyLoss +from typing import Optional + +import torch from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction -class LigerCrossEntropyLoss(CrossEntropyLoss): - def __init__(self, *args, **kwargs): - super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs) - assert (self.label_smoothing >= 0) and ( - self.label_smoothing <= 1 - ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}" - assert self.reduction in { +class LigerCrossEntropyLoss(torch.nn.Module): + def __init__( + self, + ignore_index: int = -100, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, + ): + super().__init__() + assert (label_smoothing >= 0) and ( + label_smoothing <= 1 + ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" + assert reduction in { "mean", "sum", "none", - }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {self.reduction}" + }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}" + assert ( + softcap is None or softcap > 0 + ), f"softcap must greater than 0.0 or None. Got: {softcap}" + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + self.softcap = softcap - def forward(self, _input, target): + def forward(self, _input: torch.Tensor, target: torch.Tensor): return LigerCrossEntropyFunction.apply( - _input, target, self.ignore_index, self.label_smoothing, self.reduction + _input, + target, + self.ignore_index, + self.label_smoothing, + self.reduction, + self.softcap, ) diff --git a/src/liger_kernel/transformers/fused_linear_cross_entropy.py b/src/liger_kernel/transformers/fused_linear_cross_entropy.py index 0e333156..fb9a193a 100644 --- a/src/liger_kernel/transformers/fused_linear_cross_entropy.py +++ b/src/liger_kernel/transformers/fused_linear_cross_entropy.py @@ -1,13 +1,36 @@ -from torch.nn import CrossEntropyLoss +from typing import Optional + +import torch from liger_kernel.ops.fused_linear_cross_entropy import ( LigerFusedLinearCrossEntropyFunction, ) -class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss): - def __init__(self, *args, **kwargs): - super(LigerFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs) +class LigerFusedLinearCrossEntropyLoss(torch.nn.Module): + def __init__( + self, + ignore_index: int = -100, + label_smoothing: float = 0.0, + reduction: str = "mean", + softcap: Optional[float] = None, + ): + super().__init__() + assert (label_smoothing >= 0) and ( + label_smoothing <= 1 + ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" + assert reduction in { + "mean", + "sum", + "none", + }, f"reduction must be one of 'mean', 'sum', or 'none'. Got: {reduction}" + assert ( + softcap is None or softcap > 0 + ), f"softcap must greater than 0.0 or None. Got: {softcap}" + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + self.softcap = softcap def forward(self, lin_weight, _input, target, bias=None): return LigerFusedLinearCrossEntropyFunction.apply( @@ -18,4 +41,5 @@ def forward(self, lin_weight, _input, target, bias=None): self.ignore_index, self.label_smoothing, self.reduction, + self.softcap, ) diff --git a/src/liger_kernel/transformers/model/gemma2.py b/src/liger_kernel/transformers/model/gemma2.py new file mode 100644 index 00000000..fd0aa15e --- /dev/null +++ b/src/liger_kernel/transformers/model/gemma2.py @@ -0,0 +1,137 @@ +from typing import Optional, Tuple, Union + +import torch +from transformers.cache_utils import HybridCache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.gemma2.modeling_gemma2 import ( + _CONFIG_FOR_DOC, + GEMMA2_INPUTS_DOCSTRING, + logger, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + +from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, +) + + +@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + + copy paste transformers.models.gemma2.modeling_gemma2 CausalLM with loss replaced with liger fused linear cross entropy + + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + + if self.training and self.config._attn_implementation != "eager": + logger.warning_once( + "It is strongly recommended to train Gemma2 models with the `eager` attention implementation " + f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`." + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + if self.training and (labels is not None): + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss( + softcap=self.config.final_logit_softcapping + ) + logits = None + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + else: + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states) + if self.config.final_logit_softcapping is not None: + logits = logits / self.config.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.config.final_logit_softcapping + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index be286903..7e8596cd 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -9,6 +9,7 @@ from liger_kernel.transformers.geglu import LigerGEGLUMLP from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward +from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward @@ -228,7 +229,7 @@ def apply_liger_kernel_to_mistral( Apply Liger kernels to replace original implementation in HuggingFace Mistral models Args: - rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + rope (bool): Whether to apply Liger's rotary position embedding. Default is False. cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True. fused_linear_cross_entropy (bool): Whether to apply Liger's fused linear cross entropy loss. Default is True. @@ -422,7 +423,8 @@ def apply_liger_kernel_to_gemma( def apply_liger_kernel_to_gemma2( rope: bool = True, - cross_entropy: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, rms_norm: bool = True, geglu: bool = True, model: PreTrainedModel = None, @@ -433,12 +435,19 @@ def apply_liger_kernel_to_gemma2( Args: rope (bool): Whether to apply Liger's rotary position embedding. Default is True. - cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True. model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ + assert not ( + cross_entropy and fused_linear_cross_entropy + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." from transformers.models.gemma2 import modeling_gemma2 LigerRMSNormForGemma2 = partial( @@ -455,6 +464,8 @@ def apply_liger_kernel_to_gemma2( modeling_gemma2.Gemma2RMSNorm = LigerRMSNormForGemma2 if cross_entropy: modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss + if fused_linear_cross_entropy: + modeling_gemma2.Gemma2ForCausalLM.forward = gemma2_lce_forward if geglu: modeling_gemma2.Gemma2MLP = LigerGEGLUMLP diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index 5aa61eaa..78e3fd73 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -149,7 +149,9 @@ ), ), "mini_gemma2": MiniModelConfig( - liger_kernel_patch_func=apply_liger_kernel_to_gemma2, + liger_kernel_patch_func=functools.partial( + apply_liger_kernel_to_gemma2, fused_linear_cross_entropy=False + ), liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma2, model_class=Gemma2ForCausalLM, mini_model_config=Gemma2Config( diff --git a/test/convergence/test_mini_models_no_logits.py b/test/convergence/test_mini_models_no_logits.py index 35b751a2..7cdd1244 100644 --- a/test/convergence/test_mini_models_no_logits.py +++ b/test/convergence/test_mini_models_no_logits.py @@ -411,12 +411,8 @@ def run_mini_model( else: kwargs["swiglu"] = True - model_support_flce = "gemma2" not in model_name - if model_support_flce: - kwargs["fused_linear_cross_entropy"] = True - kwargs["cross_entropy"] = False - else: - kwargs["cross_entropy"] = True + kwargs["fused_linear_cross_entropy"] = True + kwargs["cross_entropy"] = False MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) else: diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 1a970573..24bd7656 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -1,4 +1,4 @@ -from test.utils import set_seed, supports_bfloat16 +from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 import pytest import torch @@ -116,6 +116,30 @@ def _test_correctness_with_label_smoothing_with_ignore_index_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_softcap_once( + target_ce, B, T, V, softcap, reduction, scalar, dtype, atol, rtol +): + + torch_ce = CrossEntropyLoss(reduction=reduction) + + _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + # upcasting to match liger's casting strategy + _input = _tensor.to(torch.float32).detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + + # downcasting to original dtype + output = torch_ce(softcap * torch.tanh(_input / softcap), target).to(dtype) + output2 = target_ce(_input2, target) + + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + output.backward() + output2.backward() + assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + def _test_correctness_not_last_layer_once( target_ce, B, T, V, reduction, scalar, dtype, atol, rtol ): @@ -140,7 +164,19 @@ def _test_correctness_not_last_layer_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) -def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): +def _test_correctness_functional( + B, + T, + V, + scalar, + ignore_index, + label_smoothing, + reduction, + softcap, + dtype, + atol, + rtol, +): _input = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar @@ -149,8 +185,12 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) - y1 = liger_cross_entropy(x1, target, 0) - y2 = LigerCrossEntropyFunction.apply(x2, target, 0) + y1 = liger_cross_entropy( + x1, target, ignore_index, label_smoothing, reduction, softcap + ) + y2 = LigerCrossEntropyFunction.apply( + x2, target, ignore_index, label_smoothing, reduction, softcap + ) assert torch.allclose(y1, y2, atol=atol, rtol=rtol) @@ -173,7 +213,7 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): (2, 4096, 32000), # llama2, mistral (2, 4096, 32000), # llama2, mistral (1, 4096, 128256), # llama3 - # # weird shapes + # weird shapes (3, 423, 32000), ], ) @@ -241,8 +281,39 @@ def test_correctness(B, T, V, scalar, dtype, reduction, atol, rtol): (10.0, torch.float32, 1e-8, 1e-6), ], ) -def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol): - _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol) +@pytest.mark.parametrize( + "ignore_index, label_smoothing, reduction, softcap", + [ + (-100, 0.0, "mean", None), + (42, 0.1, "sum", 30), + ], +) +def test_correctness_functional( + B, + T, + V, + scalar, + ignore_index, + label_smoothing, + reduction, + softcap, + dtype, + atol, + rtol, +): + _test_correctness_functional( + B, + T, + V, + scalar, + ignore_index, + label_smoothing, + reduction, + softcap, + dtype, + atol, + rtol, + ) @pytest.mark.parametrize( @@ -423,6 +494,65 @@ def test_correctness_with_label_smoothing_with_ignore_index_once( ) +@pytest.mark.parametrize( + "B, T, V, softcap", + [ + (2, 4096, 32000, 30.0), # llama2, mistral + (2, 4096, 32000, 30.0), # llama2, mistral + (1, 4096, 128256, 30.0), # llama3 + # weird shapes + (3, 423, 32000, 30.0), + ], +) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 0.1, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 10.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (0.1, torch.float32, 1e-8, 1e-6), + (1.0, torch.float32, 1e-8, 1e-6), + (10.0, torch.float32, 1e-8, 1e-6), + ], +) +@pytest.mark.skipif( + torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, + reason="Needs 16GB+ GPU memory.", +) +def test_correctness_with_softcap_once( + B, T, V, softcap, reduction, scalar, dtype, atol, rtol +): + liger_ce = LigerCrossEntropyLoss(softcap=softcap, reduction=reduction) + _test_correctness_with_softcap_once( + liger_ce, B, T, V, softcap, reduction, scalar, dtype, atol, rtol + ) + + @pytest.mark.parametrize( "B, T, V", [