From cdd3af25c56d50582b15b9779cf352711ab798f5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 26 Aug 2024 13:14:59 -0400 Subject: [PATCH 01/12] jamba liger fused linear+xentropy --- src/liger_kernel/transformers/model/jamba.py | 168 +++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 src/liger_kernel/transformers/model/jamba.py diff --git a/src/liger_kernel/transformers/model/jamba.py b/src/liger_kernel/transformers/model/jamba.py new file mode 100644 index 00000000..a10b4b43 --- /dev/null +++ b/src/liger_kernel/transformers/model/jamba.py @@ -0,0 +1,168 @@ +from typing import Optional, Tuple, Union + +import torch +from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, +) +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import MoeCausalLMOutputWithPast +from transformers.models.jamba.modeling_jamba import ( + _CONFIG_FOR_DOC, + JAMBA_INPUTS_DOCSTRING, + HybridMambaAttentionDynamicCache, + load_balancing_loss_func, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + + +@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[Union[int, None]] = None, +) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, JambaForCausalLM + + >>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + else: + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states) + else: + logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) + logits = logits.float() + + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) From d05f82b020d93fdd48ffbe7b004a24167ce57eb8 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Wed, 28 Aug 2024 08:23:04 +0000 Subject: [PATCH 02/12] fix apply jamba function add test logits --- src/liger_kernel/transformers/__init__.py | 1 + src/liger_kernel/transformers/monkey_patch.py | 38 +++++++++++++++++++ test/convergence/test_mini_models.py | 33 +++++++++++++++- .../convergence/test_mini_models_no_logits.py | 29 +++++++++++++- 4 files changed, 99 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 5c44a154..f5a20110 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -15,6 +15,7 @@ apply_liger_kernel_to_mixtral, apply_liger_kernel_to_phi3, apply_liger_kernel_to_qwen2, + apply_liger_kernel_to_jamba, ) from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401 from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401 diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 6ecb670e..775618b8 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -10,6 +10,7 @@ from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward +from liger_kernel.transformers.model.jamba import lce_forward as jamba_lce_forward from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.swiglu import ( @@ -283,6 +284,42 @@ def apply_liger_kernel_to_phi3( modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward +def apply_liger_kernel_to_jamba( + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Jamba models + to make GPU go burrr. + + # Note: Jamba model does not use rotary position embedding(RoPE). + + Args: + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused lienar cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True. + """ + assert not ( + cross_entropy and fused_linear_cross_entropy + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + + from transformers.models.jamba import modeling_jamba + if rms_norm: + modeling_jamba.JambaRMSNorm = LigerRMSNorm + if cross_entropy: + modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss + if swiglu: + modeling_jamba.JambaMLP = LigerSwiGLUMLP + if fused_linear_cross_entropy: + modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward + + # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py MODEL_TYPE_TO_APPLY_LIGER_FN = { "gemma": apply_liger_kernel_to_gemma, @@ -292,6 +329,7 @@ def apply_liger_kernel_to_phi3( "mixtral": apply_liger_kernel_to_mixtral, "qwen2": apply_liger_kernel_to_qwen2, "phi3": apply_liger_kernel_to_phi3, + "jamba": apply_liger_kernel_to_jamba, } diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index 95c832e1..afaa6b34 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -20,10 +20,12 @@ from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM +from transformers.models.jamba import JambaConfig, JambaForCausalLM from liger_kernel.transformers import ( apply_liger_kernel_to_gemma, apply_liger_kernel_to_gemma2, + apply_liger_kernel_to_jamba, apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, @@ -279,6 +281,31 @@ attn_implementation="eager", ), ), + "mini_jamba": MiniModelConfig( + liger_kernel_patch_func=functools.partial( + apply_liger_kernel_to_jamba, fused_linear_cross_entropy=False + ), + model_class=JambaForCausalLM, + mini_model_config=JambaConfig( + attention_dropout=0.0, + num_experts_per_tok=1, + num_experts=2, + bos_token_id=1, + eos_token_id=2, # 32000 + hidden_act="silu", + hidden_size=1024, # 3072 + initializer_range=0.02, + intermediate_size=2048, # 8192 + max_position_embeddings=32768, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + rms_norm_eps=1e-5, + sliding_window=None, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32064, + ), + ), } @@ -316,8 +343,12 @@ def run_mini_model( kwargs["geglu"] = True else: kwargs["swiglu"] = True + if model_name == "mini_jamba": + del kwargs["rope"] MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) + from torch.profiler import profile, record_function, ProfilerActivity + model = create_model(model_name).to(dtype).to("cuda") train_dataset = load_from_disk(DEFAULT_DATASET_PATH) @@ -328,7 +359,6 @@ def run_mini_model( optimizer = torch.optim.AdamW(model.parameters(), lr=lr) loss_list = [] - for i in range(num_steps): batch = next(loader_iter).to(model.device) output = model(**batch) @@ -460,6 +490,7 @@ def run_mini_model( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), + ("mini_jamba", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), ], ) def test_mini_model( diff --git a/test/convergence/test_mini_models_no_logits.py b/test/convergence/test_mini_models_no_logits.py index 67ce443c..584a6729 100644 --- a/test/convergence/test_mini_models_no_logits.py +++ b/test/convergence/test_mini_models_no_logits.py @@ -18,6 +18,7 @@ from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM +from transformers import JambaConfig, JambaForCausalLM from liger_kernel.transformers import ( apply_liger_kernel_to_gemma, @@ -27,6 +28,7 @@ apply_liger_kernel_to_mixtral, apply_liger_kernel_to_phi3, apply_liger_kernel_to_qwen2, + apply_liger_kernel_to_jamba, ) MINI_MODEL_SETUPS = { @@ -247,6 +249,29 @@ attention_dropout=0.0, ), ), + "mini_jamba": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_jamba, + model_class=JambaForCausalLM, + mini_model_config=JambaConfig( + attention_dropout=0.0, + num_experts_per_tok=1, + num_experts=2, + bos_token_id=1, + eos_token_id=2, # 32000 + hidden_act="silu", + hidden_size=1024, # 3072 + initializer_range=0.02, + intermediate_size=2048, # 8192 + max_position_embeddings=32768, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + rms_norm_eps=1e-5, + sliding_window=None, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32064, + ), + ), } @@ -283,7 +308,8 @@ def run_mini_model( kwargs["geglu"] = True else: kwargs["swiglu"] = True - + if model_name == "mini_jamba": + del kwargs["rope"] model_support_flce = "gemma2" not in model_name if model_support_flce: kwargs["fused_linear_cross_entropy"] = True @@ -446,6 +472,7 @@ def run_mini_model( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), + ("mini_jamba", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5), ], ) def test_mini_model( From 99d3553aab198dd16d250e0ba3aefedc780a0d96 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Thu, 5 Sep 2024 06:59:57 +0000 Subject: [PATCH 03/12] lint --- src/liger_kernel/transformers/__init__.py | 2 +- src/liger_kernel/transformers/model/jamba.py | 35 ++++++++++--------- src/liger_kernel/transformers/monkey_patch.py | 15 ++++---- test/convergence/test_mini_models.py | 4 +-- .../convergence/test_mini_models_no_logits.py | 4 +-- 5 files changed, 31 insertions(+), 29 deletions(-) diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index f5a20110..98435a27 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -10,12 +10,12 @@ from liger_kernel.transformers.monkey_patch import ( # noqa: F401 apply_liger_kernel_to_gemma, apply_liger_kernel_to_gemma2, + apply_liger_kernel_to_jamba, apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, apply_liger_kernel_to_phi3, apply_liger_kernel_to_qwen2, - apply_liger_kernel_to_jamba, ) from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401 from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401 diff --git a/src/liger_kernel/transformers/model/jamba.py b/src/liger_kernel/transformers/model/jamba.py index a10b4b43..8bef7304 100644 --- a/src/liger_kernel/transformers/model/jamba.py +++ b/src/liger_kernel/transformers/model/jamba.py @@ -1,9 +1,6 @@ from typing import Optional, Tuple, Union import torch -from liger_kernel.transformers.fused_linear_cross_entropy import ( - LigerFusedLinearCrossEntropyLoss, -) from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import MoeCausalLMOutputWithPast from transformers.models.jamba.modeling_jamba import ( @@ -17,26 +14,30 @@ replace_return_docstrings, ) +from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, +) + @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC ) def lce_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: Optional[Union[int, None]] = None, + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[Union[int, None]] = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 775618b8..594a118f 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -5,12 +5,12 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward +from liger_kernel.transformers.model.jamba import lce_forward as jamba_lce_forward from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward -from liger_kernel.transformers.model.jamba import lce_forward as jamba_lce_forward from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.swiglu import ( @@ -285,15 +285,15 @@ def apply_liger_kernel_to_phi3( def apply_liger_kernel_to_jamba( - cross_entropy: bool = False, - fused_linear_cross_entropy: bool = True, - rms_norm: bool = True, - swiglu: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, ) -> None: """ Apply Liger kernels to replace original implementation in HuggingFace Jamba models to make GPU go burrr. - + # Note: Jamba model does not use rotary position embedding(RoPE). Args: @@ -306,10 +306,11 @@ def apply_liger_kernel_to_jamba( geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True. """ assert not ( - cross_entropy and fused_linear_cross_entropy + cross_entropy and fused_linear_cross_entropy ), "cross_entropy and fused_linear_cross_entropy cannot both be True." from transformers.models.jamba import modeling_jamba + if rms_norm: modeling_jamba.JambaRMSNorm = LigerRMSNorm if cross_entropy: diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index afaa6b34..70d54243 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -15,12 +15,12 @@ from torch.utils.data import DataLoader from transformers.models.gemma import GemmaConfig, GemmaForCausalLM from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM +from transformers.models.jamba import JambaConfig, JambaForCausalLM from transformers.models.llama import LlamaConfig, LlamaForCausalLM from transformers.models.mistral import MistralConfig, MistralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM -from transformers.models.jamba import JambaConfig, JambaForCausalLM from liger_kernel.transformers import ( apply_liger_kernel_to_gemma, @@ -347,7 +347,7 @@ def run_mini_model( del kwargs["rope"] MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) - from torch.profiler import profile, record_function, ProfilerActivity + from torch.profiler import ProfilerActivity, profile, record_function model = create_model(model_name).to(dtype).to("cuda") train_dataset = load_from_disk(DEFAULT_DATASET_PATH) diff --git a/test/convergence/test_mini_models_no_logits.py b/test/convergence/test_mini_models_no_logits.py index 584a6729..9247f0b3 100644 --- a/test/convergence/test_mini_models_no_logits.py +++ b/test/convergence/test_mini_models_no_logits.py @@ -11,6 +11,7 @@ import torch from datasets import load_from_disk from torch.utils.data import DataLoader +from transformers import JambaConfig, JambaForCausalLM from transformers.models.gemma import GemmaConfig, GemmaForCausalLM from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM from transformers.models.llama import LlamaConfig, LlamaForCausalLM @@ -18,17 +19,16 @@ from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM -from transformers import JambaConfig, JambaForCausalLM from liger_kernel.transformers import ( apply_liger_kernel_to_gemma, apply_liger_kernel_to_gemma2, + apply_liger_kernel_to_jamba, apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, apply_liger_kernel_to_phi3, apply_liger_kernel_to_qwen2, - apply_liger_kernel_to_jamba, ) MINI_MODEL_SETUPS = { From 7602cf74f13bb5a7a4565983a5677f70e9be8d78 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Thu, 5 Sep 2024 16:39:03 +0000 Subject: [PATCH 04/12] remove profiler --- test/convergence/test_mini_models.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index 70d54243..359b5467 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -347,8 +347,6 @@ def run_mini_model( del kwargs["rope"] MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) - from torch.profiler import ProfilerActivity, profile, record_function - model = create_model(model_name).to(dtype).to("cuda") train_dataset = load_from_disk(DEFAULT_DATASET_PATH) From 204742f35d916bfd49d848accc0967307133f594 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Fri, 6 Sep 2024 06:57:28 +0000 Subject: [PATCH 05/12] add jamba required deps into dev --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 0cba2e2a..67fce122 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,8 @@ dev = [ "pytest>=7.1.2", "datasets>=2.19.2", "jupyter==1.0.0", + "causal-conv1d>=1.4.0", + "mamba-ssm>=2.2.2", "seaborn", ] From 2cef594d96be48849cb3c6ee9c982f4bfb170e98 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Fri, 6 Sep 2024 07:39:05 +0000 Subject: [PATCH 06/12] bump setuptool version for causal-conv1d installation --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 67fce122..1336cfcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=42", "wheel"] +requires = ["setuptools>=69.5.1", "wheel"] build-backend = "setuptools.build_meta" [project] From 65597cf5c749b30e64740733535bc87595e905b2 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Fri, 6 Sep 2024 07:58:41 +0000 Subject: [PATCH 07/12] change ci yaml to install torch beforehead --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f41afdb6..e00ec119 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,7 +24,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 isort black + pip install setuptools==69.5.1 flake8 isort black torch - name: Run checkstyle run: make checkstyle \ No newline at end of file From cfd1eda4b493aa13330a865b6699d476582b4715 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Sat, 7 Sep 2024 06:38:47 +0000 Subject: [PATCH 08/12] split conv1d and mamba to separate dependencies --- .github/workflows/ci.yml | 2 +- pyproject.toml | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e00ec119..f41afdb6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,7 +24,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install setuptools==69.5.1 flake8 isort black torch + pip install flake8 isort black - name: Run checkstyle run: make checkstyle \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 01a2226b..8c5f76fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=69.5.1", "wheel"] +requires = ["setuptools>=61.0", "wheel"] build-backend = "setuptools.build_meta" [project] @@ -29,6 +29,11 @@ dev = [ "seaborn", ] +test = [ + "causal-conv1d>=1.4.0", + "mamba-ssm>=2.2.2", +] + [tool.setuptools.packages.find] where = ["src"] include = ["liger_kernel"] From e46fb171d20b14660d6112b13c7749dcd22ef780 Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Sat, 7 Sep 2024 07:03:39 +0000 Subject: [PATCH 09/12] fix contributing guide --- CONTRIBUTING.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2ea925a5..23b5d0c7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -40,6 +40,13 @@ Add a benchmarking script under `benchmark/scripts` using the naming convention ## Run tests +### Install test dependencies +To run convergence test, you will need to install mamba dependencies separately after installing dev dependencies. +This is due to a ongoing build [issue](https://github.com/state-spaces/mamba/issues/481) with mamba-ssm and causal-conv1d. +``` +pip install .'[test]' +``` + ### Use Makefile to run full tests 1. Run `make test` to ensure correctness. 2. Run `make checkstyle` to ensure code style. From c64791f7fdb93048e16fbf008b1d03b2f78eea3c Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Sat, 7 Sep 2024 07:04:21 +0000 Subject: [PATCH 10/12] remove old deps --- pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8c5f76fb..4664606f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,8 +24,6 @@ dev = [ "pytest>=7.1.2", "datasets>=2.19.2", "jupyter==1.0.0", - "causal-conv1d>=1.4.0", - "mamba-ssm>=2.2.2", "seaborn", ] From cf61cb2f670cf02694049567d28603d1e579446f Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Sat, 7 Sep 2024 07:16:04 +0000 Subject: [PATCH 11/12] remove jamba from test --- CONTRIBUTING.md | 5 ----- test/convergence/test_mini_models.py | 3 ++- test/convergence/test_mini_models_no_logits.py | 3 ++- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 23b5d0c7..354d8340 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -41,11 +41,6 @@ Add a benchmarking script under `benchmark/scripts` using the naming convention ## Run tests ### Install test dependencies -To run convergence test, you will need to install mamba dependencies separately after installing dev dependencies. -This is due to a ongoing build [issue](https://github.com/state-spaces/mamba/issues/481) with mamba-ssm and causal-conv1d. -``` -pip install .'[test]' -``` ### Use Makefile to run full tests 1. Run `make test` to ensure correctness. diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index 359b5467..4727d3a7 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -488,7 +488,8 @@ def run_mini_model( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), - ("mini_jamba", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + # To run this test, you need to first run `pip install . '[test]'` + # ("mini_jamba", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), ], ) def test_mini_model( diff --git a/test/convergence/test_mini_models_no_logits.py b/test/convergence/test_mini_models_no_logits.py index 9247f0b3..819d0252 100644 --- a/test/convergence/test_mini_models_no_logits.py +++ b/test/convergence/test_mini_models_no_logits.py @@ -472,7 +472,8 @@ def run_mini_model( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), - ("mini_jamba", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5), + # To run this test, you need to first run `pip install . '[test]'` + # ("mini_jamba", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5), ], ) def test_mini_model( From 68494dfeb0aab25df773153359be1f2cf6f7bb2b Mon Sep 17 00:00:00 2001 From: Yubo Wang Date: Sat, 7 Sep 2024 07:18:13 +0000 Subject: [PATCH 12/12] fix contribute guide --- CONTRIBUTING.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 354d8340..2ea925a5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -40,8 +40,6 @@ Add a benchmarking script under `benchmark/scripts` using the naming convention ## Run tests -### Install test dependencies - ### Use Makefile to run full tests 1. Run `make test` to ensure correctness. 2. Run `make checkstyle` to ensure code style.