-
Notifications
You must be signed in to change notification settings - Fork 4.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flatting Packing / maybe fix #5443 and #5426 #5458
Open
AlongWY
wants to merge
2
commits into
hiyouga:main
Choose a base branch
from
AlongWY:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+155
−54
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,23 +11,21 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import itertools | ||
from collections import defaultdict | ||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple | ||
|
||
from ...extras.constants import IGNORE_INDEX | ||
from ...extras.logging import get_logger | ||
from .processor_utils import greedy_knapsack, infer_seqlen | ||
|
||
|
||
if TYPE_CHECKING: | ||
from transformers import PreTrainedTokenizer, ProcessorMixin | ||
|
||
from ...hparams import DataArguments | ||
from ..mm_plugin import ImageInput, VideoInput | ||
from ..template import Template | ||
|
||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
|
@@ -53,13 +51,16 @@ def _encode_supervised_example( | |
encoded_pairs = encoded_pairs[::-1] # high priority for last turns | ||
|
||
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): | ||
if total_length >= cutoff_len: | ||
if total_length >= cutoff_len and cutoff_len > 0: | ||
break | ||
|
||
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length) | ||
source_ids = source_ids[:source_len] | ||
target_ids = target_ids[:target_len] | ||
total_length += source_len + target_len | ||
if cutoff_len > 0: | ||
source_len, target_len = infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length) | ||
source_ids = source_ids[:source_len] | ||
target_ids = target_ids[:target_len] | ||
total_length += source_len + target_len | ||
else: | ||
source_len, target_len = len(source_ids), len(target_ids) | ||
|
||
if train_on_prompt: | ||
source_label = source_ids | ||
|
@@ -112,7 +113,7 @@ def preprocess_supervised_dataset( | |
template=template, | ||
tokenizer=tokenizer, | ||
processor=processor, | ||
cutoff_len=data_args.cutoff_len, | ||
cutoff_len=data_args.cutoff_len if data_args.allow_truncation else 0, | ||
train_on_prompt=data_args.train_on_prompt, | ||
mask_history=data_args.mask_history, | ||
) | ||
|
@@ -132,13 +133,16 @@ def preprocess_packed_supervised_dataset( | |
processor: Optional["ProcessorMixin"], | ||
data_args: "DataArguments", | ||
) -> Dict[str, List[Any]]: | ||
# TODO: use `position_ids` to achieve packing | ||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>` | ||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>` | ||
valid_num = 0 | ||
invalid_num = 0 | ||
batch_input_ids, batch_labels, batch_images, batch_videos = [], [], [], [] | ||
lengths = [] | ||
length2indexes = defaultdict(list) | ||
|
||
# reserved for the padding token / flat_packing don't need | ||
num_reserved = 0 if data_args.flat_packing else 1 | ||
for i in range(len(examples["_prompt"])): | ||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: | ||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) | ||
|
@@ -154,13 +158,13 @@ def preprocess_packed_supervised_dataset( | |
template=template, | ||
tokenizer=tokenizer, | ||
processor=processor, | ||
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token | ||
cutoff_len=data_args.cutoff_len - num_reserved if data_args.allow_truncation else 0, | ||
train_on_prompt=data_args.train_on_prompt, | ||
mask_history=data_args.mask_history, | ||
) | ||
length = len(input_ids) | ||
if length > data_args.cutoff_len: | ||
logger.warning("Dropped lengthy example with length {} > {}.".format(length, data_args.cutoff_len)) | ||
if length > data_args.cutoff_len - num_reserved: | ||
invalid_num += 1 | ||
else: | ||
lengths.append(length) | ||
length2indexes[length].append(valid_num) | ||
|
@@ -170,36 +174,52 @@ def preprocess_packed_supervised_dataset( | |
batch_videos.append(examples["_videos"][i] or []) | ||
valid_num += 1 | ||
|
||
if invalid_num > 0: | ||
logger.warning( | ||
"Dropped lengthy {} example with length > {}.".format(invalid_num, data_args.cutoff_len - num_reserved) | ||
) | ||
|
||
model_inputs = defaultdict(list) | ||
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token | ||
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - num_reserved) # reserved for the padding token | ||
for knapsack in knapsacks: | ||
packed_input_ids, packed_attention_masks, packed_labels = [], [], [] | ||
packed_images, packed_videos = [], [] | ||
for i, length in enumerate(knapsack): | ||
index = length2indexes[length].pop() | ||
packed_input_ids += batch_input_ids[index] | ||
packed_labels += batch_labels[index] | ||
packed_images += batch_images[index] | ||
packed_videos += batch_videos[index] | ||
if data_args.neat_packing: | ||
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 | ||
else: | ||
packed_attention_masks += [1] * len(batch_input_ids[index]) | ||
|
||
if len(packed_input_ids) < data_args.cutoff_len: | ||
pad_length = data_args.cutoff_len - len(packed_input_ids) | ||
packed_input_ids += [tokenizer.pad_token_id] * pad_length | ||
packed_labels += [IGNORE_INDEX] * pad_length | ||
if data_args.neat_packing: | ||
packed_attention_masks += [0] * pad_length | ||
else: | ||
packed_attention_masks += [1] * pad_length # more efficient flash_attn | ||
|
||
if len(packed_input_ids) != data_args.cutoff_len: | ||
raise ValueError("The length of packed example should be identical to the cutoff length.") | ||
|
||
if data_args.flat_packing: | ||
for i, length in enumerate(knapsack): | ||
index = length2indexes[length].pop() | ||
packed_input_ids.append(batch_input_ids[index]) | ||
packed_labels.append(batch_labels[index]) | ||
packed_images.append(batch_images[index]) | ||
packed_videos.append(batch_videos[index]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 延迟处理,此时先不返回 position ids,在 collator 中整合并返回 position ids |
||
else: | ||
for i, length in enumerate(knapsack): | ||
index = length2indexes[length].pop() | ||
packed_input_ids += batch_input_ids[index] | ||
packed_labels += batch_labels[index] | ||
packed_images += batch_images[index] | ||
packed_videos += batch_videos[index] | ||
if data_args.neat_packing: | ||
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1 | ||
else: | ||
packed_attention_masks += [1] * len(batch_input_ids[index]) | ||
|
||
# flat_packing don't need attention masks | ||
if len(packed_input_ids) < data_args.cutoff_len: | ||
pad_length = data_args.cutoff_len - len(packed_input_ids) | ||
packed_input_ids += [tokenizer.pad_token_id] * pad_length | ||
packed_labels += [IGNORE_INDEX] * pad_length | ||
if data_args.neat_packing: | ||
packed_attention_masks += [0] * pad_length | ||
else: | ||
packed_attention_masks += [1] * pad_length # more efficient flash_attn | ||
|
||
# flatting packing don't need pad | ||
if len(packed_input_ids) != data_args.cutoff_len: | ||
raise ValueError("The length of packed example should be identical to the cutoff length.") | ||
model_inputs["attention_mask"].append(packed_attention_masks) | ||
|
||
model_inputs["input_ids"].append(packed_input_ids) | ||
model_inputs["attention_mask"].append(packed_attention_masks) | ||
model_inputs["labels"].append(packed_labels) | ||
model_inputs["images"].append(packed_images or None) | ||
model_inputs["videos"].append(packed_videos or None) | ||
|
@@ -213,3 +233,12 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: " | |
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) | ||
print("label_ids:\n{}".format(example["labels"])) | ||
print("labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False))) | ||
|
||
|
||
def print_flatting_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer") -> None: | ||
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, itertools.chain(*example["labels"]))) | ||
input_ids = list(itertools.chain(*example["input_ids"])) | ||
print("input_ids:\n{}".format(input_ids)) | ||
print("inputs:\n{}".format(tokenizer.decode(input_ids, skip_special_tokens=False))) | ||
print("label_ids:\n{}".format(list(itertools.chain(*example["labels"])))) | ||
print("labels:\n{}".format(tokenizer.decode(valid_labels), skip_special_tokens=False)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里导致 Inst 数据被异常截断 #5426, 也许考虑引入一个新的参数来保证是否可以被截断?我的样本是2轮次的 tool 调用,但是如果截断就只会学习到输出 tool_calls 没有最后的答案了。 而且这里现在截断的实现方式将会导致 user 和 assistant 的内容被截断。如在 mistral 模板中, 会产生
[INST] xxxxxxx
的结果,而xxxxx[/INST]
就不见了,这显然是不正确的。There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得不是这里的问题?non-packing 也会有同样的行为
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不过我确实觉得需要加一个参数控制一下,因为有些情况下不允许一个样本被中间截断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不截 prompt 的话 assistant 放在哪里呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接跳过,drop掉这个样本
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加了参数控制是否可以截断,默认不能截断