diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index e3c19cc91c..5eacced549 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -514,6 +514,10 @@ def profile_packing( big_batch = next(iter(train_dataloader)) # Cut everything down to size + if 'total_tokens' in big_batch: + del big_batch['total_tokens'] + if 'loss_generating_tokens' in big_batch: + del big_batch['loss_generating_tokens'] sizes, trimmed_examples = _trim_batch(big_batch) def profile(raw_batch_size: int) -> tuple[Optional[float], Optional[float]]: diff --git a/llmfoundry/data/utils.py b/llmfoundry/data/utils.py index 21c28d9183..8038430259 100644 --- a/llmfoundry/data/utils.py +++ b/llmfoundry/data/utils.py @@ -20,6 +20,55 @@ log = logging.getLogger(__name__) +class LossGeneratingTokensCollatorWrapper: + """Collator wrapper to add loss generating token counts to batch.""" + + def __init__( + self, + base_collator: Callable, + token_counting_func: Callable[[Batch], Union[int, dict[str, int]]], + ): + self.base_collator = base_collator + self.token_counting_func = token_counting_func + + self._token_count_batch_keys = [ + 'input_ids', + 'attention_mask', + 'labels', + 'decoder_attention_mask', + ] + + def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]: + batch = self.base_collator(examples) + + # Add token counts to batch as a list, one for each row, so that microbatch splitting works + output = { + 'total_tokens': [], + 'loss_generating_tokens': [], + } + num_rows = batch['input_ids'].shape[0] + for row in range(num_rows): + row_batch = {} + for key in self._token_count_batch_keys: + if key in batch: + row_batch[key] = batch[key][row:row + 1] + + num_tokens = self.token_counting_func(row_batch) + if isinstance(num_tokens, dict): + output['total_tokens'].append(num_tokens['total']) + output['loss_generating_tokens'].append( + num_tokens['loss_generating'], + ) + else: + output['total_tokens'].append(num_tokens) + output['loss_generating_tokens'].append(num_tokens) + + batch['total_tokens'] = output['total_tokens'] + batch['loss_generating_tokens'] = output['loss_generating_tokens'] + + return batch + + def _validate_cfg( dataset_cfg: dict[str, Any], tokenizer: PreTrainedTokenizerBase, @@ -109,6 +158,13 @@ def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]: 'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_attention_mask key', ) + # Short cut if the dataloader has already calculated the number of tokens + if 'total_tokens' in batch and 'loss_generating_tokens' in batch: + return { + 'total': sum(batch['total_tokens']), + 'loss_generating': sum(batch['loss_generating_tokens']), + } + # Count number of non padding tokens in batch if 'attention_mask' in batch: input_ids_tokens = int(torch.sum(batch['attention_mask']).item()) @@ -117,16 +173,10 @@ def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]: loss_generating_tokens = None if 'labels' in batch: - loss_generating_tokens = int( - torch.sum(batch['labels'] != CROSS_ENTROPY_IGNORE_INDEX).item(), - ) - - # Subtract one for each example in the batch that starts with a non -100, - # because those will be shifted off - loss_generating_tokens -= int( - torch.sum( - batch['labels'][:, 0] != CROSS_ENTROPY_IGNORE_INDEX, - ).item(), + loss_generating_tokens = ( + batch['labels'].shape[0] * (batch['labels'].shape[1] - 1) + ) - torch.count_nonzero( + torch.eq(batch['labels'][..., 1:], CROSS_ENTROPY_IGNORE_INDEX), ) # For encoder decoder models only @@ -151,7 +201,8 @@ def get_text_collator( tokenizer: PreTrainedTokenizerBase, dataset_batch_size: int, ) -> tuple[Union[transformers.DataCollatorForLanguageModeling, - ConcatenatedSequenceCollatorWrapper], int]: + ConcatenatedSequenceCollatorWrapper, + LossGeneratingTokensCollatorWrapper], int]: dataset_cfg = dataloader_cfg.get('dataset') assert isinstance(dataset_cfg, dict) eos_token_id = dataset_cfg.get('eos_token_id', None) @@ -171,6 +222,11 @@ def get_text_collator( bos_token_id=bos_token_id, ) + collate_fn = LossGeneratingTokensCollatorWrapper( + collate_fn, + get_tokens_per_batch_func(), + ) + return collate_fn, dataset_batch_size @@ -178,5 +234,15 @@ def get_finetuning_collator( dataloader_cfg: dict[str, Any], tokenizer: PreTrainedTokenizerBase, dataset_batch_size: int, -) -> tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]: - return build_collate_fn(dataloader_cfg, tokenizer, dataset_batch_size) +) -> tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator, + LossGeneratingTokensCollatorWrapper], int]: + collate_fn, dataset_batch_size = build_collate_fn( + dataloader_cfg, + tokenizer, + dataset_batch_size, + ) + collate_fn = LossGeneratingTokensCollatorWrapper( + collate_fn, + get_tokens_per_batch_func(), + ) + return collate_fn, dataset_batch_size diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index f599ebbc16..809babece9 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -1573,7 +1573,8 @@ def test_mptmoe_huggingface_conversion_callback( # Check output equivalence loaded_model = loaded_model.cuda().bfloat16() # type: ignore for k, v in batch.items(): - batch[k] = v.cuda() + if isinstance(v, torch.Tensor): + batch[k] = v.cuda() loaded_model_logits = loaded_model( input_ids=batch.get('input_ids', None), attention_mask=batch.get('attention_mask', None), diff --git a/tests/data/test_packing.py b/tests/data/test_packing.py index 0fad6c0d53..48713f8a19 100644 --- a/tests/data/test_packing.py +++ b/tests/data/test_packing.py @@ -16,6 +16,7 @@ from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader from llmfoundry.data.finetuning.tasks import StreamingFinetuningDataset from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio +from llmfoundry.data.utils import LossGeneratingTokensCollatorWrapper from llmfoundry.utils.builders import build_tokenizer @@ -253,7 +254,8 @@ def test_packing_with_dataloader(packing_ratio: Any): ).dataloader assert isinstance(loader, DataLoader) - pack_collator = loader.collate_fn + assert isinstance(loader.collate_fn, LossGeneratingTokensCollatorWrapper) + pack_collator = loader.collate_fn.base_collator assert isinstance(pack_collator, BinPackCollator) batch_ix = 0