Fix dtype mismatch in fused_linear_cross_entropy_forward #307
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes #305
Fix dtype mismatch in fused_linear_cross_entropy_forward function.
logits_chunk
to the data type of_input_chunk
before performing operations on it.I tested this in Colab after the change and it solved the problem.
{
"epoch": 1.0,
"eval_loss": 1.885668396949768,
"eval_runtime": 0.1708,
"eval_samples_per_second": 5.856,
"eval_steps_per_second": 5.856,
"total_flos": 1766475165597696.0,
"train_loss": 1.9928909236309575,
"train_runtime": 115.5799,
"train_samples_per_second": 0.441,
"train_steps_per_second": 0.441
}
For more details, open the Copilot Workspace session.