Skip to content

Commit

Permalink
[torchtitan][debug] integrated CommDebugMode into TorchTitan
Browse files Browse the repository at this point in the history
ghstack-source-id: 7e9de7b83a376eb320a403c416b891a0c5b5321e
Pull Request resolved: #480
  • Loading branch information
sinhaanshul committed Jul 24, 2024
1 parent 0f70507 commit 834b8f6
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
26 changes: 26 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,32 @@ def __init__(self):
""",
)

# commdebugmode configs
self.parser.add_argument(
"--comm_debug.enable_comm_debug_mode",
default=False,
action="store_true",
help="Whether to enable CommDebugMode, should be used only on first step",
)
self.parser.add_argument(
"--comm_debug.dump_file",
type=str,
default="torchtitan_comm_debug_dump.txt",
help="Which file to dump CommDebugMode's output to",
)
self.parser.add_argument(
"--comm_debug.dump_json",
type=str,
default="torchtitan_comm_debug_log.json",
help="Which file to dump CommDebugMode's json to",
)
self.parser.add_argument(
"--comm_debug.noise_level",
type=int,
default=2,
help="Sets noise level for CommDebugMode's output, controls how much info is displayed",
)

# communications library settings
self.parser.add_argument(
"--comm.init_timeout_seconds",
Expand Down
23 changes: 19 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.parallel import loss_parallel

from torch.distributed._tensor.debug import CommDebugMode
from torchtitan.checkpoint import CheckpointManager
from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_hf_data_loader, create_tokenizer
Expand Down Expand Up @@ -138,18 +138,22 @@ def zero_grad(self):
return OptimizersContainer([_build_optimizer(model) for model in model_parts])


def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool, enable_comm_debug_mode: bool):
@contextlib.contextmanager
def context():
context_managers = {}
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(loss_parallel())
if enable_compiled_autograd:
stack.enter_context(
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)
if enable_comm_debug_mode:
comm_mode = stack.enter_context(CommDebugMode())
context_managers["comm_mode"] = comm_mode

yield
yield context_managers

return context

Expand Down Expand Up @@ -214,6 +218,7 @@ def main(job_config: JobConfig):
train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
job_config.comm_debug.enable_comm_debug_mode,
)

# loss fn can be shared by pipeline-parallel or non-pp execution
Expand Down Expand Up @@ -390,6 +395,10 @@ def loss_fn(pred, labels):
else:
pp_schedule.step()

if job_config.comm_debug.enable_comm_debug_mode and train_state.step == 1:
comm_mode = tc["comm_mode"]
comm_mode.log_comm_debug_tracing_table_to_file(file_name=job_config.comm_debug.dump_file, noise_level=job_config.comm_debug.noise_level)

# accumulate losses across pipeline microbatches
loss = (
torch.mean(torch.stack(losses))
Expand All @@ -398,13 +407,19 @@ def loss_fn(pred, labels):
)
else:
# Non-PP forward / backward
with train_context():
with train_context() as tc:
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()

if job_config.comm_debug.enable_comm_debug_mode and train_state.step == 1:
comm_mode = tc["comm_mode"]
comm_mode.log_comm_debug_tracing_table_to_file(file_name=job_config.comm_debug.dump_file, noise_level=job_config.comm_debug.noise_level)
comm_mode.generate_json_dump(file_name=job_config.comm_debug.dump_json, noise_level=job_config.comm_debug.noise_level)


# clip gradients
for model in model_parts:
Expand Down

0 comments on commit 834b8f6

Please sign in to comment.