-
Notifications
You must be signed in to change notification settings - Fork 181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Z Loss in CE #239
base: main
Are you sure you want to change the base?
Support Z Loss in CE #239
Conversation
…s and update comments
…s and update comments
Passed all tests. Ready for review! |
loss_stride, | ||
n_cols, | ||
n_non_ignore, | ||
ignore_index, | ||
label_smoothing: tl.constexpr, | ||
lse_square_scale: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if making label_smoothing and lse_square_scale tl.constexpr
is a correct move.
Not familiar with model training. Are these two parameters often changed in practice? I'm worried that it might cause the same issue as #146.
Flash-attention's implementation creates a new constexpr for it in triton.heuristics
to solve branching issues.
I wonder what the difference is between
- declare
label_smoothing
as a constexpr, and - do calculations in
triton.heuristics
then assign a value to the constexprHAS_SMOOTHING
My assumption is that:
in case 1, JIT every time label_smoothing
changes
in case 2, JIT only when HAS_SMOOTHING
changes because of calculations on label_smoothing
.
If so, I will go with flash-attn's approach.
Ignore OOM errors, the current custom CrossEntropyWithZLoss (torch.nn.module), as a ground truth implementation, has precision issue on gradients calculations with bfloat16 and reduction="sum". LigerCrossEntropyLoss in this PR has no issue passing tests if comparing to flash-attn's CrossEntropyLoss. Current goal is to make the custom torch implementation on par with flash-attn's. Update: problems solved |
All passed |
lgtm. @ByronHsu for a second check |
Summary
This PR aims to resolve #197
Implemented z loss in LigerCrossEntropy.
note:
lse_square_scale
not exposed at flce yet, having issues passing the tests.Details
For loss:
We can use$m = max(X_i)$ and $d = \sum e^{X_i - m}$ , obtained from online softmax algorithm, to calculate $lse$ directly.
For gradients:
First, we calculate the derivative of lse
Then we can obtain the derivative of z_loss by chain rule.
and we have the derivative of cross entropy loss with label smoothing
where$\epsilon$ is label_smoothing and $K$ is the number of total classes.
Thus, the derivative of total loss is
Reference
PaLM: Scaling Language Modeling with Pathways
Chameleon: Mixed-Modal Early-Fusion Foundation Models
Testing Done
benchmark gist
neglectable error in speed benchmark.
This benchmark was done on my machine, which is probably not accurate.
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence