Skip to content

Commit

Permalink
Move loss generating token counting to the dataloader (#1632)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Nov 4, 2024
1 parent fe69619 commit 47dd036
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 15 deletions.
4 changes: 4 additions & 0 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,10 @@ def profile_packing(
big_batch = next(iter(train_dataloader))

# Cut everything down to size
if 'total_tokens' in big_batch:
del big_batch['total_tokens']
if 'loss_generating_tokens' in big_batch:
del big_batch['loss_generating_tokens']
sizes, trimmed_examples = _trim_batch(big_batch)

def profile(raw_batch_size: int) -> tuple[Optional[float], Optional[float]]:
Expand Down
92 changes: 79 additions & 13 deletions llmfoundry/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,55 @@
log = logging.getLogger(__name__)


class LossGeneratingTokensCollatorWrapper:
"""Collator wrapper to add loss generating token counts to batch."""

def __init__(
self,
base_collator: Callable,
token_counting_func: Callable[[Batch], Union[int, dict[str, int]]],
):
self.base_collator = base_collator
self.token_counting_func = token_counting_func

self._token_count_batch_keys = [
'input_ids',
'attention_mask',
'labels',
'decoder_attention_mask',
]

def __call__(self, examples: list[Any]) -> dict[str, torch.Tensor]:
batch = self.base_collator(examples)

# Add token counts to batch as a list, one for each row, so that microbatch splitting works
output = {
'total_tokens': [],
'loss_generating_tokens': [],
}
num_rows = batch['input_ids'].shape[0]
for row in range(num_rows):
row_batch = {}
for key in self._token_count_batch_keys:
if key in batch:
row_batch[key] = batch[key][row:row + 1]

num_tokens = self.token_counting_func(row_batch)
if isinstance(num_tokens, dict):
output['total_tokens'].append(num_tokens['total'])
output['loss_generating_tokens'].append(
num_tokens['loss_generating'],
)
else:
output['total_tokens'].append(num_tokens)
output['loss_generating_tokens'].append(num_tokens)

batch['total_tokens'] = output['total_tokens']
batch['loss_generating_tokens'] = output['loss_generating_tokens']

return batch


def _validate_cfg(
dataset_cfg: dict[str, Any],
tokenizer: PreTrainedTokenizerBase,
Expand Down Expand Up @@ -109,6 +158,13 @@ def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]:
'get_tokens_per_batch_func() for encoder decoder requires a batch with a decoder_attention_mask key',
)

# Short cut if the dataloader has already calculated the number of tokens
if 'total_tokens' in batch and 'loss_generating_tokens' in batch:
return {
'total': sum(batch['total_tokens']),
'loss_generating': sum(batch['loss_generating_tokens']),
}

# Count number of non padding tokens in batch
if 'attention_mask' in batch:
input_ids_tokens = int(torch.sum(batch['attention_mask']).item())
Expand All @@ -117,16 +173,10 @@ def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]:

loss_generating_tokens = None
if 'labels' in batch:
loss_generating_tokens = int(
torch.sum(batch['labels'] != CROSS_ENTROPY_IGNORE_INDEX).item(),
)

# Subtract one for each example in the batch that starts with a non -100,
# because those will be shifted off
loss_generating_tokens -= int(
torch.sum(
batch['labels'][:, 0] != CROSS_ENTROPY_IGNORE_INDEX,
).item(),
loss_generating_tokens = (
batch['labels'].shape[0] * (batch['labels'].shape[1] - 1)
) - torch.count_nonzero(
torch.eq(batch['labels'][..., 1:], CROSS_ENTROPY_IGNORE_INDEX),
)

# For encoder decoder models only
Expand All @@ -151,7 +201,8 @@ def get_text_collator(
tokenizer: PreTrainedTokenizerBase,
dataset_batch_size: int,
) -> tuple[Union[transformers.DataCollatorForLanguageModeling,
ConcatenatedSequenceCollatorWrapper], int]:
ConcatenatedSequenceCollatorWrapper,
LossGeneratingTokensCollatorWrapper], int]:
dataset_cfg = dataloader_cfg.get('dataset')
assert isinstance(dataset_cfg, dict)
eos_token_id = dataset_cfg.get('eos_token_id', None)
Expand All @@ -171,12 +222,27 @@ def get_text_collator(
bos_token_id=bos_token_id,
)

collate_fn = LossGeneratingTokensCollatorWrapper(
collate_fn,
get_tokens_per_batch_func(),
)

return collate_fn, dataset_batch_size


def get_finetuning_collator(
dataloader_cfg: dict[str, Any],
tokenizer: PreTrainedTokenizerBase,
dataset_batch_size: int,
) -> tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]:
return build_collate_fn(dataloader_cfg, tokenizer, dataset_batch_size)
) -> tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator,
LossGeneratingTokensCollatorWrapper], int]:
collate_fn, dataset_batch_size = build_collate_fn(
dataloader_cfg,
tokenizer,
dataset_batch_size,
)
collate_fn = LossGeneratingTokensCollatorWrapper(
collate_fn,
get_tokens_per_batch_func(),
)
return collate_fn, dataset_batch_size
3 changes: 2 additions & 1 deletion tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,7 +1573,8 @@ def test_mptmoe_huggingface_conversion_callback(
# Check output equivalence
loaded_model = loaded_model.cuda().bfloat16() # type: ignore
for k, v in batch.items():
batch[k] = v.cuda()
if isinstance(v, torch.Tensor):
batch[k] = v.cuda()
loaded_model_logits = loaded_model(
input_ids=batch.get('input_ids', None),
attention_mask=batch.get('attention_mask', None),
Expand Down
4 changes: 3 additions & 1 deletion tests/data/test_packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
from llmfoundry.data.finetuning.tasks import StreamingFinetuningDataset
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.utils import LossGeneratingTokensCollatorWrapper
from llmfoundry.utils.builders import build_tokenizer


Expand Down Expand Up @@ -253,7 +254,8 @@ def test_packing_with_dataloader(packing_ratio: Any):
).dataloader

assert isinstance(loader, DataLoader)
pack_collator = loader.collate_fn
assert isinstance(loader.collate_fn, LossGeneratingTokensCollatorWrapper)
pack_collator = loader.collate_fn.base_collator
assert isinstance(pack_collator, BinPackCollator)

batch_ix = 0
Expand Down

0 comments on commit 47dd036

Please sign in to comment.