Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] pipeline parallelism+fp16+moe isn't working #6714

Open
NeferpitouS3 opened this issue Nov 5, 2024 · 2 comments
Open

[BUG] pipeline parallelism+fp16+moe isn't working #6714

NeferpitouS3 opened this issue Nov 5, 2024 · 2 comments

Comments

@NeferpitouS3
Copy link

Describe the bug
My model use deepspeed PipelineModule(num_stages=4) split into 4 parts, and my deepspeed.moe.layer.MoE is only set in the pipeline stage1 layer. When my model train_batch, the program will get stuck, the specific issue occurs in FP16_Optimizer step.

Here is our deepspeed config

{
   "train_batch_size": 4,
   "train_micro_batch_size_per_gpu" 1,
   "fp16": {
      "enabled": true,
      "auto_cast": true
   },
   "optimizer": {
      "type": "AdamW",
      "params": {
         "lr": 0.001,
         "betas": [
            0.9,
            0.95
         ],
         "weight_decay": 0.05
      }
   },
   "zero_optimization": {
      "stage": 0
   }
}

Source code with issues
my pipeline_parallel_world_size is 4, the code will enter the following branch, but my moe layer only is set in pipeline stage1, then all_reduce will make program stuck. If I delete this code, it will run successfully.

elif bwc_pipeline_parallel_world_size(mpu) > 1:
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=bwc_pipeline_parallel_group(mpu))

I don't know why all_reduce needs to be done here, it doesn't seem meaningful

@ranzhejiang
Copy link
Contributor

can you provide the whole script to reproduce it?

@NeferpitouS3
Copy link
Author

Here is a simple example adapted from DeepspeedExamples.training.cifar.

class net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = layer1()
        self.layer2 = layer2()
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

class layer1(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        return x

class layer2(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = DeepSpeedMoEMlp()
        self.fc3 = nn.Linear(120, 10)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

class DeepSpeedMoEMlp(nn.Module):
    def __init__(self):
        super(DeepSpeedMoEMlp, self).__init__()
        self.fc2 = nn.Linear(120, 120)
        self._moe_layer = deepspeed.moe.layer.MoE(
            hidden_size=120,
            expert=self.fc2,
            num_experts=4,
            k=1,
            capacity_factor=1.25,
            use_tutel=True
        )
    def forward(self, x):
        x, _, _ = self._moe_layer(x)
        return x

if __name__ == "__main__":
    deepspeed.init_distributed()
    model = net()
    layer = torch.nn.Sequential(
        model.layer1,
        model.layer2
    )
    criterion = nn.CrossEntropyLoss()
    pipeline_model = PipelineModule(layers=layer, loss_fn= criterion, num_stages=2, partition_method="uniform")
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    trainset = torchvision.datasets.CIFAR10(root='./data',
                                            train=True,
                                            download=False,
                                            transform=transform)
    engine, _, data_loader, _ = deepspeed.initialize(
        model=pipeline_model,
        model_parameters=pipeline_model.parameters(),
        config=get_ds_config(),
        training_data=trainset
    )
    for epoch in range(3):
        for i in range(len(trainset)):
            loss = engine.train_batch()
            print(f"step{i}, loss: {loss}")

After running this code with deepspeed --include="localhost:0,1" test.py --deepspeed .The problem I mentioned above will be reproduced.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants