Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge easycontext #4733

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
ab6bfb4
merge easycontext
qianhao0713 Jul 9, 2024
3a0f5b4
Merge pull request #1 from ZJLab-DataHub-Security/qianhao
qianhao0713 Jul 9, 2024
7544719
merge easycontext into llamafactory
qianhao0713 Jul 9, 2024
49bd709
remove useless line
qianhao0713 Jul 9, 2024
62367a4
Merge pull request #2 from ZJLab-DataHub-Security/bxf
qianhao0713 Jul 9, 2024
751c6a1
hybrid dp & sp for sft
qianhao0713 Jul 10, 2024
ba4e768
add sp_group_size params for launch shell
qianhao0713 Jul 10, 2024
b4c8d23
fix sp bug in hybrid sp&dp mode
qianhao0713 Jul 11, 2024
fccd0c3
fix sp bug in hybrid sp&dp mode
qianhao0713 Jul 11, 2024
c039a82
fix bug when sequence_parallel_size=-1
qianhao0713 Jul 11, 2024
b7007ec
fix loss normalizer for sp&dp hybrid
qianhao0713 Jul 12, 2024
5544a75
fix bug
qianhao0713 Jul 12, 2024
37d097a
fix bug
qianhao0713 Jul 12, 2024
d5e5135
rename variables
qianhao0713 Jul 15, 2024
d31c8f7
fix 70b launch shell
qianhao0713 Jul 15, 2024
4a4ea30
fix bug
qianhao0713 Jul 15, 2024
9b360dd
Merge pull request #3 from ZJLab-DataHub-Security/qianhao
luckyqsz Jul 15, 2024
92554c2
Merge pull request #4 from ZJLab-DataHub-Security/qianhao_dev
qianhao0713 Jul 15, 2024
34f70ce
add dp&sp hybrid for cpt
qianhao0713 Jul 15, 2024
70bd600
add cpt test launch shell
qianhao0713 Jul 15, 2024
89c28fb
fix compute_loss for cpt
qianhao0713 Jul 15, 2024
0543306
Merge pull request #5 from ZJLab-DataHub-Security/qianhao
qianhao0713 Jul 16, 2024
8f43fc1
arange launch shell
qianhao0713 Jul 30, 2024
f4f9659
free buffer
qianhao0713 Aug 2, 2024
da7a8cc
use global_buffer to load/unload activation
qianhao0713 Aug 9, 2024
9c0ce85
global_buffer add hidden_states and attention_mask
qianhao0713 Aug 14, 2024
7690f83
merge offload buffer
qianhao0713 Aug 27, 2024
09f67e5
pycuda -> cudart
qianhao0713 Sep 3, 2024
949763f
add offload examples
qianhao0713 Sep 4, 2024
5b23b07
Merge pull request #6 from ZJLab-DataHub-Security/qianhao_dev
qianhao0713 Sep 11, 2024
cf5e491
Merge pull request #7 from ZJLab-DataHub-Security/qianhao
qianhao0713 Oct 9, 2024
d81e737
add dp&sp for zigzag and ulysses
qianhao0713 Oct 22, 2024
da1f5b9
Merge pull request #8 from ZJLab-DataHub-Security/dpsp
qianhao0713 Oct 24, 2024
6b964cd
fix bug
qianhao0713 Oct 28, 2024
19236bb
update example
qianhao0713 Oct 28, 2024
a215921
update transformers to 4.46.3
qianhao0713 Dec 6, 2024
358cbad
fix bug
qianhao0713 Dec 9, 2024
1eead84
fix bug when parallel_mode=data_parallel
qianhao0713 Dec 11, 2024
e6f1947
remove per_instance_loss
qianhao0713 Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# local directories to ignore
/output/
*.json
*llmtuner*
/wandb/
/examples/
/data/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
8 changes: 7 additions & 1 deletion data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
"alpaca_zh_demo": {
"file_name": "alpaca_zh_demo.json"
},
"long_sft_32k": {
"file_name": "sample_long_sft_32k_48M.json"
},
"long_sft_128k": {
"file_name": "sample_long_sft_128k.parquet"
},
"glaive_toolcall_en_demo": {
"file_name": "glaive_toolcall_en_demo.json",
"formatting": "sharegpt",
Expand Down Expand Up @@ -551,4 +557,4 @@
},
"folder": "python"
}
}
}
15 changes: 15 additions & 0 deletions examples/accelerate/ds_multi_nodes.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
debug: false
deepspeed_config:
deepspeed_config_file: examples/deepspeed/ds_z3_offload_config.json
deepspeed_multinode_launcher: standard
zero3_init_flag: true
distributed_type: DEEPSPEED
num_processes: 16
downcast_bf16: 'no'
main_training_function: main
rdzv_backend: c10d
same_network: false
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
12 changes: 4 additions & 8 deletions examples/deepspeed/ds_z3_offload_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,8 @@
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
"device": "cpu"
},
"overlap_comm": true,
"contiguous_gradients": true,
Expand All @@ -34,5 +29,6 @@
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
}
}
},
"steps_per_print":1
}
3 changes: 2 additions & 1 deletion src/llamafactory/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SeqParallelDataCollatorForLanguageModeling
from .data_utils import Role, split_dataset
from .loader import get_dataset
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
Expand All @@ -7,6 +7,7 @@
__all__ = [
"KTODataCollatorWithPadding",
"PairwiseDataCollatorWithPadding",
"SeqParallelDataCollatorForLanguageModeling",
"Role",
"split_dataset",
"get_dataset",
Expand Down
87 changes: 85 additions & 2 deletions src/llamafactory/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from typing import Any, Dict, Sequence

import torch
from transformers import DataCollatorForSeq2Seq

from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from llamafactory.easy_context import prepare_seq_parallel_sft_inputs

@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
Expand Down Expand Up @@ -79,3 +80,85 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor

batch["kto_tags"] = torch.tensor(kto_tags)
return batch

@dataclass
class SeqParallelDataCollator(DataCollatorForSeq2Seq):
r"""
Data collator for sequence parallel in supervised finetune(sft) stage.
"""
seq_algo: str = "data_parallel",
sp_size: int = -1
rank: int = 0
world_size: int = 8
device: Optional[Any] = None

def __call__(self, features: Sequence[Dict[str, Any]], return_tensors=None) -> Dict[str, torch.Tensor]:
batch = super().__call__(features, return_tensors)
if self.seq_algo == "data_parallel":
return batch
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
world_size = self.world_size
sp_rank = self.rank
if self.sp_size != -1:
dp_rank = self.rank // self.sp_size
sp_rank = self.rank % self.sp_size
world_size = self.sp_size
bs = len(input_ids)
dp_size = self.world_size // self.sp_size
group_bs = bs // dp_size
input_ids = input_ids[dp_rank * group_bs: (dp_rank + 1) * group_bs]
attention_mask = attention_mask[dp_rank * group_bs: (dp_rank + 1) * group_bs]
labels = labels[dp_rank * group_bs: (dp_rank + 1) * group_bs]
batch = prepare_seq_parallel_sft_inputs(self.seq_algo,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=None,
labels=labels,
rank=sp_rank,
world_size=world_size,
device=self.device)
return batch


@dataclass
class SeqParallelDataCollatorForLanguageModeling(DataCollatorForLanguageModeling):
r"""
Data collator for sequence parallel in pretrain(pt) stage.
Reuse the sequence parallel distributing function for sft stage.
"""
seq_algo: str = "data_parallel"
sp_size: int = -1
rank: int = 0
world_size: int = 8
device: Optional[Any] = None

def __call__(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
batch = super().__call__(examples)
if self.seq_algo == "data_parallel":
return batch
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
world_size = self.world_size
sp_rank = self.rank
if self.sp_size != -1:
dp_rank = self.rank // self.sp_size
sp_rank = self.rank % self.sp_size
world_size = self.sp_size
bs = len(input_ids)
dp_size = self.world_size // self.sp_size
group_bs = bs // dp_size
input_ids = input_ids[dp_rank * group_bs: (dp_rank + 1) * group_bs]
attention_mask = attention_mask[dp_rank * group_bs: (dp_rank + 1) * group_bs]
labels = labels[dp_rank * group_bs: (dp_rank + 1) * group_bs]
batch = prepare_seq_parallel_sft_inputs(self.seq_algo,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=None,
labels=labels,
rank=sp_rank,
world_size=world_size,
device=self.device)
return batch
17 changes: 15 additions & 2 deletions src/llamafactory/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,18 @@ def get_dataset(
if has_tokenized_data(data_args.tokenized_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
dataset = load_from_disk(data_args.tokenized_path)
# ---lsy---
to_remove = [col for col in dataset.column_names if col != "input_ids"]
# import copy
# first_item = copy.deepcopy(dataset[0]['input_ids'])
def update_column(example):
example['input_ids'] = example['input_ids'][:data_args.cutoff_len]
# example['input_ids'] = first_item[:data_args.cutoff_len]
return example

# # 使用 map 方法添加新列
dataset = dataset.map(update_column,remove_columns=to_remove)
# ---lsy---
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
if data_args.streaming:
dataset = dataset.to_iterable_dataset()
Expand All @@ -166,6 +178,7 @@ def get_dataset(
data_args, training_args, stage, template, tokenizer, processor
)
column_names = list(next(iter(dataset)).keys())
logger.debug(f"remove_columns:{column_names}")
kwargs = {}
if not data_args.streaming:
kwargs = dict(
Expand All @@ -175,9 +188,9 @@ def get_dataset(
)

dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)

if data_args.tokenized_path is not None:
if training_args.should_save:
if training_args.should_save:
dataset.save_to_disk(data_args.tokenized_path)
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
Expand Down
85 changes: 85 additions & 0 deletions src/llamafactory/easy_context/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from .dist_flash_attn.prepare_input import prepare_dist_flash_attn_inputs, prepare_dist_flash_attn_sft_inputs
from .dist_flash_attn.monkey_patch import apply_dist_flash_attn_monkey_patch_llama
from .zigzag_ring_attn.prepare_inputs import prepare_zigzag_ring_attn_inputs, prepare_zigzag_ring_attn_sft_inputs
from .zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_llama
from .unsloth_offloaded_gradient_checkpoint.monkey_patch import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
from .ulysses_attn.prepare_inputs import prepare_ulysses_attn_inputs, prepare_ulysses_attn_sft_inputs
from .ulysses_attn.monkey_patch import apply_ulysses_attn_monkey_patch_llama
import torch
import torch.nn.functional as F

def prepare_seq_parallel_inputs(
seq_algo, input_ids, position_ids, target_ids, rank, world_size, device
):
if seq_algo == "zigzag_ring_attn":
return prepare_zigzag_ring_attn_inputs(
input_ids, position_ids, target_ids, rank, world_size, device
)
elif seq_algo == "dist_flash_attn":
return prepare_dist_flash_attn_inputs(
input_ids, position_ids, target_ids, rank, world_size, device
)
elif seq_algo == "ulysses_attn":
return prepare_ulysses_attn_inputs(
input_ids, position_ids, target_ids, rank, world_size, device
)
elif seq_algo == "data_parallel":
return {
"local_input_ids": input_ids.to(device),
"local_position_ids": position_ids.to(device),
"local_target_ids": target_ids.to(device),
}
else:
raise ValueError(f"Invalid seq_algo: {seq_algo}")

def prepare_seq_parallel_sft_inputs(
seq_algo, input_ids, attention_mask, position_ids, labels, rank, world_size, device
):
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
shift_labels = F.pad(labels, [0, 1], 'constant', -100)[:, 1:]
if seq_algo == "zigzag_ring_attn":
return prepare_zigzag_ring_attn_sft_inputs(
input_ids, attention_mask, position_ids, shift_labels, rank, world_size, device
)
elif seq_algo == "dist_flash_attn":
return prepare_dist_flash_attn_sft_inputs(
input_ids, attention_mask, position_ids, shift_labels, rank, world_size, device
)
elif seq_algo == "ulysses_attn":
return prepare_ulysses_attn_sft_inputs(
input_ids, attention_mask, position_ids, shift_labels, rank, world_size, device
)
elif seq_algo == "data_parallel":
return {
"input_ids": input_ids,
"position_ids": position_ids,
"attention_mask": attention_mask,
"target_ids": labels,
}
else:
raise ValueError(f"Invalid seq_algo: {seq_algo}")

def apply_seq_parallel_monkey_patch(
seq_algo, model, sp_size=None, enable_offload=False, offload_percent=0.
):
assert seq_algo in ["zigzag_ring_attn", "dist_flash_attn", "ulysses_attn", "data_parallel"], f"Invalid seq_algo: {seq_algo}"
assert model in ["llama", "mistral"], f"Invalid model: {model}"
if seq_algo == "data_parallel":
return
elif seq_algo == "zigzag_ring_attn" and model == "llama":
apply_zigzag_ring_attn_monkey_patch_llama(sp_size=sp_size)
elif seq_algo == "dist_flash_attn" and model == "llama":
apply_dist_flash_attn_monkey_patch_llama(sp_size=sp_size, enable_offload=enable_offload, offload_percent=offload_percent)
elif seq_algo == "ulysses_attn" and model == "llama":
apply_ulysses_attn_monkey_patch_llama(sp_size=sp_size)
else:
raise ValueError(f"Invalid seq_algo: {seq_algo} or model: {model}")

def prepare_dataloader(seq_algo, dataloader, acclerator):
if seq_algo == "data_parallel":
return acclerator.prepare(dataloader)
else:
return dataloader
11 changes: 11 additions & 0 deletions src/llamafactory/easy_context/dist_flash_attn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# LightSeq
Taken from https://github.com/RulinShao/LightSeq. All credits to the authors.

```
@article{li2023lightseq,
title={LIGHTSEQ: SEQUENCE LEVEL PARALLELISM FOR DISTRIBUTED TRAINING OF LONG CONTEXT TRANS},
author={Li, Dacheng and Shao, Rulin and Xie𝑠, Anze and Xing𝑐𝑚, Eric P and Gonzalez𝑏, Joseph E and Stoica𝑏, Ion and Ma𝑢, Xuezhe and Zhang𝑠, Hao},
journal={arXiv preprint arXiv:2310.03294},
year={2023}
}
```
Loading