Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure In-place correctness checks work properly #273

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

Tcc0403
Copy link
Contributor

@Tcc0403 Tcc0403 commented Sep 26, 2024

Summary

Fix #272

It's a show case of how to trigger error properly.
I only apply it to cross_entropy for demonstration, can apply to others if we want.

Testing Done

same gist as the issue's

import torch
import torch.nn.functional as F

from liger_kernel.transformers.functional import liger_cross_entropy


def run_inplace_experiment(logits_p, logits_q, cross_entropy_fn):
    _p = logits_p.clone().detach().requires_grad_(True)
    _p.retain_grad()
    softmax = torch.nn.Softmax(dim=-1)
    p = softmax(_p)
    p.retain_grad()
    loss = cross_entropy_fn(p, logits_q)
    loss.backward(retain_graph=True)

    print(f"Cross Entropy Loss: {loss.item()}")
    print(f"Input _p: {_p}")
    print(f"Input logits_q: {logits_q}")
    print(f"Gradients of p (batch item 0): {p.grad[0]}")
    print(f"Gradients of _p (batch item 0): {_p.grad[0]}")


torch.manual_seed(0)
logits_p = torch.randn(8, 8, requires_grad=True, device="cuda")
logits_q = torch.randint(0, 8, (8,), device="cuda", dtype=torch.long)


run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=F.cross_entropy)

print()
print("LIGER:")
run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=liger_cross_entropy)

Properly raised the error

❯ python3 inplace_bug.py
Cross Entropy Loss: 2.08567214012146
Input _p: tensor([[-0.9247, -0.4253, -2.6438,  0.1452, -0.1209, -0.5797, -0.6229, -0.3284],
        [-1.0745, -0.3631, -1.6711,  2.2655,  0.3117, -0.1842,  1.2866,  1.1820],
        [-0.1271,  1.2169,  1.4353,  1.0605, -0.4941, -1.4244, -0.7244, -1.2973],
        [ 0.0697, -0.0074,  1.8969,  0.6878, -0.0779, -0.8373,  1.3506, -0.2879],
        [-0.5965, -0.3283, -0.9086, -0.8059, -0.7407, -0.0504,  0.5435,  1.5150],
        [ 0.0141,  0.4532,  1.6349,  0.7124, -0.1806,  1.0252, -1.4622, -0.7554],
        [-0.1836,  0.3824,  0.3918, -0.0830,  0.8971, -1.1123,  0.1116,  0.4863],
        [-0.5499, -0.3231, -0.5469,  0.9049,  0.2837,  0.1210,  0.4730, -1.0823]],
       device='cuda:0', requires_grad=True)
Input logits_q: tensor([4, 6, 7, 2, 2, 6, 5, 5], device='cuda:0')
Gradients of p (batch item 0): tensor([ 0.0149,  0.0157,  0.0140,  0.0174, -0.1086,  0.0154,  0.0153,  0.0159],
       device='cuda:0')
Gradients of _p (batch item 0): tensor([ 0.0017,  0.0029,  0.0003,  0.0055, -0.0182,  0.0024,  0.0023,  0.0032],
       device='cuda:0')

LIGER:
Traceback (most recent call last):
  File "/home/tcc/Liger-Kernel/inplace_bug.py", line 36, in <module>
    run_inplace_experiment(logits_p, logits_q, cross_entropy_fn=liger_cross_entropy)
  File "/home/tcc/Liger-Kernel/inplace_bug.py", line 18, in run_inplace_experiment
    loss.backward(retain_graph=True)
  File "/home/tcc/Liger-Kernel/.venv/lib/python3.10/site-packages/torch/_tensor.py", line 521, in backward
    torch.autograd.backward(
  File "/home/tcc/Liger-Kernel/.venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 289, in backward
    _engine_run_backward(
  File "/home/tcc/Liger-Kernel/.venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 769, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [8, 8]], which is output 0 of SoftmaxBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@@ -213,6 +213,8 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reducti
target = target.contiguous()

# Here we use a trick to store X_ptr gradient in X_ptr so we can save memory
# explicitly declare in-place operation is performed
_input.add_(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain a bit on what this is doing exactly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How torch checks whether an in-place op on a tensor would result in an incorrect gradient calculation, is by implicitly tracking the "version" of a tensor and comparing the saved version and afterward version.

When an in-place operation is performed on a tensor, the version of the tensor is incremented by 1, achieved by an internal function bump() written in C.

Since the bump() function is only called when doing "torch" in-place operation, i.e. in-place operations in "triton kernel" cannot trigger bump(), which makes torch lose track of the version and unable to raise an error.

This approach is a hint for torch, by manually performing a torch's in-place op when we do that tensor dirty in triton kernel.

Reference:
torch in-place-correctness-checks
https://pytorch.org/docs/stable/autograd.html#in-place-correctness-checks
version tracking, bump (there's a note above the function block)
https://github.com/pytorch/pytorch/blob/190e09d8b6a13f789b143f0fbd1325f924550967/c10/core/TensorImpl.h#L382

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this operation have any performance impact? cc @ByronHsu

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

impact is huge for add_(0). will try some other inplace ops.

@Tcc0403
Copy link
Contributor Author

Tcc0403 commented Oct 2, 2024

bench
i did some benchmarks on H100, adding any torch's inplace op increases time cost by roughtly 50%
(original -> with_hint = 23 -> 32 ms for 128k vocab size).

so i guess its not worth it?
full stdout:

**************************************
     BENCHMARKING SPEED for CROSS_ENTROPY
**************************************
********** Benchmark Data **********
[
  {
    "kernel_name": "cross_entropy",
    "kernel_provider": "hint",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "NVIDIA H100 PCIe",
    "x_name": "V",
    "x_label": "vocab size",
    "x_values": [
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      0.7055680155754089,
      1.253440022468567,
      2.433199882507324,
      4.869984149932861,
      10.520671844482422,
      23.0479679107666
    ],
    "y_values_20": [
      0.7003520131111145,
      1.2493120431900024,
      2.4296703338623047,
      4.865350246429443,
      10.509568214416504,
      23.046571731567383
    ],
    "y_values_80": [
      0.7126911878585815,
      1.2599040269851685,
      2.4357247352600098,
      4.873280048370361,
      10.537690162658691,
      23.04932403564453
    ],
    "timestamp": "2024-10-02 23:21:11",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048}",
    "liger_version": "0.3.1"
  },
  {
    "kernel_name": "cross_entropy",
    "kernel_provider": "original",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "NVIDIA H100 PCIe",
    "x_name": "V",
    "x_label": "vocab size",
    "x_values": [
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      0.41332799196243286,
      0.6783679723739624,
      1.2743680477142334,
      2.535327911376953,
      5.867008209228516,
      13.692416191101074
    ],
    "y_values_20": [
      0.41091200709342957,
      0.6760640144348145,
      1.2711039781570435,
      2.5320703983306885,
      5.845632076263428,
      13.691308975219727
    ],
    "y_values_80": [
      0.41729921102523804,
      0.6832832098007202,
      1.2798080444335938,
      2.539724826812744,
      5.877439975738525,
      13.695513725280762
    ],
    "timestamp": "2024-10-02 23:21:12",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048}",
    "liger_version": "0.3.1"
  },
  {
    "kernel_name": "cross_entropy",
    "kernel_provider": "hint",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "NVIDIA H100 PCIe",
    "x_name": "V",
    "x_label": "vocab size",
    "x_values": [
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      1.1566720008850098,
      2.042720079421997,
      3.7615039348602295,
      7.380864143371582,
      15.551360130310059,
      32.693214416503906
    ],
    "y_values_20": [
      1.1196672916412354,
      2.026726245880127,
      3.7539713382720947,
      7.370649337768555,
      15.547072410583496,
      32.686397552490234
    ],
    "y_values_80": [
      1.1723840236663818,
      2.0592191219329834,
      3.7860095500946045,
      7.387167930603027,
      15.649503707885742,
      32.69830322265625
    ],
    "timestamp": "2024-10-02 23:21:13",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048}",
    "liger_version": "0.3.1"
  },
  {
    "kernel_name": "cross_entropy",
    "kernel_provider": "original",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "NVIDIA H100 PCIe",
    "x_name": "V",
    "x_label": "vocab size",
    "x_values": [
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      0.8923839926719666,
      1.45249605178833,
      2.6448960304260254,
      5.078847885131836,
      10.8754243850708,
      23.530689239501953
    ],
    "y_values_20": [
      0.8872384428977966,
      1.4472639560699463,
      2.63141131401062,
      5.075232028961182,
      10.859647750854492,
      23.527257919311523
    ],
    "y_values_80": [
      0.9067007899284363,
      1.4663935899734497,
      2.6562368869781494,
      5.088294506072998,
      11.039955139160156,
      23.535194396972656
    ],
    "timestamp": "2024-10-02 23:21:14",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048}",
    "liger_version": "0.3.1"
  }
]
**************************************
     BENCHMARKING MEMORY for CROSS_ENTROPY
**************************************
********** Benchmark Data **********
[
  {
    "kernel_name": "cross_entropy",
    "kernel_provider": "hint",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "NVIDIA H100 PCIe",
    "x_name": "V",
    "x_label": "vocab size",
    "x_values": [
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      256.32861328125,
      512.32861328125,
      1024.32861328125,
      2048.32861328125,
      4096.32861328125,
      8192.328125
    ],
    "y_values_20": [
      256.32861328125,
      512.32861328125,
      1024.32861328125,
      2048.32861328125,
      4096.32861328125,
      8192.328125
    ],
    "y_values_80": [
      256.32861328125,
      512.32861328125,
      1024.32861328125,
      2048.32861328125,
      4096.32861328125,
      8192.328125
    ],
    "timestamp": "2024-10-02 23:21:15",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048}",
    "liger_version": "0.3.1"
  },
  {
    "kernel_name": "cross_entropy",
    "kernel_provider": "original",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "NVIDIA H100 PCIe",
    "x_name": "V",
    "x_label": "vocab size",
    "x_values": [
      4096,
      8192,
      16384,
      32768,
      65536,
      131072
    ],
    "y_values_50": [
      256.32861328125,
      512.32861328125,
      1024.32861328125,
      2048.32861328125,
      4096.32861328125,
      8192.328125
    ],
    "y_values_20": [
      256.32861328125,
      512.32861328125,
      1024.32861328125,
      2048.32861328125,
      4096.32861328125,
      8192.328125
    ],
    "y_values_80": [
      256.32861328125,
      512.32861328125,
      1024.32861328125,
      2048.32861328125,
      4096.32861328125,
      8192.328125
    ],
    "timestamp": "2024-10-02 23:21:15",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"B\": 8, \"T\": 2048}",
    "liger_version": "0.3.1"
  }
]

@lancerts
Copy link
Collaborator

lancerts commented Oct 3, 2024

Yeah, i was looking if we can call bump() from python... 50% cost does not worth ..

@ByronHsu
Copy link
Collaborator

ByronHsu commented Oct 3, 2024

I am wondering why the error does not happen for normal case?

@Tcc0403
Copy link
Contributor Author

Tcc0403 commented Oct 4, 2024

@ByronHsu

I am wondering why the error does not happen for normal case?

I left an explanation in issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

In-place operations in triton kernel might result in incorrect gradient calculations
3 participants