From dfd9ab395e26fa83189ae45acacd83c195298f7e Mon Sep 17 00:00:00 2001 From: ylfeng Date: Wed, 18 Sep 2024 06:06:38 +0800 Subject: [PATCH] 1. flatting_packing don't need reserve token for padding 2. Fix mistral assistant message --- .../data/processors/supervised.py | 19 ++++++++++++------- src/llamafactory/data/template.py | 1 + 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 68ce2ffd73..1406f16a15 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -127,10 +127,13 @@ def preprocess_packed_supervised_dataset( # build inputs with format ` X1 Y1 X2 Y2 ` # and labels with format ` ... Y1 ... Y2 ` valid_num = 0 + invalid_num = 0 batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], [] lengths = [] length2indexes = defaultdict(list) - count_drop = 0 + + # reserved for the padding token / flatting_packing don't need + num_reserved = 0 if data_args.flatting_packing else 1 for i in range(len(examples["_prompt"])): if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) @@ -146,13 +149,13 @@ def preprocess_packed_supervised_dataset( template=template, tokenizer=tokenizer, processor=processor, - cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token + cutoff_len=data_args.cutoff_len - num_reserved, train_on_prompt=data_args.train_on_prompt, mask_history=data_args.mask_history, ) length = len(input_ids) - if length > data_args.cutoff_len - 1: # reserved for the padding token - count_drop += 1 + if length > data_args.cutoff_len - num_reserved: + invalid_num += 1 else: lengths.append(length) length2indexes[length].append(valid_num) @@ -162,11 +165,13 @@ def preprocess_packed_supervised_dataset( batch_videos.append(examples["_videos"][i] or []) valid_num += 1 - if count_drop > 0: - logger.warning("Dropped lengthy {} example with length > {}.".format(count_drop, data_args.cutoff_len - 1)) + if invalid_num > 0: + logger.warning( + "Dropped lengthy {} example with length > {}.".format(invalid_num, data_args.cutoff_len - num_reserved) + ) model_inputs = defaultdict(list) - knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token + knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - num_reserved) # reserved for the padding token for knapsack in knapsacks: packed_input_ids, packed_attention_masks, packed_labels = [], [], [] packed_images, packed_videos = [], [] diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index e4ae4457d2..89d19be01a 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -724,6 +724,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: _register_template( name="mistral", format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), + format_assistant=StringFormatter(slots=[" {{content}}"]), # mistral add space here format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_function=MistralFunctionFormatter(slots=[], tool_format="mistral"), format_observation=MistralObservationFormatter(tool_format="mistral"),