-
Notifications
You must be signed in to change notification settings - Fork 181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] FusedLinearCrossEntropy support for Gemma2 #127
Comments
#take @yundai424 I would like to make an attempt to make it available. I'm thinking this approach:
Can you assign it to me if this sounds okay? |
@troy1729 Sounds reasonable to me. Assigned and feel free to kick off the implementation and ping us to discuss or review on any issues. Thank you! |
Hi @qingquansong, I've made the changes but still have to add the tests hence kept the PR in draft stage. |
Hey @troy1729 , thanks for the question (no silly question) and fast kick off! I think
In sum, my suggestion would be: implement the tanh option for now only + follow geglu backward to see how tanh gradient is computed with chain rule to device the equation and implement it here |
I believe I've implemented softcap in cross entropy function correctly and the flce support for gemma2. But since gemma2 currently can't pass the test even without flce, do I need to find a way to pass the relevant convergence test (test_mini_models_no_logits.py)? cc @yundai424 |
🚀 The feature, motivation and pitch
FLCE needs special handling for the soft capping in gemma2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py#L1054
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: