Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bias for fused linear cross entropy #144

Merged
merged 6 commits into from
Aug 28, 2024

Conversation

davidgonmar
Copy link
Contributor

@davidgonmar davidgonmar commented Aug 28, 2024

Summary

Adds optional bias param for fused linear cross entropy!
Added bias = {true, false} to the testing space.
Also changed weight/bias generation in tests to uniform rand instead of normal (seems stabler for low precision bfloat16).

Testing Done

Results

test/transformers/test_fused_linear_cross_entropy.py ............                                                                                                                                                                                              [100%]

======================================================================================================================== 12 passed in 31.61s =========================================================================================================================
  • Hardware Type: NVIDIA L4
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@davidgonmar davidgonmar marked this pull request as ready for review August 28, 2024 18:31
Copy link
Collaborator

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! may you test the perf in benchmark/?

@davidgonmar
Copy link
Contributor Author

davidgonmar commented Aug 28, 2024

Benchmarks on L40

Fused Linear Cross-Entropy Bias Memory Benchmark

BT Liger Hugging Face
4096.0 4135.732178 5988.779150
8192.0 4361.197803 9056.422900
16384.0 4797.391553 15190.085400
32768.0 5669.779053 27457.410400

Fused Linear Cross-Entropy Memory Benchmark

BT Liger Hugging Face
4096.0 4142.146875 5988.882227
8192.0 4360.243750 9055.713477
16384.0 4796.437500 15189.375977
32768.0 5668.825000 27456.700977

Fused Linear Cross-Entropy Fwd Speed Benchmark

BT Liger Hugging Face
4096.0 220.761093 22.916607
8192.0 257.770508 46.697983
16384.0 370.574341 94.603264
32768.0 640.113647 190.658554

Fused Linear Cross-Entropy Bias Fwd Speed Benchmark

BT Liger Hugging Face
4096.0 222.580734 20.622849
8192.0 266.255371 48.650703
16384.0 375.865356 92.753922
32768.0 639.733765 196.055038

Fused Linear Cross-Entropy Full Speed Benchmark

BT Liger Hugging Face
4096.0 225.852417 71.114754
8192.0 263.044098 136.209412
16384.0 369.807373 289.680389
32768.0 639.706177 586.910706

Fused Linear Cross-Entropy Bias Full Speed Benchmark

BT Liger Hugging Face
4096.0 227.635193 71.284737
8192.0 271.443970 144.418823
16384.0 385.181702 294.692871
32768.0 658.874390 588.967957

@ByronHsu
Copy link
Collaborator

make all
python -m pytest --disable-warnings test/ --ignore=test/convergence
============================================================================================================================================= test session starts ==============================================================================================================================================
platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/jobuser/resources/Liger-Kernel
plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191
collected 147 items                                                                                                                                                                                                                                                                                            

test/transformers/test_auto_model.py .                                                                                                                                                                                                                                                                   [  0%]
test/transformers/test_cross_entropy.py ..........................................................                                                                                                                                                                                                       [ 40%]
test/transformers/test_fused_linear_cross_entropy.py ............                                                                                                                                                                                                                                        [ 48%]
test/transformers/test_geglu.py ........                                                                                                                                                                                                                                                                 [ 53%]
test/transformers/test_monkey_patch.py .....                                                                                                                                                                                                                                                             [ 57%]
test/transformers/test_rms_norm.py ................................                                                                                                                                                                                                                                      [ 78%]
test/transformers/test_rope.py ............                                                                                                                                                                                                                                                              [ 87%]
test/transformers/test_swiglu.py ................                                                                                                                                                                                                                                                        [ 97%]
test/transformers/test_trainer_integration.py .                                                                                                                                                                                                                                                          [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                                                                                                                                                                                               [100%]

======================================================================================================================================== 147 passed in 72.19s (0:01:12) ========================================================================================================================================
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence
============================================================================================================================================= test session starts ==============================================================================================================================================
platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/jobuser/resources/Liger-Kernel
plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191
collected 28 items                                                                                                                                                                                                                                                                                             

test/convergence/test_mini_models.py ..............                                                                                                                                                                                                                                                      [ 50%]
test/convergence/test_mini_models_no_logits.py ..............                                                                                                                                                                                                                                            [100%]

======================================================================================================================================== 28 passed in 161.75s (0:02:41) ========================================================================================================================================
flake8 .; flake8_status=$?; \
isort .; isort_status=$?; \
black .; black_status=$?; \
if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \
        exit 1; \
fi
Skipped 1 files
All done! ✨ 🍰 ✨
58 files left unchanged.

Copy link
Collaborator

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks king! Love your clean code and rapid execution!

@ByronHsu ByronHsu merged commit 01010eb into linkedin:main Aug 28, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants