-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] pipeline parallelism+fp16+moe isn't working #6714
Comments
can you provide the whole script to reproduce it? |
Here is a simple example adapted from DeepspeedExamples.training.cifar.
After running this code with |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Describe the bug
My model use deepspeed
PipelineModule(num_stages=4)
split into 4 parts, and mydeepspeed.moe.layer.MoE
is only set in the pipeline stage1 layer. When my modeltrain_batch
, the program will get stuck, the specific issue occurs in FP16_Optimizer step.Here is our deepspeed config
Source code with issues
my pipeline_parallel_world_size is 4, the code will enter the following branch, but my moe layer only is set in pipeline stage1, then all_reduce will make program stuck. If I delete this code, it will run successfully.
DeepSpeed/deepspeed/runtime/utils.py
Lines 892 to 893 in 10ba3dd
I don't know why all_reduce needs to be done here, it doesn't seem meaningful
The text was updated successfully, but these errors were encountered: