Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] [ROCm] Pick num_warps based on platform #326

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

tjtanaa
Copy link

@tjtanaa tjtanaa commented Oct 27, 2024

Summary

This is a PR to enable the kernel to run on AMD GPUs through the initial changes to the num_warps.
This change is proposed by @Edenzzzz and @DocShotgun in this issue #266

Testing Done

  • Hardware Type: AMD Instinct MI300X
  • run make test to ensure correctness
    • There are some test failed due to numerical precision issue.
    • Something is weird as well, if I just run the failed test test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100], the test passed. By running pytest test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100]. However it will failed if there are other tests running before this test.
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence
Failure Test Logs (Click to expand/collapse) ```bash ============================================================= FAILURES ============================================================= ________________________ test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100] _________________________
B = 2, T = 4096, V = 32000, ignore_index = -100, reduction = 'sum', scalar = 10.0, dtype = torch.float32, atol = 1e-08, rtol = 1e-06

    @pytest.mark.parametrize(
        "B, T, V, ignore_index",
        [
            (2, 4096, 32000, -100),  # llama2, mistral
            (2, 4096, 32000, 2),  # llama2, mistral
            (1, 4096, 128256, -300),  # llama3
            # weird shapes
            (3, 423, 32000, -123),
        ],
    )
    @pytest.mark.parametrize("reduction", ["sum", "mean"])
    @pytest.mark.parametrize(
        "scalar, dtype, atol, rtol",
        [
            pytest.param(
                0.1,
                torch.bfloat16,
                1e-8,
                5e-2,
                marks=pytest.mark.skipif(
                    not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                ),
            ),
            pytest.param(
                1.0,
                torch.bfloat16,
                1e-8,
                5e-2,
                marks=pytest.mark.skipif(
                    not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                ),
            ),
            pytest.param(
                10.0,
                torch.bfloat16,
                1e-8,
                5e-2,
                marks=pytest.mark.skipif(
                    not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                ),
            ),
            (0.1, torch.float32, 1e-8, 1e-6),
            (1.0, torch.float32, 1e-8, 1e-6),
            (10.0, torch.float32, 1e-8, 1e-6),
        ],
    )
    @pytest.mark.skipif(
        torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
        reason="Needs 16GB+ GPU memory.",
    )
    def test_correctness_with_ignore_index(
        B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol
    ):
        liger_ce = LigerCrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
>       _test_correctness_with_ignore_index_once(
            liger_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol
        )

test/transformers/test_cross_entropy.py:302: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

target_ce = LigerCrossEntropyLoss(), B = 2, T = 4096, V = 32000, ignore_index = -100, reduction = 'sum', scalar = 10.0
dtype = torch.float32, atol = 1e-08, rtol = 1e-06

    def _test_correctness_with_ignore_index_once(
        target_ce, B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol
    ):
    
        torch_ce = CrossEntropyLoss(ignore_index=ignore_index, reduction=reduction)
    
        _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar
        _input = _tensor.detach().clone().requires_grad_(True)
        _input2 = _tensor.detach().clone().requires_grad_(True)
    
        target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long)
    
        # Assign some random number of elements as ignore_index
        num_elements_to_assign = torch.randint(
            1, B * T // 2, (1,)
        ).item()  # Random number of elements to set to ignore_index
        indices_to_assign = torch.randperm(B * T)[
            :num_elements_to_assign
        ]  # Randomly select indices
        target[indices_to_assign] = ignore_index
    
        output = torch_ce(_input, target)
        output2 = target_ce(_input2, target)
    
        assert torch.allclose(output, output2, atol=atol, rtol=rtol)
    
        output.backward()
        output2.backward()
>       assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
E       AssertionError: assert False
E        +  where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3721e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0'), tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3722e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0'), atol=1e-08, rtol=1e-06)
E        +    where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose
E        +    and   tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3721e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0') = tensor([[  6.0503,   3.7258,  -0.3530,  ...,  11.8853,  20.5071,  -9.9739],\n        [ 15.2597,  -0.5924,   6.6471,  ...,  -9.3584,   3.0466,  -2.5966],\n        [-17.9122,  31.2363,  -1.4114,  ...,  -5.5268,  17.4033,  -3.3372],\n        ...,\n        [  4.3242,  -7.8904,  10.2973,  ..., -17.3829,  -1.2789,   6.6447],\n        [-10.9055,  10.4553,  -5.2270,  ..., -12.5100,   5.0782,  11.1050],\n        [ -5.8922,  15.0620,   5.5783,  ...,  -5.3107,   6.2329, -13.0452]],\n       device='cuda:0', requires_grad=True).grad
E        +    and   tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3722e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0') = tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3722e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0', requires_grad=True).grad

test/transformers/test_cross_entropy.py:61: AssertionError
_________________________________ test_correctness_with_beta[0.1-dtype1-1e-08-1e-06-1-4096-128256] _________________________________

B = 1, T = 4096, V = 128256, beta = 0.1, dtype = torch.float32, atol = 1e-08, rtol = 1e-06

    @pytest.mark.parametrize(*_SHAPE_PARAMS)
    @pytest.mark.parametrize(*_DTYPE_PARAMS)
    @pytest.mark.parametrize("beta", [0.1, 0.5, 0.9])
    def test_correctness_with_beta(B, T, V, beta, dtype, atol, rtol):
        liger_jsd = LigerJSD(beta=beta)
>       _test_correctness_with_beta_once(liger_jsd, beta, B, T, V, dtype, atol, rtol)

test/transformers/test_jsd.py:269: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
test/transformers/test_jsd.py:157: in _test_correctness_with_beta_once
    assert_verbose_allclose(output, output2, atol=atol, rtol=rtol)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor(0.0805, device='cuda:0', grad_fn=<SumBackward0>)
tensor2 = tensor(0.0805, device='cuda:0', grad_fn=<LigerJSDFunctionBackward>), rtol = 1e-06, atol = 1e-08, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(
            torch.isposinf(tensor1), torch.isposinf(tensor2)
        )
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(
            torch.isneginf(tensor1), torch.isneginf(tensor2)
        )
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(
                    f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}"
                )
            if num_mismatched > max_print:
                mismatch_details.append(
                    f"... and {num_mismatched - max_print} more mismatched elements."
                )
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 1
E           Mismatch at index (): tensor1[()] = 0.08054989576339722, tensor2[()] = 0.08054977655410767

test/utils.py:106: AssertionError
_________________________________ test_correctness_with_beta[0.9-dtype1-1e-08-1e-06-1-4096-128256] _________________________________

B = 1, T = 4096, V = 128256, beta = 0.9, dtype = torch.float32, atol = 1e-08, rtol = 1e-06

    @pytest.mark.parametrize(*_SHAPE_PARAMS)
    @pytest.mark.parametrize(*_DTYPE_PARAMS)
    @pytest.mark.parametrize("beta", [0.1, 0.5, 0.9])
    def test_correctness_with_beta(B, T, V, beta, dtype, atol, rtol):
        liger_jsd = LigerJSD(beta=beta)
>       _test_correctness_with_beta_once(liger_jsd, beta, B, T, V, dtype, atol, rtol)

test/transformers/test_jsd.py:269: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
test/transformers/test_jsd.py:157: in _test_correctness_with_beta_once
    assert_verbose_allclose(output, output2, atol=atol, rtol=rtol)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tensor1 = tensor(0.0805, device='cuda:0', grad_fn=<SumBackward0>)
tensor2 = tensor(0.0805, device='cuda:0', grad_fn=<LigerJSDFunctionBackward>), rtol = 1e-06, atol = 1e-08, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
    
        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.
    
        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")
    
        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)
    
        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)
    
        # Find tolerance mismatched elements
        tol_mismatched = diff > tolerance
    
        # Find nan mismatched elements
        nan_mismatched = torch.logical_xor(torch.isnan(tensor1), torch.isnan(tensor2))
    
        # Find +inf mismatched elements
        posinf_mismatched = torch.logical_xor(
            torch.isposinf(tensor1), torch.isposinf(tensor2)
        )
        # Find -inf mismatched elements
        neginf_mismatched = torch.logical_xor(
            torch.isneginf(tensor1), torch.isneginf(tensor2)
        )
    
        # Find all mismatched elements
        mismatched = torch.logical_or(
            torch.logical_or(tol_mismatched, nan_mismatched),
            torch.logical_or(posinf_mismatched, neginf_mismatched),
        )
    
        mismatched_indices = torch.nonzero(mismatched)
    
        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()
    
        # Check if all elements are close
        all_close = num_mismatched == 0
    
        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched >= 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(
                    f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}"
                )
            if num_mismatched > max_print:
                mismatch_details.append(
                    f"... and {num_mismatched - max_print} more mismatched elements."
                )
    
>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 1
E           Mismatch at index (): tensor1[()] = 0.08054172992706299, tensor2[()] = 0.08054161071777344

test/utils.py:106: AssertionError
___________________________________ test_correctness[dtype1-1e-08-1e-06-none-False-32-4096-1024] ___________________________________

B = 32, T = 4096, V = 1024, log_target = False, reduction = 'none', dtype = torch.float32, atol = 1e-08, rtol = 1e-06

    @pytest.mark.parametrize(*_SHAPE_PARAMS)
    @pytest.mark.parametrize("log_target", [True, False])
    @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"])
    @pytest.mark.parametrize(*_DTYPE_PARAMS)
    def test_correctness(B, T, V, log_target, reduction, dtype, atol, rtol):
        liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target)
>       _test_correctness_once(
            liger_kldiv, B, T, V, dtype, atol, rtol, reduction, log_target
        )

test/transformers/test_kl_div.py:97: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

target_kldiv = LigerKLDIVLoss(), B = 32, T = 4096, V = 1024, dtype = torch.float32, atol = 1e-08, rtol = 1e-06, reduction = 'none'
log_target = False, is_last_layer = True, device = 'cuda'

    def _test_correctness_once(
        target_kldiv,
        B,
        T,
        V,
        dtype,
        atol,
        rtol,
        reduction,
        log_target,
        is_last_layer=True,
        device="cuda",
    ):
        torch.manual_seed(0)
        torch_kldiv = KLDivLoss(reduction=reduction, log_target=log_target)
    
        input = torch.randn(
            B * T, V, device=device, dtype=dtype, requires_grad=True
        ).log_softmax(dim=-1)
    
        x1 = input.detach().clone().requires_grad_(True)
        x2 = input.detach().clone().requires_grad_(True)
    
        with torch.no_grad():
            target = torch.randn(B * T, V, device=device).softmax(dim=-1)
    
        output = torch_kldiv(x1, target)
        output2 = target_kldiv(x2, target)
>       assert torch.allclose(output, output2, atol=atol, rtol=rtol)
E       AssertionError: assert False
E        +  where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0',\n       grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06)
E        +    where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose

test/transformers/test_kl_div.py:75: AssertionError
______________________________ test_correctness_not_last[dtype1-1e-08-1e-06-none-False-32-4096-1024] _______________________________

B = 32, T = 4096, V = 1024, log_target = False, reduction = 'none', dtype = torch.float32, atol = 1e-08, rtol = 1e-06

    @pytest.mark.parametrize(*_SHAPE_PARAMS)
    @pytest.mark.parametrize("log_target", [True, False])
    @pytest.mark.parametrize("reduction", ["batchmean", "sum", "mean", "none"])
    @pytest.mark.parametrize(*_DTYPE_PARAMS)
    def test_correctness_not_last(B, T, V, log_target, reduction, dtype, atol, rtol):
        liger_kldiv = LigerKLDIVLoss(reduction=reduction, log_target=log_target)
>       _test_correctness_once(
            liger_kldiv,
            B,
            T,
            V,
            dtype,
            atol,
            rtol,
            reduction,
            log_target,
            is_last_layer=False,
        )

test/transformers/test_kl_div.py:108: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

target_kldiv = LigerKLDIVLoss(), B = 32, T = 4096, V = 1024, dtype = torch.float32, atol = 1e-08, rtol = 1e-06, reduction = 'none'
log_target = False, is_last_layer = False, device = 'cuda'

    def _test_correctness_once(
        target_kldiv,
        B,
        T,
        V,
        dtype,
        atol,
        rtol,
        reduction,
        log_target,
        is_last_layer=True,
        device="cuda",
    ):
        torch.manual_seed(0)
        torch_kldiv = KLDivLoss(reduction=reduction, log_target=log_target)
    
        input = torch.randn(
            B * T, V, device=device, dtype=dtype, requires_grad=True
        ).log_softmax(dim=-1)
    
        x1 = input.detach().clone().requires_grad_(True)
        x2 = input.detach().clone().requires_grad_(True)
    
        with torch.no_grad():
            target = torch.randn(B * T, V, device=device).softmax(dim=-1)
    
        output = torch_kldiv(x1, target)
        output2 = target_kldiv(x2, target)
>       assert torch.allclose(output, output2, atol=atol, rtol=rtol)
E       AssertionError: assert False
E        +  where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0',\n       grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06)
E        +    where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose

test/transformers/test_kl_div.py:75: AssertionError
_________________________________________________ test_import_custom_cache_manager _________________________________________________

    def test_import_custom_cache_manager():
        from triton.runtime.cache import get_cache_manager
    
        from liger_kernel.triton import apply_liger_triton_cache_manager
    
        apply_liger_triton_cache_manager()
>       cache_manager = get_cache_manager(key="test_hash")

test/triton/test_triton_monkey_patch.py:17: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/cache.py:277: in get_cache_manager
    return __cache_cls(_base64(key))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

key = 'test_hash'

    def _base64(key):
        # Assume key is a hex string.
>       return base64.urlsafe_b64encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")
E       ValueError: non-hexadecimal number found in fromhex() arg at position 0

/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/cache.py:261: ValueError
===================================================== short test summary info ======================================================
FAILED test/transformers/test_cross_entropy.py::test_correctness_with_ignore_index[10.0-dtype5-1e-08-1e-06-sum-2-4096-32000--100] - AssertionError: assert False
 +  where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3721e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0'), tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3722e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0'), atol=1e-08, rtol=1e-06)
 +    where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose
 +    and   tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3721e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0') = tensor([[  6.0503,   3.7258,  -0.3530,  ...,  11.8853,  20.5071,  -9.9739],\n        [ 15.2597,  -0.5924,   6.6471,  ...,  -9.3584,   3.0466,  -2.5966],\n        [-17.9122,  31.2363,  -1.4114,  ...,  -5.5268,  17.4033,  -3.3372],\n        ...,\n        [  4.3242,  -7.8904,  10.2973,  ..., -17.3829,  -1.2789,   6.6447],\n        [-10.9055,  10.4553,  -5.2270,  ..., -12.5100,   5.0782,  11.1050],\n        [ -5.8922,  15.0620,   5.5783,  ...,  -5.3107,   6.2329, -13.0452]],\n       device='cuda:0', requires_grad=True).grad
 +    and   tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3722e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0') = tensor([[4.0225e-16, 3.9353e-17, 6.6616e-19,  ..., 1.3759e-13, 7.6381e-10,\n         4.4185e-23],\n        [2.9569e-12, 3.8580e-19, 5.3756e-16,  ..., 6.0166e-23, 1.4681e-17,\n         5.1994e-20],\n        [4.7900e-26, 1.0599e-04, 7.0237e-19,  ..., 1.1461e-20, 1.0415e-10,\n         1.0237e-19],\n        ...,\n        [6.9540e-17, 3.4471e-22, 2.7309e-14,  ..., 2.5999e-26, 2.5635e-19,\n         7.0793e-16],\n        [6.3722e-23, 1.2054e-13, 1.8638e-20,  ..., 1.2807e-23, 5.5705e-16,\n         2.3085e-13],\n        [1.9623e-20, 2.4720e-11, 1.8808e-15,  ..., 3.5100e-20, 3.6195e-15,\n         1.5356e-23]], device='cuda:0', requires_grad=True).grad
FAILED test/transformers/test_jsd.py::test_correctness_with_beta[0.1-dtype1-1e-08-1e-06-1-4096-128256] - AssertionError: Number of mismatched elements: 1
Mismatch at index (): tensor1[()] = 0.08054989576339722, tensor2[()] = 0.08054977655410767
FAILED test/transformers/test_jsd.py::test_correctness_with_beta[0.9-dtype1-1e-08-1e-06-1-4096-128256] - AssertionError: Number of mismatched elements: 1
Mismatch at index (): tensor1[()] = 0.08054172992706299, tensor2[()] = 0.08054161071777344
FAILED test/transformers/test_kl_div.py::test_correctness[dtype1-1e-08-1e-06-none-False-32-4096-1024] - AssertionError: assert False
 +  where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0',\n       grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06)
 +    where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose
FAILED test/transformers/test_kl_div.py::test_correctness_not_last[dtype1-1e-08-1e-06-none-False-32-4096-1024] - AssertionError: assert False
 +  where False = <built-in method allclose of type object at 0x7035c99e82c0>(tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0', grad_fn=<SubBackward0>), tensor([[ 3.8871e-04,  1.5342e-03,  9.7731e-04,  ...,  1.5857e-04,\n          2.0651e-05, -2.0225e-04],\n        [ 3.0436e-04,  1.4040e-03, -1.4338e-04,  ..., -9.6487e-04,\n          3.6957e-04, -1.7970e-04],\n        [ 1.3870e-02,  1.8989e-03, -2.3409e-04,  ..., -9.2741e-05,\n         -2.1325e-03, -3.6861e-04],\n        ...,\n        [ 1.6965e-04,  7.5081e-04,  1.7243e-03,  ..., -3.3345e-04,\n          2.9291e-04,  4.6570e-03],\n        [-8.5313e-04,  5.1247e-04,  2.9434e-03,  ..., -1.6669e-04,\n          6.3304e-04,  8.2082e-04],\n        [-1.0297e-03, -5.9040e-05, -4.5201e-04,  ...,  1.1601e-03,\n          1.0437e-03,  2.4179e-04]], device='cuda:0',\n       grad_fn=<LigerKLDivLossFunctionBackward>), atol=1e-08, rtol=1e-06)
 +    where <built-in method allclose of type object at 0x7035c99e82c0> = torch.allclose
FAILED test/triton/test_triton_monkey_patch.py::test_import_custom_cache_manager - ValueError: non-hexadecimal number found in fromhex() arg at position 0
================================ 6 failed, 1012 passed, 8 skipped, 72 warnings in 630.02s (0:10:30) ================================
make: *** [Makefile:8: test] Error 1
</details>

@tjtanaa tjtanaa mentioned this pull request Oct 27, 2024
@@ -194,7 +194,7 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
BLOCK_SIZE=BLOCK_SIZE,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
num_warps=32,
num_warps=32 if not is_hip() else 16,
Copy link

@Edenzzzz Edenzzzz Oct 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No a hardware expert, but can we perhaps benchmark num_warps=8 a bit as in vllm or autotune [32, 16] for NV and [16, 8] for AMD (in case of register spilling etc.)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added an extensive parameter search space sweep for layer norm, yet, it still cannot outperform huggingface at smaller dimension space

Added Code


def get_amd_triton_config_list():

    waves_per_eu = [0, 1, 2]
    matrix_instr_nonkdim = [16, 32]
    num_stages=[0,1,2]
    num_warps=[4, 8, 16]

    config_list = []

    for wpe in waves_per_eu:
        for kdim in matrix_instr_nonkdim:
            for ns in num_stages:
                for nw in num_warps:
                    config_list.append(
                        triton.Config(
                            {
                                "waves_per_eu": wpe,
                                "matrix_instr_nonkdim": kdim,
                            },
                            num_stages=ns,
                            num_warps=nw,
                        )
                    ) 
    return config_list


@triton.autotune(
    configs=get_amd_triton_config_list(),
    key=["BLOCK_SIZE"]
)

Benchmark results

 OPTIMIZE_EPILOGUE=1 TRITON_PRINT_AUTOTUNING=1 python scripts/benc
hmark_layer_norm.py 
**************************************
     BENCHMARKING SPEED for LAYER_NORM
**************************************
Triton autotuning for function _layer_norm_forward_kernel finished after 12.43s; best config selected: waves_per_eu: 1, matrix_instr_nonkdim: 16, num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None;
Triton autotuning for function _layer_norm_forward_kernel finished after 12.24s; best config selected: waves_per_eu: 0, matrix_instr_nonkdim: 32, num_warps: 4, num_ctas: 1, num_stages: 2, maxnreg: None;
Triton autotuning for function _layer_norm_forward_kernel finished after 12.22s; best config selected: waves_per_eu: 2, matrix_instr_nonkdim: 32, num_warps: 8, num_ctas: 1, num_stages: 0, maxnreg: None;
Triton autotuning for function _layer_norm_forward_kernel finished after 12.17s; best config selected: waves_per_eu: 0, matrix_instr_nonkdim: 16, num_warps: 8, num_ctas: 1, num_stages: 0, maxnreg: None;
Triton autotuning for function _layer_norm_forward_kernel finished after 12.54s; best config selected: waves_per_eu: 1, matrix_instr_nonkdim: 16, num_warps: 8, num_ctas: 1, num_stages: 1, maxnreg: None;
Triton autotuning for function _layer_norm_backward_kernel finished after 12.19s; best config selected: waves_per_eu: 2, matrix_instr_nonkdim: 32, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None;
Triton autotuning for function _layer_norm_backward_kernel finished after 12.11s; best config selected: waves_per_eu: 2, matrix_instr_nonkdim: 32, num_warps: 8, num_ctas: 1, num_stages: 0, maxnreg: None;
Triton autotuning for function _layer_norm_backward_kernel finished after 12.11s; best config selected: waves_per_eu: 0, matrix_instr_nonkdim: 32, num_warps: 16, num_ctas: 1, num_stages: 0, maxnreg: None;
Triton autotuning for function _layer_norm_backward_kernel finished after 12.51s; best config selected: waves_per_eu: 1, matrix_instr_nonkdim: 32, num_warps: 8, num_ctas: 1, num_stages: 0, maxnreg: None;
Triton autotuning for function _layer_norm_backward_kernel finished after 12.70s; best config selected: waves_per_eu: 1, matrix_instr_nonkdim: 16, num_warps: 8, num_ctas: 1, num_stages: 0, maxnreg: None;
********** Benchmark Data **********
[
  {
    "kernel_name": "layer_norm",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "AMD Instinct MI300X",
    "x_name": "N",
    "x_label": "hidden size",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      0.05548600107431412,
      0.0498029999434948,
      0.060717999935150146,
      0.09918499737977982,
      0.17343400418758392
    ],
    "y_values_20": [
      0.05348199978470802,
      0.047047000378370285,
      0.05381200090050697,
      0.09738200157880783,
      0.16993799805641174
    ],
    "y_values_80": [
      0.058132000267505646,
      0.05293000116944313,
      0.08238700032234192,
      0.10034800320863724,
      0.17636018991470337
    ],
    "timestamp": "2024-10-28 04:43:07",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"M\": 4096, \"dtype\": \"torch.float32\", \"eps\": 1e-06}",
    "liger_version": "0.3.1"
  },
  {
    "kernel_name": "layer_norm",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "AMD Instinct MI300X",
    "x_name": "N",
    "x_label": "hidden size",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      0.02472599968314171,
      0.03308499976992607,
      0.05716999992728233,
      0.11405900120735168,
      0.22450999915599823
    ],
    "y_values_20": [
      0.023934999480843544,
      0.0322830006480217,
      0.05523499846458435,
      0.11289700120687485,
      0.22289039194583893
    ],
    "y_values_80": [
      0.026359200477600098,
      0.06141600012779236,
      0.05879399925470352,
      0.11538200080394745,
      0.22627399861812592
    ],
    "timestamp": "2024-10-28 04:43:10",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"M\": 4096, \"dtype\": \"torch.float32\", \"eps\": 1e-06}",
    "liger_version": "0.3.1"
  },
  {
    "kernel_name": "layer_norm",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "AMD Instinct MI300X",
    "x_name": "N",
    "x_label": "hidden size",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      0.3861970007419586,
      0.9399949908256531,
      0.9476320147514343,
      1.0064010620117188,
      1.017171025276184
    ],
    "y_values_20": [
      0.3667530119419098,
      0.8674409985542297,
      0.6628599762916565,
      0.855912983417511,
      0.8749900460243225
    ],
    "y_values_80": [
      0.422760009765625,
      0.9504649639129639,
      0.9593693614006042,
      1.0176535844802856,
      1.035987138748169
    ],
    "timestamp": "2024-10-28 04:44:14",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"M\": 4096, \"dtype\": \"torch.float32\", \"eps\": 1e-06}",
    "liger_version": "0.3.1"
  },
  {
    "kernel_name": "layer_norm",
    "kernel_provider": "huggingface",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "AMD Instinct MI300X",
    "x_name": "N",
    "x_label": "hidden size",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      0.3276045024394989,
      0.3255690038204193,
      0.34742000699043274,
      0.4774940013885498,
      0.9882450103759766
    ],
    "y_values_20": [
      0.32131001353263855,
      0.32023200392723083,
      0.34053999185562134,
      0.4757609963417053,
      0.9819509983062744
    ],
    "y_values_80": [
      0.33827799558639526,
      0.3349609971046448,
      0.3595069944858551,
      0.4794589877128601,
      0.9944999814033508
    ],
    "timestamp": "2024-10-28 04:44:17",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"M\": 4096, \"dtype\": \"torch.float32\", \"eps\": 1e-06}",
    "liger_version": "0.3.1"
  }
]
**************************************
     BENCHMARKING MEMORY for LAYER_NORM
**************************************
********** Benchmark Data **********
[
  {
    "kernel_name": "layer_norm",
    "kernel_provider": "liger",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "AMD Instinct MI300X",
    "x_name": "N",
    "x_label": "hidden size",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      82.4375,
      164.84375,
      329.65625,
      659.28125,
      1320.53125
    ],
    "y_values_20": [
      82.4375,
      164.84375,
      329.65625,
      659.28125,
      1320.53125
    ],
    "y_values_80": [
      82.4375,
      164.84375,
      329.65625,
      659.28125,
      1320.53125
    ],
    "timestamp": "2024-10-28 04:44:17",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"M\": 4096, \"dtype\": \"torch.float32\", \"eps\": 1e-06}",
    "liger_version": "0.3.1"
  },
  {
    "kernel_name": "layer_norm",
    "kernel_provider": "huggingface",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "AMD Instinct MI300X",
    "x_name": "N",
    "x_label": "hidden size",
    "x_values": [
      1024,
      2048,
      4096,
      8192,
      16384
    ],
    "y_values_50": [
      80.5625,
      161.09375,
      322.15625,
      644.28125,
      1288.53125
    ],
    "y_values_20": [
      80.5625,
      161.09375,
      322.15625,
      644.28125,
      1288.53125
    ],
    "y_values_80": [
      80.5625,
      161.09375,
      322.15625,
      644.28125,
      1288.53125
    ],
    "timestamp": "2024-10-28 04:44:17",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"M\": 4096, \"dtype\": \"torch.float32\", \"eps\": 1e-06}",
    "liger_version": "0.3.1"
  }
]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no noticeable performance gain in autotuning. So I would suggest to keep things simple, just set the num_warps to be 16 is sufficient for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants