diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index 2ee9b87b..15ec8fbd 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -60,26 +60,30 @@ def build_mesh(self, device_type): ): if d > 1: dims.append(d) - if (name == "dp_replicate" and self.dp_shard == 1) or ( - name == "dp_shard" and self.dp_replicate == 1 - ): - names.append("dp") - else: - names.append(name) + names.append(name) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") names = tuple(names) mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + # Create all the submesh here to ensure all required process groups are - # initialized - if self.dp_replicate > 1 and self.dp_shard > 1: # HSDP - mesh["dp_replicate", "dp_shard"]._flatten(mesh_dim_name="dp") + # initialized: + # Mesh for data loading + dp_mesh_dim_names = [] + if self.dp_replicate_enabled: + dp_mesh_dim_names.append("dp_replicate") + + if self.dp_shard_enabled: + dp_mesh_dim_names.append("dp_shard") + + if dp_mesh_dim_names != []: + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") if self.cp > 1: if self.dp_replicate > 1 and self.dp_shard > 1: # HSDP mesh["dp_replicate", "dp_shard", "cp"]._flatten(mesh_dim_name="dp_cp") elif self.dp_shard > 1: # FSDP - mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp") + mesh["dp_shard", "cp"]._flatten(mesh_dim_name="dp_cp") return mesh