diff --git a/torchtitan/models/llama/model.py b/torchtitan/models/llama/model.py index 2060519a4..e3c2bdb62 100644 --- a/torchtitan/models/llama/model.py +++ b/torchtitan/models/llama/model.py @@ -439,8 +439,8 @@ def forward(self, tokens: torch.Tensor): h = layer(h, self.freqs_cis) h = self.norm(h) if self.norm else h - output = self.output(h).float() if self.output else h - return output + chunks = h.chunk(16, dim=1) # TODO: 16 is from the default `num_chunks` + return [self.output(chunk) for chunk in chunks] @classmethod def from_model_args(cls, model_args: ModelArgs) -> "Transformer": diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index aa07f25fb..f8cb77f6c 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -312,15 +312,15 @@ def apply_fsdp( # all-gathers, which can be expensive and non-overlapped reshard_after_forward = False else: - # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately - reshard_after_forward = int(layer_id) < len(model.layers) - 1 + # For small models (e.g. GPT-2), parameter memory is low, so there + # is no need to reshard after forward + reshard_after_forward = False fully_shard( transformer_block, **fsdp_config, reshard_after_forward=reshard_after_forward, ) - fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) + fully_shard(model, **fsdp_config) logger.info("Applied FSDP to the model") diff --git a/train.py b/train.py index 3f07d3c7b..adcfac51a 100644 --- a/train.py +++ b/train.py @@ -8,6 +8,7 @@ import os import time from datetime import timedelta +from typing import List import torch from torch.distributed.elastic.multiprocessing.errors import record @@ -44,6 +45,36 @@ def context(): return context +class TokenChunkedCrossEntropyLoss(torch.nn.Module): + def __init__(self, num_chunks: int = 16, ignore_index: int = -100): + super(TokenChunkedCrossEntropyLoss, self).__init__() + self.num_chunks = num_chunks + self.ignore_index = ignore_index + self.cross_entropy_loss = torch.nn.CrossEntropyLoss( + reduction="sum", ignore_index=self.ignore_index + ) + + @torch.compile() + def _compute_cross_entropy(self, logits: torch.Tensor, labels: torch.Tensor): + return self.cross_entropy_loss(logits.float(), labels) + + def forward(self, logits: List[torch.Tensor], labels: torch.Tensor): + """ + Args: + logits (List[torch.Tensor]): List of chunked logits of length + ``self.num_chunks``, where each chunk has shape + (batch_size, num_tokens / num_chunks, vocab_size). + labels (torch.Tensor): Ground truth labels of shape (batch_size, num_tokens). + """ + total_elements = (labels != self.ignore_index).sum() + labels = [target_chunk.reshape(-1) for target_chunk in labels.chunk(self.num_chunks, dim=1)] + logits = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits] + total_loss = 0.0 + for logits_chunk, labels_chunk in zip(logits, labels): + total_loss += self._compute_cross_entropy(logits_chunk, labels_chunk) + return total_loss / total_elements + + # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html @record def main(job_config: JobConfig): @@ -132,9 +163,16 @@ def main(job_config: JobConfig): f"{color.blue}Model {model_name} {job_config.model.flavor} " f"{color.red}size: {model_param_count:,} total parameters{color.reset}" ) + token_chunked_cross_entropy_loss = TokenChunkedCrossEntropyLoss() # loss function to be shared by Pipeline Parallel and SPMD training def loss_fn(pred, labels): + if isinstance(pred, torch.Tensor): + pred_chunks = pred.chunk(token_chunked_cross_entropy_loss.num_chunks, dim=1) + else: + assert isinstance(pred, list) + pred_chunks = pred + return token_chunked_cross_entropy_loss(pred_chunks, labels) return torch.nn.functional.cross_entropy( pred.flatten(0, 1), labels.flatten(0, 1) )