Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError due to dtype mismatch in fused_linear_cross_entropy_forward #305

Closed
kostum123 opened this issue Oct 12, 2024 · 5 comments · Fixed by #318
Closed

RuntimeError due to dtype mismatch in fused_linear_cross_entropy_forward #305

kostum123 opened this issue Oct 12, 2024 · 5 comments · Fixed by #318
Labels
bug Something isn't working

Comments

@kostum123
Copy link

kostum123 commented Oct 12, 2024

🐛 Describe the bug

I encountered a RuntimeError while running a full fine-tuning experiment using the LLaMA-Factory on a model with BFloat16 precision. The error occurred during the training process when executing the fused_linear_cross_entropy_forward operation. The error traceback indicates a mismatch in data types between mat1 and mat2, specifically BFloat16 and Float. The models used were qwen2.5 3b and llama3.2 3b.

Error Log

0% 0/1376 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/usr/local/bin/llamafactory-cli", line 8, in <module>
    sys.exit(main())
  File "/content/LLaMA-Factory/src/llamafactory/cli.py", line 111, in main
    run_exp()
  File "/content/LLaMA-Factory/src/llamafactory/train/tuner.py", line 50, in run_exp
    run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
  File "/content/LLaMA-Factory/src/llamafactory/train/sft/workflow.py", line 96, in run_sft
    train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2052, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2388, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3485, in training_step
    loss = self.compute_loss(model, inputs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3532, in compute_loss
    outputs = model(**inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 820, in forward
    return model_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 808, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
  File "/content/LLaMA-Factory/Liger-Kernel/src/liger_kernel/transformers/model/qwen2.py", line 108, in lce_forward
    loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/content/LLaMA-Factory/Liger-Kernel/src/liger_kernel/transformers/fused_linear_cross_entropy.py", line 13, in forward
    return LigerFusedLinearCrossEntropyFunction.apply(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/content/LLaMA-Factory/Liger-Kernel/src/liger_kernel/ops/fused_linear_cross_entropy.py", line 221, in forward
    loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
  File "/content/LLaMA-Factory/Liger-Kernel/src/liger_kernel/ops/fused_linear_cross_entropy.py", line 122, in fused_linear_cross_entropy_forward
    torch.addmm(
RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float

### Reproduce

Steps to Reproduce
Use Colab with A100 40GB.
Run the full fine-tuning experiment with the LLaMA-Factory on a model with BFloat16 precision.
Observe the error during the training process.
Expected Behavior
The training process should execute without encountering a RuntimeError due to dtype mismatch.

Temporary Fix
Comment out the line causing the error in the fused_linear_cross_entropy_forward function located in src/liger_kernel/ops/fused_linear_cross_entropy.py. Line 101:   logits_chunk = logits_chunk.to(dtype)

Versions

Main

kostum123 added a commit to kostum123/Liger-Kernel that referenced this issue Oct 12, 2024
Fixes linkedin#305

Fix dtype mismatch in fused_linear_cross_entropy_forward function.

* Cast `logits_chunk` to the data type of `_input_chunk` before performing operations on it.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/linkedin/Liger-Kernel/issues/305?shareId=XXXX-XXXX-XXXX-XXXX).
@gotzmann
Copy link

I had the same problem when trying to train lm_head layer of LLaMA.

@ByronHsu ByronHsu added the bug Something isn't working label Oct 21, 2024
@yundai424
Copy link
Collaborator

yundai424 commented Oct 21, 2024

following up on my previous comment in the attempt PR by @kostum123 , I feel the issue here is we're missing torch.amp.custom_fwd/custom_bwd for our custom torch autograd function. @kostum123 @gotzmann could either of you provide a reproducible example so I can test my fix on it? Thanks a lot!

@yundai424 yundai424 mentioned this issue Oct 21, 2024
3 tasks
@kostum123
Copy link
Author

I have been busy lately but I will test the new fix you provided and let you know if it solves the issue. I closed the old PR since it was just a temporary solution. @yundai424

@fzyzcjy
Copy link

fzyzcjy commented Oct 24, 2024

+1 same issue here for llama3.2 1B + Trainer

01:28:45.122]:   File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 3532, in compute_loss
[01:28:45.122]:     outputs = model(**inputs)
[01:28:45.122]:               ^^^^^^^^^^^^^^^
[01:28:45.122]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[01:28:45.122]:     return self._call_impl(*args, **kwargs)
[01:28:45.122]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[01:28:45.122]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[01:28:45.122]:     return forward_call(*args, **kwargs)
[01:28:45.122]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[01:28:45.122]:   File "/opt/conda/lib/python3.11/site-packages/accelerate/utils/operations.py", line 820, in forward
[01:28:45.122]:     return model_forward(*args, **kwargs)
[01:28:45.122]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[01:28:45.122]:   File "/opt/conda/lib/python3.11/site-packages/accelerate/utils/operations.py", line 808, in __call__
[01:28:45.122]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[01:28:45.122]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[01:28:45.122]:   File "/opt/conda/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[01:28:45.122]:     return func(*args, **kwargs)
[01:28:45.122]:            ^^^^^^^^^^^^^^^^^^^^^
[01:28:45.122]:   File "/opt/conda/lib/python3.11/site-packages/liger_kernel/transformers/model/llama.py", line 109, in lce_forward
[01:28:45.122]:     loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
[01:28:45.122]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[01:28:45.122]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[01:28:45.122]:     return self._call_impl(*args, **kwargs)
[01:28:45.122]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[01:28:45.122]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[01:28:45.123]:     return forward_call(*args, **kwargs)
[01:28:45.123]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[01:28:45.123]:   File "/opt/conda/lib/python3.11/site-packages/liger_kernel/transformers/fused_linear_cross_entropy.py", line 13, in forward
[01:28:45.123]:     return LigerFusedLinearCrossEntropyFunction.apply(
[01:28:45.123]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[01:28:45.123]:   File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 574, in apply
[01:28:45.123]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[01:28:45.123]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[01:28:45.123]:   File "/opt/conda/lib/python3.11/site-packages/liger_kernel/ops/fused_linear_cross_entropy.py", line 221, in forward
[01:28:45.123]:     loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
[01:28:45.123]:                                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[01:28:45.123]:   File "/opt/conda/lib/python3.11/site-packages/liger_kernel/ops/fused_linear_cross_entropy.py", line 122, in fused_linear_cross_entropy_forward
[01:28:45.123]:     torch.addmm(
[01:28:45.123]: RuntimeError: mat1 and mat2 must have the same dtype, but got BFloat16 and Float
╭─────────────────────────────── Traceback (most recent call last) ───────────────────────────────

@ByronHsu
Copy link
Collaborator

#318

we just merged the change. can you try downloading liger-kernel-nightly to test the fix? if pass, we will release a new version. thanks @yundai424 for the fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
5 participants