Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: correct casts in RMSNorm to match references #92

Merged
merged 7 commits into from
Aug 28, 2024

Conversation

davidgonmar
Copy link
Contributor

@davidgonmar davidgonmar commented Aug 26, 2024

Summary

Aims to fix #89.

Details

Does the casts to float32 at the correct places to match the Gemma and Llama references. Does so both in the forward and backward passes.
Also modified the tests for RMSNorm with tighter tolerances + fp16 tests.

Testing Done

Ran tests for convergence and RMSNorm.

test/convergence/test_mini_models.py ........                            [100%] |
                                                                                |
========================= 8 passed in 78.70s (0:01:18) =========================

test/transformers/test_rms_norm.py ................................................                                                                                                                                                                            [100%]

========================================================================================================================= 48 passed in 4.62s =========================================================================================================================

  • Hardware Type: NVIDIA L4
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@davidgonmar davidgonmar changed the title feat: correct casts in rms to match references feat: correct casts in RMSNorm to match references Aug 26, 2024
@davidgonmar
Copy link
Contributor Author

I aim to exactly match the references in terms of dtypes. As you can see, in order to do so (especially on the backward passes), the complexity is considerably increased. We also need to store some buffers that were stored previously in lower precision in fp32 now. I am not sure if the tradeoff is worth it (especially in the backward pass, it is more complex), and if it would be better to just not follow the references one-to-one in terms of casting (and storing the inv_rms cached buffer). Let me know what parts you think are worth it to keep close to the reference.

@winglian
Copy link

Would something like kahan summation help with the lower precision accumulation?

@davidgonmar
Copy link
Contributor Author

davidgonmar commented Aug 26, 2024

Would something like kahan summation help with the lower precision accumulation?

Hadn't heard of it before. Just did a quick search, and if I am not mistaken should be implementable in triton with the generic reduce function. However, given that the relative difference is usually very low (and accum is done in fp32), it might not be worth the hassle. Definitely an interesting idea though.

@ByronHsu
Copy link
Collaborator

you can verify the perf using benchmark/

@davidgonmar
Copy link
Contributor Author

davidgonmar commented Aug 26, 2024

Benchmarks on NVIDIA L4:

Forward Speed Benchmark

This branch

N Liger Hugging Face
1024.0 0.043008 0.153600
2048.0 0.077824 0.283648
4096.0 0.146432 0.886784
8192.0 0.289792 2.652160
16384.0 0.577536 5.429328
32768.0 1.250304 10.841600

Master

N Liger Hugging Face
1024.0 0.043008 0.153600
2048.0 0.075776 0.283648
4096.0 0.148480 0.886784
8192.0 0.289792 2.651136
16384.0 0.585728 5.428224
32768.0 1.161216 10.841584

Backward Speed Benchmark

This branch

N Liger Hugging Face
1024.0 0.111616 0.396288
2048.0 0.196608 1.313792
4096.0 0.444416 3.471360
8192.0 1.043616 8.110064
16384.0 2.098064 16.336800
32768.0 5.965824 32.528385

Master

N Liger Hugging Face
1024.0 0.096256 0.396288
2048.0 0.167936 1.311776
4096.0 0.321536 3.467776
8192.0 0.834560 8.117248
16384.0 1.750016 16.328705
32768.0 5.246816 32.542721

Full Speed Benchmark

This branch

N Liger Hugging Face
1024.0 0.128000 0.477184
2048.0 0.232448 1.593344
4096.0 0.574464 4.336640
8192.0 1.320960 10.767360
16384.0 2.674688 21.747711
32768.0 7.201792 43.365376

Master

N Liger Hugging Face
1024.0 0.114688 0.477184
2048.0 0.202752 1.593232
4096.0 0.456704 4.343808
8192.0 1.120256 10.764288
16384.0 2.332736 21.744576
32768.0 6.358048 43.321857

Full Memory Benchmark

This branch

N Liger Hugging Face
1024.0 36.023535 79.619531
2048.0 72.038770 159.231250
4096.0 124.069238 318.454687
8192.0 204.130176 636.901562
16384.0 368.252051 1273.795312
32768.0 688.496289 2547.582812

Master

N Liger Hugging Face
1024.0 32.017676 79.619531
2048.0 64.030957 159.231250
4096.0 108.057520 318.454687
8192.0 172.110645 636.901562
16384.0 304.216895 1273.795312
32768.0 560.429883 2547.582812

@davidgonmar
Copy link
Contributor Author

you can verify the perf using benchmark/

Just benchmarked it. As expected, the memory usage/speed is slightly worse with this changes. The memory is due to storing the cached rms norm in float32, and speed due to computing/reducing in float32. On Gemma, it will be even more memory usage since weight grads will be stored in fp32 (and then maybe casted back to lower precision). All of this is needed to match the reference.
There are some options:

  • Do not do this, and only maintain the older behaviour with native dtypes (no cast, no fp32 storage when needed), at the cost of precision.
  • Keep this, but by default, do not cast and compute/store everything in their native dtypes (so by default use what was in master, and if user needs more precision, they can specify the mode)
  • Keep this, same as before, but use precise mode by default (slower) and let user specify if they want no casts.

I'd go with the second one if you are willing to have the complexity in the kernels. What do you think?

@ByronHsu
Copy link
Collaborator

i prefer the third. exactness is a deal breaker

@ByronHsu ByronHsu mentioned this pull request Aug 26, 2024
3 tasks
@ByronHsu
Copy link
Collaborator

run all tests again, if pass we can merge

@davidgonmar
Copy link
Contributor Author

run all tests again, if pass we can merge

Done. Also made float16 a bit less tight (to the same level as bfloat) since a very low percentage of elements was failing sometimes.

@davidgonmar davidgonmar marked this pull request as ready for review August 26, 2024 22:07
@lancerts
Copy link
Collaborator

@davidgonmar can you resolve the conflicts? Thanks

src/liger_kernel/ops/rms_norm.py Show resolved Hide resolved
src/liger_kernel/ops/rms_norm.py Outdated Show resolved Hide resolved
src/liger_kernel/ops/rms_norm.py Outdated Show resolved Hide resolved
src/liger_kernel/ops/rms_norm.py Outdated Show resolved Hide resolved
test/transformers/test_rms_norm.py Show resolved Hide resolved
src/liger_kernel/ops/rms_norm.py Outdated Show resolved Hide resolved
@yundai424
Copy link
Collaborator

Code logic LGTM! Some minor nit comments, the remaining issue is triton 2.3.0 compatibility (the way we use flag doesn't work there) and test env var

@yundai424
Copy link
Collaborator

LGTM! checking why the CI is failing

@yundai424
Copy link
Collaborator

There seems to be some issue with the GPU CI setup, temporarily making it optional. @davidgonmar could you help rebasing on the latest main branch and we're good to go!

@davidgonmar
Copy link
Contributor Author

There seems to be some issue with the GPU CI setup, temporarily making it optional. @davidgonmar could you help rebasing on the latest main branch and we're good to go!

done!

@yundai424 yundai424 merged commit e7c2505 into linkedin:main Aug 28, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fix Dtype Issue of Gemma RMSNorm
5 participants