diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index 7bb32baa..d9ea5ae9 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -100,9 +100,20 @@ def forward(self, x, y): ], ) @pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("label_smoothing", [0, 0.1]) +@pytest.mark.parametrize("label_smoothing, ignore_index", [(0.0, -100), (0.1, 42)]) def test_correctness( - B, T, H, V, scalar, dtype, bias, label_smoothing, reduction, atol, rtol + B, + T, + H, + V, + scalar, + dtype, + bias, + label_smoothing, + ignore_index, + reduction, + atol, + rtol, ): device = "cuda" torch_lm_head_ce = TorchLMHeadCE( @@ -110,6 +121,7 @@ def test_correctness( V=V, bias=bias, label_smoothing=label_smoothing, + ignore_index=ignore_index, reduction=reduction, dtype=dtype, ).to(device) @@ -118,6 +130,7 @@ def test_correctness( V=V, bias=bias, label_smoothing=label_smoothing, + ignore_index=ignore_index, reduction=reduction, dtype=dtype, ).to(device) @@ -137,6 +150,14 @@ def test_correctness( _input2 = _tensor.detach().clone().requires_grad_(True) target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + target[indices_to_assign] = ignore_index output1 = torch_lm_head_ce(_input1, target) output2 = liger_lm_head_ce(_input2, target)