diff --git a/cpp/metadata/model.cc b/cpp/metadata/model.cc index 66b71a8efe..ffa366f8e0 100644 --- a/cpp/metadata/model.cc +++ b/cpp/metadata/model.cc @@ -46,6 +46,7 @@ ModelMetadata::KVCacheMetadata ModelMetadata::KVCacheMetadata::FromJSON( kv_cache_metadata.head_dim = json::Lookup(json, "head_dim"); kv_cache_metadata.num_attention_heads = json::Lookup(json, "num_attention_heads"); kv_cache_metadata.num_key_value_heads = json::Lookup(json, "num_key_value_heads"); + kv_cache_metadata.kv_nbits = json::Lookup(json, "kv_nbits"); return kv_cache_metadata; } @@ -73,7 +74,8 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata, result.kv_cache_metadata = {/*num_hidden_layers=*/0, /*head_dim=*/0, /*num_attention_heads=*/0, - /*num_key_value_heads=*/0}; + /*num_key_value_heads=*/0, + /*kv_nbits=*/0}; } { std::vector& params = result.params; diff --git a/cpp/metadata/model.h b/cpp/metadata/model.h index e677918e21..dd3173f90e 100644 --- a/cpp/metadata/model.h +++ b/cpp/metadata/model.h @@ -67,6 +67,7 @@ struct ModelMetadata { int64_t num_attention_heads; int64_t num_key_value_heads; int64_t head_dim; + int64_t kv_nbits; static KVCacheMetadata FromJSON(const picojson::object& json); }; diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 4df2bd8af9..b1e52e7006 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -567,8 +567,10 @@ Result EstimateMemoryUsageOnMode( int64_t head_dim = model_metadata[i].kv_cache_metadata.head_dim; int64_t num_qo_heads = model_metadata[i].kv_cache_metadata.num_attention_heads; int64_t num_kv_heads = model_metadata[i].kv_cache_metadata.num_key_value_heads; + int64_t kv_nbits = model_metadata[i].kv_cache_metadata.kv_nbits; int64_t hidden_size = head_dim * num_qo_heads; - kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25; + double kv_nbytes = kv_nbits / 8.0f; + kv_bytes_per_token += head_dim * num_kv_heads * num_layers * kv_nbytes * 2 + 1.25; kv_aux_workspace_bytes += (max_num_sequence + 1) * 88 + prefill_chunk_size * (num_qo_heads + 1) * 8 + prefill_chunk_size * head_dim * (num_qo_heads + num_kv_heads) * 4 + 48 * 1024 * 1024; diff --git a/python/mlc_llm/bench/metrics.py b/python/mlc_llm/bench/metrics.py index ab414c2ad9..ff62253e72 100644 --- a/python/mlc_llm/bench/metrics.py +++ b/python/mlc_llm/bench/metrics.py @@ -1,4 +1,5 @@ """ MLC LLM bench Metrics""" + import json from typing import Any, Callable, Dict, List, Optional, Union diff --git a/python/mlc_llm/bench/replay.py b/python/mlc_llm/bench/replay.py index 65fb325c34..fcf6a747d5 100644 --- a/python/mlc_llm/bench/replay.py +++ b/python/mlc_llm/bench/replay.py @@ -1,4 +1,5 @@ """MLC LLM bench replay request""" + import asyncio import json from datetime import datetime diff --git a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py index 20e4c7bdd9..69611f13f4 100644 --- a/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py +++ b/python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py @@ -6,6 +6,7 @@ from tvm import IRModule, relax from mlc_llm.nn import RopeMode, kv_cache +from mlc_llm.quantization import PagedKVCacheQuantization, get_kv_storage_dtype def extract_creation_args(func: relax.Function) -> Dict[str, Any]: @@ -20,13 +21,13 @@ def extract_creation_args(func: relax.Function) -> Dict[str, Any]: assert isinstance(args[0], relax.ExternFunc) assert args[0].global_symbol == "mlc.create_paged_kv_cache_generic" - assert len(args) == 11 + assert len(args) == 12 assert isinstance(args[1], relax.ShapeExpr) assert len(args[1].values) == 5 - for i in range(2, 10): + for i in range(2, 11): assert isinstance(args[i], relax.PrimValue) assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm)) - assert isinstance(args[10], relax.DataTypeImm) + assert isinstance(args[11], relax.DataTypeImm) return { "max_batch_size": args[1].values[0], @@ -42,7 +43,8 @@ def extract_creation_args(func: relax.Function) -> Dict[str, Any]: "rope_scale": args[7].value.value, "rope_theta": args[8].value.value, "rotary_dim": args[9].value.value, - "dtype": args[10].value, + "kv_quantization": args[10].value.value, + "dtype": args[11].value, } @@ -105,6 +107,10 @@ def attach_kv_cache_metadata(self, kwargs: Dict[str, Any]): "num_attention_heads": kwargs["num_attention_heads"], "num_key_value_heads": kwargs["num_key_value_heads"], "head_dim": kwargs["head_dim"], + "kv_nbits": get_kv_storage_dtype( + kv_quant_scheme=PagedKVCacheQuantization(kwargs["kv_quantization"]).name.lower(), + model_dtype=kwargs["dtype"], + ).bits, } def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]) -> None: @@ -156,6 +162,8 @@ def create_flashinfer_paged_kv_cache( ) # filter by attention group size or kwargs["num_attention_heads"] // kwargs["num_key_value_heads"] not in [1, 4, 8] + # MLC-LLM with FlashInfer and KV Quantization not supported yet + or kwargs["kv_quantization"] != PagedKVCacheQuantization.KV_NO_QUANT ): return diff --git a/python/mlc_llm/model/baichuan/baichuan_model.py b/python/mlc_llm/model/baichuan/baichuan_model.py index bce68b830a..072d5c349c 100644 --- a/python/mlc_llm/model/baichuan/baichuan_model.py +++ b/python/mlc_llm/model/baichuan/baichuan_model.py @@ -12,6 +12,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -41,6 +42,7 @@ class BaichuanConfig(ConfigBase): # pylint: disable=too-many-instance-attribute tensor_parallel_shards: int = 1 max_batch_size: int = 1 head_dim: int = 0 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -203,6 +205,7 @@ def __init__(self, config: BaichuanConfig): self.vocab_size = config.vocab_size self.rope_theta = 10000 self.tensor_parallel_shards = config.tensor_parallel_shards + self.kv_quantization = config.kv_quantization self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -291,6 +294,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/baichuan/baichuan_quantization.py b/python/mlc_llm/model/baichuan/baichuan_quantization.py index 2bad7e3349..86fbb78fc3 100644 --- a/python/mlc_llm/model/baichuan/baichuan_quantization.py +++ b/python/mlc_llm/model/baichuan/baichuan_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a BaichuanLM-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = BaichuanForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/chatglm3/chatglm3_model.py b/python/mlc_llm/model/chatglm3/chatglm3_model.py index fa4b24e87a..c538ce243d 100644 --- a/python/mlc_llm/model/chatglm3/chatglm3_model.py +++ b/python/mlc_llm/model/chatglm3/chatglm3_model.py @@ -12,6 +12,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -43,6 +44,7 @@ class GLMConfig(ConfigBase): # pylint: disable=too-many-instance-attributes tensor_parallel_shards: int = 1 head_dim: int = 0 max_batch_size: int = 1 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -279,6 +281,7 @@ def __init__(self, config: GLMConfig): self.vocab_size = config.vocab_size self.rope_theta = 10000 self.tensor_parallel_shards = config.tensor_parallel_shards + self.kv_quantization = config.kv_quantization self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -367,6 +370,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/chatglm3/chatglm3_quantization.py b/python/mlc_llm/model/chatglm3/chatglm3_quantization.py index 172188a557..833d4b70de 100644 --- a/python/mlc_llm/model/chatglm3/chatglm3_quantization.py +++ b/python/mlc_llm/model/chatglm3/chatglm3_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a ChatGLM-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = ChatGLMForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/eagle/eagle_model.py b/python/mlc_llm/model/eagle/eagle_model.py index 9d7820b841..e2e33fb446 100644 --- a/python/mlc_llm/model/eagle/eagle_model.py +++ b/python/mlc_llm/model/eagle/eagle_model.py @@ -90,6 +90,7 @@ def __init__(self, config: EagleConfig): self.vocab_size = config.vocab_size self.rope_theta = config.position_embedding_base self.tensor_parallel_shards = config.tensor_parallel_shards + self.kv_quantization = config.kv_quantization self.dtype = "float32" def fuse_embed_hidden_states(self, input_embed: Tensor, hidden_states: Tensor): @@ -177,6 +178,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/eagle/eagle_quantization.py b/python/mlc_llm/model/eagle/eagle_quantization.py index 4510a17d2c..03bb016846 100644 --- a/python/mlc_llm/model/eagle/eagle_quantization.py +++ b/python/mlc_llm/model/eagle/eagle_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Eagle-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = EagleForCasualLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/gemma/gemma_model.py b/python/mlc_llm/model/gemma/gemma_model.py index b3ee189a51..3e6ba5530f 100644 --- a/python/mlc_llm/model/gemma/gemma_model.py +++ b/python/mlc_llm/model/gemma/gemma_model.py @@ -9,6 +9,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -36,6 +37,7 @@ class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 max_batch_size: int = 1 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -228,6 +230,7 @@ def __init__(self, config: GemmaConfig): self.vocab_size = config.vocab_size self.rope_theta = config.position_embedding_base self.tensor_parallel_shards = config.tensor_parallel_shards + self.kv_quantization = config.kv_quantization self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -316,6 +319,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/gemma/gemma_quantization.py b/python/mlc_llm/model/gemma/gemma_quantization.py index 48a5bbfedc..eae539e355 100644 --- a/python/mlc_llm/model/gemma/gemma_quantization.py +++ b/python/mlc_llm/model/gemma/gemma_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Gemma-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = GemmaForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/gpt2/gpt2_model.py b/python/mlc_llm/model/gpt2/gpt2_model.py index d24b73955b..8d05fd7fc3 100644 --- a/python/mlc_llm/model/gpt2/gpt2_model.py +++ b/python/mlc_llm/model/gpt2/gpt2_model.py @@ -12,6 +12,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -36,6 +37,7 @@ class GPT2Config(ConfigBase): # pylint: disable=too-many-instance-attributes tensor_parallel_shards: int = 1 head_dim: int = 0 max_batch_size: int = 1 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -223,6 +225,7 @@ def __init__(self, config: GPT2Config): self.n_head = config.n_head self.head_dim = config.head_dim self.tensor_parallel_shards = config.tensor_parallel_shards + self.kv_quantization = config.kv_quantization self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -311,6 +314,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NONE, rope_scale=-1, rope_theta=-1, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/gpt2/gpt2_quantization.py b/python/mlc_llm/model/gpt2/gpt2_quantization.py index 8b722f4b06..b825136fe9 100644 --- a/python/mlc_llm/model/gpt2/gpt2_quantization.py +++ b/python/mlc_llm/model/gpt2/gpt2_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a GPT-2-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = GPT2LMHeadModel(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py index fd84601112..9359a1c2b8 100644 --- a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py +++ b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py @@ -12,6 +12,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -35,6 +36,7 @@ class GPTBigCodeConfig(ConfigBase): # pylint: disable=too-many-instance-attribu prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 max_batch_size: int = 1 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -190,6 +192,7 @@ def __init__(self, config: GPTBigCodeConfig): self.num_kv_heads = 1 self.head_dim = config.n_embd // config.n_head self.tensor_parallel_shards = config.tensor_parallel_shards + self.kv_quantization = config.kv_quantization self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -278,6 +281,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NONE, rope_scale=-1, rope_theta=-1, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_quantization.py b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_quantization.py index f6f1ff3cda..2e4b4ecf63 100644 --- a/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_quantization.py +++ b/python/mlc_llm/model/gpt_bigcode/gpt_bigcode_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a GPTBigCode-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = GPTBigCodeForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py index c7832ea68e..88d41dc8db 100644 --- a/python/mlc_llm/model/gpt_neox/gpt_neox_model.py +++ b/python/mlc_llm/model/gpt_neox/gpt_neox_model.py @@ -13,6 +13,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase from mlc_llm.support.style import bold @@ -39,6 +40,7 @@ class GPTNeoXConfig(ConfigBase): # pylint: disable=too-many-instance-attributes tensor_parallel_shards: int = 1 ffn_out_dtype: str = "float32" max_batch_size: int = 1 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -253,6 +255,7 @@ def __init__(self, config: GPTNeoXConfig): self.vocab_size = config.vocab_size self.rope_theta = config.position_embedding_base self.tensor_parallel_shards = config.tensor_parallel_shards + self.kv_quantization = config.kv_quantization self.dtype = "float32" self.rotary_pct = config.rotary_pct @@ -342,6 +345,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, rotary_dim=int(self.head_dim * self.rotary_pct), ) diff --git a/python/mlc_llm/model/gpt_neox/gpt_neox_quantization.py b/python/mlc_llm/model/gpt_neox/gpt_neox_quantization.py index 61dbe6d6ae..f85fad3e46 100644 --- a/python/mlc_llm/model/gpt_neox/gpt_neox_quantization.py +++ b/python/mlc_llm/model/gpt_neox/gpt_neox_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a GPTNeoX-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = GPTNeoXForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/internlm/internlm_model.py b/python/mlc_llm/model/internlm/internlm_model.py index 4c7793ca2a..ac6f931f2c 100644 --- a/python/mlc_llm/model/internlm/internlm_model.py +++ b/python/mlc_llm/model/internlm/internlm_model.py @@ -12,6 +12,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -40,6 +41,7 @@ class InternLMConfig(ConfigBase): # pylint: disable=too-many-instance-attribute tensor_parallel_shards: int = 1 max_batch_size: int = 1 head_dim: int = 0 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -214,6 +216,7 @@ def __init__(self, config: InternLMConfig): self.vocab_size = config.vocab_size self.rope_theta = 10000 self.tensor_parallel_shards = config.tensor_parallel_shards + self.kv_quantization = config.kv_quantization self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -302,6 +305,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/internlm/internlm_quantization.py b/python/mlc_llm/model/internlm/internlm_quantization.py index de302686ca..d5c4ec0d24 100644 --- a/python/mlc_llm/model/internlm/internlm_quantization.py +++ b/python/mlc_llm/model/internlm/internlm_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a InternLM-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = InternLMForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/internlm2/internlm2_model.py b/python/mlc_llm/model/internlm2/internlm2_model.py index 75af3b86a8..3eb0727a0d 100644 --- a/python/mlc_llm/model/internlm2/internlm2_model.py +++ b/python/mlc_llm/model/internlm2/internlm2_model.py @@ -12,6 +12,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -42,6 +43,7 @@ class InternLM2Config(ConfigBase): # pylint: disable=too-many-instance-attribut tensor_parallel_shards: int = 1 max_batch_size: int = 1 head_dim: int = 0 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -222,6 +224,7 @@ def __init__(self, config: InternLM2Config): self.head_dim = config.head_dim self.rope_theta = config.rope_theta self.tensor_parallel_shards = config.tensor_parallel_shards + self.kv_quantization = config.kv_quantization def to(self, dtype: Optional[str] = None): super().to(dtype=dtype) @@ -309,6 +312,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/internlm2/internlm2_quantization.py b/python/mlc_llm/model/internlm2/internlm2_quantization.py index 38d6bea342..f9bdcddbbe 100644 --- a/python/mlc_llm/model/internlm2/internlm2_quantization.py +++ b/python/mlc_llm/model/internlm2/internlm2_quantization.py @@ -1,5 +1,6 @@ """This file specifies how MLC's InternLM2 parameters are quantized using group quantization or other formats.""" + from typing import Tuple from tvm.relax.frontend import nn @@ -15,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a InternLM2-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = InternLM2ForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/llama/llama_model.py b/python/mlc_llm/model/llama/llama_model.py index 62c07ba324..f7bbe8778b 100644 --- a/python/mlc_llm/model/llama/llama_model.py +++ b/python/mlc_llm/model/llama/llama_model.py @@ -12,6 +12,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -37,6 +38,7 @@ class LlamaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes head_dim: int = 0 tensor_parallel_shards: int = 1 max_batch_size: int = 1 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -209,6 +211,7 @@ def __init__(self, config: LlamaConfig): self.vocab_size = config.vocab_size self.rope_theta = config.position_embedding_base self.tensor_parallel_shards = config.tensor_parallel_shards + self.kv_quantization = config.kv_quantization self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -344,6 +347,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/llama/llama_quantization.py b/python/mlc_llm/model/llama/llama_quantization.py index 26b6e0e728..18f8493c78 100644 --- a/python/mlc_llm/model/llama/llama_quantization.py +++ b/python/mlc_llm/model/llama/llama_quantization.py @@ -22,6 +22,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Llama-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = LlamaForCasualLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/llava/llava_model.py b/python/mlc_llm/model/llava/llava_model.py index ed2c585c59..e9ebdda68d 100644 --- a/python/mlc_llm/model/llava/llava_model.py +++ b/python/mlc_llm/model/llava/llava_model.py @@ -25,6 +25,7 @@ from mlc_llm import op as op_ext from mlc_llm.model.model_preset import MODEL_PRESETS from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from ...support.config import ConfigBase from ..llama.llama_model import LlamaConfig, LlamaForCasualLM @@ -72,6 +73,7 @@ class LlavaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes tensor_parallel_shards: int = 1 max_batch_size: int = 1 text_architecture: str = "LlamaForCausalLM" + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self) -> None: @@ -369,6 +371,7 @@ def __init__(self, config: LlavaConfig): self.multi_modal_projector = LlavaMultiModalProjector(config) self.language_model = ARCHITECTURE_MAP[config.text_architecture](config.text_config) self.vocab_size = config.vocab_size + self.kv_quantization = config.kv_quantization self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -450,6 +453,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.language_model.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/llava/llava_quantization.py b/python/mlc_llm/model/llava/llava_quantization.py index 79bd6ecdcb..a7bd1b525b 100644 --- a/python/mlc_llm/model/llava/llava_quantization.py +++ b/python/mlc_llm/model/llava/llava_quantization.py @@ -15,6 +15,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Llava model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = LlavaForCasualLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/mistral/mistral_model.py b/python/mlc_llm/model/mistral/mistral_model.py index 8179b99552..c67bbdd33d 100644 --- a/python/mlc_llm/model/mistral/mistral_model.py +++ b/python/mlc_llm/model/mistral/mistral_model.py @@ -11,6 +11,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -38,6 +39,7 @@ class MistralConfig(ConfigBase): # pylint: disable=too-many-instance-attributes attention_sink_size: int = 4 tensor_parallel_shards: int = 1 max_batch_size: int = 1 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): # pylint: disable=too-many-branches @@ -228,6 +230,7 @@ def __init__(self, config: MistralConfig): self.rope_theta = config.position_embedding_base self.tensor_parallel_shards = config.tensor_parallel_shards self.sliding_window_size = config.sliding_window_size + self.kv_quantization = config.kv_quantization self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -316,6 +319,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/mistral/mistral_quantization.py b/python/mlc_llm/model/mistral/mistral_quantization.py index aac8bd0974..c38665ea6c 100644 --- a/python/mlc_llm/model/mistral/mistral_quantization.py +++ b/python/mlc_llm/model/mistral/mistral_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Mistral-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = MistralForCasualLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/mixtral/mixtral_quantization.py b/python/mlc_llm/model/mixtral/mixtral_quantization.py index eb4983738b..2c5f74024f 100644 --- a/python/mlc_llm/model/mixtral/mixtral_quantization.py +++ b/python/mlc_llm/model/mixtral/mixtral_quantization.py @@ -22,6 +22,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Mixtral-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = MixtralForCasualLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/orion/orion_model.py b/python/mlc_llm/model/orion/orion_model.py index 8ab70b8ba8..3bfbaea5a1 100644 --- a/python/mlc_llm/model/orion/orion_model.py +++ b/python/mlc_llm/model/orion/orion_model.py @@ -12,6 +12,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -37,6 +38,7 @@ class OrionConfig(ConfigBase): # pylint: disable=too-many-instance-attributes head_dim: int = 0 tensor_parallel_shards: int = 1 max_batch_size: int = 1 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -210,6 +212,7 @@ def __init__(self, config: OrionConfig): self.vocab_size = config.vocab_size self.rope_theta = config.position_embedding_base self.tensor_parallel_shards = config.tensor_parallel_shards + self.kv_quantization = config.kv_quantization self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -298,6 +301,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/orion/orion_quantization.py b/python/mlc_llm/model/orion/orion_quantization.py index eba7976fab..7d40b0599f 100644 --- a/python/mlc_llm/model/orion/orion_quantization.py +++ b/python/mlc_llm/model/orion/orion_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Orion-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = OrionForCasualLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/phi/phi_model.py b/python/mlc_llm/model/phi/phi_model.py index c012736b61..6af5097711 100644 --- a/python/mlc_llm/model/phi/phi_model.py +++ b/python/mlc_llm/model/phi/phi_model.py @@ -12,6 +12,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -38,6 +39,7 @@ class Phi1Config(ConfigBase): # pylint: disable=too-many-instance-attributes head_dim: int = 0 tensor_parallel_shards: int = 1 max_batch_size: int = 1 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -107,6 +109,7 @@ class PhiConfig(ConfigBase): # pylint: disable=too-many-instance-attributes n_head_kv: int = 0 head_dim: int = 0 tensor_parallel_shards: int = 1 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -166,6 +169,7 @@ def from_phi1(config: Phi1Config) -> "PhiConfig": n_head_kv=config.num_key_value_heads, head_dim=config.head_dim, tensor_parallel_shards=config.tensor_parallel_shards, + kv_quantization=config.kv_quantization, kwargs=config.kwargs, ) @@ -328,6 +332,7 @@ def __init__(self, config: Union[PhiConfig, Phi1Config]) -> None: self.vocab_size = config.vocab_size self.rope_theta = config.position_embedding_base self.tensor_parallel_shards = config.tensor_parallel_shards + self.kv_quantization = config.kv_quantization self.rotary_dim = config.rotary_dim self.dtype = "float32" @@ -420,6 +425,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, rotary_dim=self.rotary_dim, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/phi/phi_quantization.py b/python/mlc_llm/model/phi/phi_quantization.py index 854b3e6547..9da5ef786a 100644 --- a/python/mlc_llm/model/phi/phi_quantization.py +++ b/python/mlc_llm/model/phi/phi_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Phi-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = PhiForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/phi3/phi3_model.py b/python/mlc_llm/model/phi3/phi3_model.py index 0bd293e715..17dcf25dce 100644 --- a/python/mlc_llm/model/phi3/phi3_model.py +++ b/python/mlc_llm/model/phi3/phi3_model.py @@ -12,6 +12,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -38,6 +39,7 @@ class Phi3Config(ConfigBase): # pylint: disable=too-many-instance-attributes head_dim: int = 0 tensor_parallel_shards: int = 1 max_batch_size: int = 1 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -217,6 +219,7 @@ def __init__(self, config: Phi3Config) -> None: self.vocab_size = config.vocab_size self.rope_theta = config.position_embedding_base self.tensor_parallel_shards = config.tensor_parallel_shards + self.kv_quantization = config.kv_quantization self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -308,6 +311,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/phi3/phi3_quantization.py b/python/mlc_llm/model/phi3/phi3_quantization.py index c0e9fced7d..74a1b66df8 100644 --- a/python/mlc_llm/model/phi3/phi3_quantization.py +++ b/python/mlc_llm/model/phi3/phi3_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Phi-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = Phi3ForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/qwen/qwen_model.py b/python/mlc_llm/model/qwen/qwen_model.py index 7fb7e0eb82..ceb81b7c8d 100644 --- a/python/mlc_llm/model/qwen/qwen_model.py +++ b/python/mlc_llm/model/qwen/qwen_model.py @@ -12,6 +12,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -38,6 +39,7 @@ class QWenConfig(ConfigBase): # pylint: disable=too-many-instance-attributes tensor_parallel_shards: int = 1 max_batch_size: int = 1 head_dim: int = 0 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -208,6 +210,7 @@ def __init__(self, config: QWenConfig): self.head_dim = config.head_dim self.tensor_parallel_shards = config.tensor_parallel_shards self.rotary_emb_base = config.rotary_emb_base + self.kv_quantization = config.kv_quantization self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -297,6 +300,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rotary_emb_base, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/qwen/qwen_quantization.py b/python/mlc_llm/model/qwen/qwen_quantization.py index 38959512d6..a0be3f00d7 100644 --- a/python/mlc_llm/model/qwen/qwen_quantization.py +++ b/python/mlc_llm/model/qwen/qwen_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a QWen-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = QWenLMHeadModel(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/qwen2/qwen2_model.py b/python/mlc_llm/model/qwen2/qwen2_model.py index 89ca027777..2e6a0b6307 100644 --- a/python/mlc_llm/model/qwen2/qwen2_model.py +++ b/python/mlc_llm/model/qwen2/qwen2_model.py @@ -12,6 +12,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -40,6 +41,7 @@ class QWen2Config(ConfigBase): # pylint: disable=too-many-instance-attributes head_dim: int = 0 dtype: str = "float32" max_batch_size: int = 1 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -240,6 +242,7 @@ def __init__(self, config: QWen2Config): self.vocab_size = config.vocab_size self.tensor_parallel_shards = config.tensor_parallel_shards self.head_dim = config.head_dim + self.kv_quantization = config.kv_quantization def to(self, dtype: Optional[str] = None): super().to(dtype=dtype) @@ -337,6 +340,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/qwen2/qwen2_quantization.py b/python/mlc_llm/model/qwen2/qwen2_quantization.py index 3a8546236c..170813bbef 100644 --- a/python/mlc_llm/model/qwen2/qwen2_quantization.py +++ b/python/mlc_llm/model/qwen2/qwen2_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a QWen-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = QWen2LMHeadModel(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py b/python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py index 59b7ae8375..e5dde07369 100644 --- a/python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py +++ b/python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py @@ -233,6 +233,7 @@ def __init__(self, config: Qwen2MoeConfig): self.vocab_size = config.vocab_size self.tensor_parallel_shards = config.tensor_parallel_shards self.head_dim = config.head_dim + self.kv_quantization = config.kv_quantization def to(self, dtype: Optional[str] = None): super().to(dtype=dtype) @@ -320,6 +321,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/qwen2_moe/qwen2_moe_quantization.py b/python/mlc_llm/model/qwen2_moe/qwen2_moe_quantization.py index e01289823e..aaa0208cdd 100644 --- a/python/mlc_llm/model/qwen2_moe/qwen2_moe_quantization.py +++ b/python/mlc_llm/model/qwen2_moe/qwen2_moe_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Qwen2MoE-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = Qwen2MoeForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/stable_lm/stablelm_model.py b/python/mlc_llm/model/stable_lm/stablelm_model.py index 4f874af633..02d51a2d55 100644 --- a/python/mlc_llm/model/stable_lm/stablelm_model.py +++ b/python/mlc_llm/model/stable_lm/stablelm_model.py @@ -12,6 +12,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support import tensor_parallel as tp from mlc_llm.support.config import ConfigBase @@ -39,6 +40,7 @@ class StableLmConfig(ConfigBase): # pylint: disable=too-many-instance-attribute prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 max_batch_size: int = 1 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -218,6 +220,7 @@ def __init__(self, config: StableLmConfig): self.vocab_size = config.vocab_size self.rope_theta = config.rope_theta self.tensor_parallel_shards = config.tensor_parallel_shards + self.kv_quantization = config.kv_quantization self.partial_rotary_factor = config.partial_rotary_factor def to(self, dtype: Optional[str] = None): @@ -306,6 +309,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, rotary_dim=int(self.head_dim * self.partial_rotary_factor), ) diff --git a/python/mlc_llm/model/stable_lm/stablelm_quantization.py b/python/mlc_llm/model/stable_lm/stablelm_quantization.py index 620b769e05..c758d4a244 100644 --- a/python/mlc_llm/model/stable_lm/stablelm_quantization.py +++ b/python/mlc_llm/model/stable_lm/stablelm_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a StableLM-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = StableLmForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/model/starcoder2/starcoder2_model.py b/python/mlc_llm/model/starcoder2/starcoder2_model.py index b7d5d942b2..afed8263da 100644 --- a/python/mlc_llm/model/starcoder2/starcoder2_model.py +++ b/python/mlc_llm/model/starcoder2/starcoder2_model.py @@ -12,6 +12,7 @@ from mlc_llm import op as op_ext from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization from mlc_llm.support import logging from mlc_llm.support.config import ConfigBase from mlc_llm.support.style import bold @@ -40,6 +41,7 @@ class Starcoder2Config(ConfigBase): # pylint: disable=too-many-instance-attribu prefill_chunk_size: int = 0 tensor_parallel_shards: int = 1 max_batch_size: int = 1 + kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): @@ -196,6 +198,7 @@ def __init__(self, config: Starcoder2Config): self.vocab_size = config.vocab_size self.rope_theta = config.rope_theta self.tensor_parallel_shards = config.tensor_parallel_shards + self.kv_quantization = config.kv_quantization self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -282,6 +285,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments rope_mode=RopeMode.NORMAL, rope_scale=1, rope_theta=self.rope_theta, + kv_quantization=self.kv_quantization, dtype=self.dtype, ) diff --git a/python/mlc_llm/model/starcoder2/starcoder2_quantization.py b/python/mlc_llm/model/starcoder2/starcoder2_quantization.py index c6ca093cdb..e2e1065d66 100644 --- a/python/mlc_llm/model/starcoder2/starcoder2_quantization.py +++ b/python/mlc_llm/model/starcoder2/starcoder2_quantization.py @@ -16,6 +16,7 @@ def group_quant( quantization: GroupQuantize, ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a InternLM-architecture model using group quantization.""" + model_config.kv_quantization = quantization.kv_quantization model: nn.Module = Starcoder2ForCausalLM(model_config) model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) diff --git a/python/mlc_llm/nn/kv_cache.py b/python/mlc_llm/nn/kv_cache.py index 6f7cacf50d..aa3cea48b4 100644 --- a/python/mlc_llm/nn/kv_cache.py +++ b/python/mlc_llm/nn/kv_cache.py @@ -1,6 +1,6 @@ """Attention KV cache modeling.""" -# pylint: disable=too-many-statements,too-many-lines,too-many-arguments +# pylint: disable=too-many-statements,too-many-lines,too-many-arguments,too-many-locals import enum import math from typing import Optional, Tuple @@ -8,12 +8,16 @@ from tvm import relax as rx from tvm import tir from tvm.relax.frontend.nn import Object, Tensor -from tvm.runtime import DataType from tvm.script import tir as T from tvm.target import Target from mlc_llm.op.position_embedding import llama_rope_with_position_map, rope_freq from mlc_llm.op.tree_attn import tree_attn +from mlc_llm.quantization.paged_kv_cache_quantization import ( + BaseKVConfig, + PagedKVCacheQuantization, + get_paged_kv_cache_config, +) from ..support.max_thread_check import ( check_thread_limits, @@ -50,6 +54,7 @@ def create_generic( rope_mode: RopeMode, rope_scale: int, rope_theta: int, + kv_quantization: PagedKVCacheQuantization, dtype: str, rotary_dim: Optional[int] = None, name: str = "paged_kv_cache", @@ -79,6 +84,7 @@ def create_generic( rx.PrimValue(rope_scale), rx.PrimValue(rope_theta), rx.PrimValue(rotary_dim), + rx.PrimValue(kv_quantization), rx.DataTypeImm(dtype), sinfo_args=rx.ObjectStructInfo(), ), @@ -168,6 +174,7 @@ def __init__( # pylint: disable=too-many-locals rope_scale: int, rope_theta: int, rotary_dim: int, + kv_quantization: PagedKVCacheQuantization, dtype: str, target: Target, name: str = "paged_kv_cache", @@ -210,6 +217,18 @@ def __init__( # pylint: disable=too-many-locals if rope_mode == RopeMode.INLINE: assert rotary_dim == head_dim, "FlashInfer RoPE does not support partial rotary dim." + kv_cache_config = get_paged_kv_cache_config( + kv_quant_scheme=PagedKVCacheQuantization(kv_quantization).name.lower(), + model_dtype=dtype, + kwargs={ + "head_dim": head_dim, + "num_hidden_layers": num_hidden_layers, + "num_attention_heads": num_attention_heads, + "num_key_value_heads": num_key_value_heads, + "target": target, + }, + ) + bb = rx.BlockBuilder.current() # pylint: disable=invalid-name args = [ rx.ShapeExpr( @@ -228,14 +247,16 @@ def __init__( # pylint: disable=too-many-locals rx.PrimValue(rope_mode), rx.PrimValue(rope_scale), rx.PrimValue(rope_theta), - rx.op.zeros((), dtype), + rx.PrimValue(kv_cache_config.num_storage), + rx.op.zeros((), kv_cache_config.model_dtype), + rx.op.zeros((), kv_cache_config.kv_storage_dtype), # pylint: disable=line-too-long # fmt: off - bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), + bb.add_func(_kv_cache_transpose_append(kv_cache_config), "kv_cache_transpose_append"), rx.extern("flashinfer.attention_kernel_prefill_with_paged_kv_cache"), rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache"), - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_prefill_sliding_window"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_decode_sliding_window"), + bb.add_func(_attention_prefill(kv_cache_config, sliding_window=True), "tir_attention_prefill_sliding_window"), + bb.add_func(_attention_decode(kv_cache_config, sliding_window=True), "tir_attention_decode_sliding_window"), rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), rx.extern("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), @@ -244,11 +265,11 @@ def __init__( # pylint: disable=too-many-locals rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_begin_forward"), rx.extern("flashinfer.attention_kernel_decode_with_paged_kv_cache_end_forward"), rx.extern("flashinfer.merge_state_in_place"), - bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), - bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), - bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), - bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), - bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_with_tree_mask"), + bb.add_func(llama_rope_with_position_map(kv_cache_config, rope_theta, rope_scale, rotary_dim), "tir_split_rotary"), + bb.add_func(_copy_single_page(kv_cache_config), "kv_cache_copy_single_page"), + bb.add_func(_kv_cache_debug_get_kv(kv_cache_config), "kv_cache_debug_get_kv"), + bb.add_func(_compact_kv_copy(kv_cache_config), "kv_cache_compact_kv_copy"), + bb.add_func(tree_attn(kv_cache_config), "tir_attention_prefill_with_tree_mask"), # fmt: on # pylint: enable=line-too-long ] @@ -280,6 +301,7 @@ def __init__( # pylint: disable=too-many-locals rope_scale: int, rope_theta: int, rotary_dim: int, + kv_quantization: PagedKVCacheQuantization, dtype: str, target: Target, name: str = "paged_kv_cache", @@ -322,6 +344,18 @@ def __init__( # pylint: disable=too-many-locals The target to build the model to. """ + kv_cache_config = get_paged_kv_cache_config( + kv_quant_scheme=PagedKVCacheQuantization(kv_quantization).name.lower(), + model_dtype=dtype, + kwargs={ + "head_dim": head_dim, + "num_hidden_layers": num_hidden_layers, + "num_attention_heads": num_attention_heads, + "num_key_value_heads": num_key_value_heads, + "target": target, + }, + ) + bb = rx.BlockBuilder.current() args = [ rx.ShapeExpr( @@ -340,21 +374,23 @@ def __init__( # pylint: disable=too-many-locals rx.PrimValue(rope_mode), rx.PrimValue(rope_scale), rx.PrimValue(rope_theta), - rx.op.zeros((), dtype), + rx.PrimValue(kv_cache_config.num_storage), + rx.op.zeros((), kv_cache_config.model_dtype), + rx.op.zeros((), kv_cache_config.kv_storage_dtype), # pylint: disable=line-too-long # fmt: off - bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, False, target), "tir_attention_prefill"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, False, target), "tir_attention_decode"), - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_prefill_sliding_window"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, target), "tir_attention_decode_sliding_window"), - bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_ragged"), - bb.add_func(_merge_state_inplace(num_attention_heads, head_dim, dtype, target), "tir_attention_merge_state"), - bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rotary_dim), "tir_split_rotary"), - bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), - bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), - bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), - bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, target), "tir_attention_prefill_with_tree_mask"), + bb.add_func(_kv_cache_transpose_append(kv_cache_config), "kv_cache_transpose_append"), + bb.add_func(_attention_prefill(kv_cache_config, sliding_window=False), "tir_attention_prefill"), + bb.add_func(_attention_decode(kv_cache_config, sliding_window=False), "tir_attention_decode"), + bb.add_func(_attention_prefill(kv_cache_config, sliding_window=True), "tir_attention_prefill_sliding_window"), + bb.add_func(_attention_decode(kv_cache_config, sliding_window=True), "tir_attention_decode_sliding_window"), + bb.add_func(_attention_prefill_ragged(kv_cache_config), "tir_attention_prefill_ragged"), + bb.add_func(_merge_state_inplace(kv_cache_config), "tir_attention_merge_state"), + bb.add_func(llama_rope_with_position_map(kv_cache_config, rope_theta, rope_scale, rotary_dim), "tir_split_rotary"), + bb.add_func(_copy_single_page(kv_cache_config), "kv_cache_copy_single_page"), + bb.add_func(_kv_cache_debug_get_kv(kv_cache_config), "kv_cache_debug_get_kv"), + bb.add_func(_compact_kv_copy(kv_cache_config), "kv_cache_compact_kv_copy"), + bb.add_func(tree_attn(kv_cache_config), "tir_attention_prefill_with_tree_mask"), # fmt: on # pylint: enable=line-too-long ] @@ -372,51 +408,70 @@ def __init__( # pylint: disable=too-many-locals # pylint: disable=too-many-locals -def _kv_cache_transpose_append(num_key_value_heads, head_dim, dtype): +def _kv_cache_transpose_append(kv_cache_config: BaseKVConfig): """Return the TIR function that appends new k/v data to PagedKVCache.""" - # pylint: disable=line-too-long,invalid-name - # fmt: off - @T.prim_func - def tir_kv_cache_transpose_append( - var_pages: T.handle, - var_k_data: T.handle, - var_v_data: T.handle, - var_position_map: T.handle, - ): - T.func_attr({"tir.noalias": T.bool(True)}) - ntoken = T.SizeVar("num_tokens_excluding_cache", "int64") - num_pages = T.int64() - position_map_elem_offset = T.int32() - pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 16, head_dim), dtype) - k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads, head_dim), dtype) - v_data = T.match_buffer(var_v_data, (ntoken, num_key_value_heads, head_dim), dtype) - position_map = T.match_buffer( - var_position_map, (ntoken,), "int32", elem_offset=position_map_elem_offset - ) - for global_pos, h, f in T.grid(ntoken, num_key_value_heads, head_dim): - if position_map[global_pos] != T.int32(-1): - with T.block("k_transpose_append"): - vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) - T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf]) - position: T.int32 = position_map[vgpos] # type: ignore - pages[T.floordiv(position, 16), 0, vh, T.floormod(position, 16), vf] = k_data[vgpos, vh, vf] - with T.block("v_transpose_append"): - vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) - T.reads(position_map[vgpos], v_data[vgpos, vh, vf]) - T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) - position: T.int32 = position_map[vgpos] # type: ignore[name-defined,no-redef] - pages[T.floordiv(position, 16), 1, vh, T.floormod(position, 16), vf] = v_data[vgpos, vh, vf] + head_dim = kv_cache_config.head_dim + num_storage = kv_cache_config.num_storage + num_key_value_heads = kv_cache_config.num_key_value_heads + model_dtype = kv_cache_config.model_dtype + kv_storage_dtype = kv_cache_config.kv_storage_dtype + + if kv_cache_config.kind == "no-quant": + # pylint: disable=line-too-long,invalid-name + # fmt: off + @T.prim_func + def tir_kv_cache_transpose_append( + var_pages: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + var_position_map: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + ntoken = T.SizeVar("num_tokens_excluding_cache", "int64") + num_pages = T.int64() + position_map_elem_offset = T.int32() + + pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 16, num_storage), kv_storage_dtype) + k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads, head_dim), model_dtype) + v_data = T.match_buffer(var_v_data, (ntoken, num_key_value_heads, head_dim), model_dtype) + position_map = T.match_buffer(var_position_map, (ntoken,), "int32", elem_offset=position_map_elem_offset) + + for global_pos, h, f in T.grid(ntoken, T.int64(num_key_value_heads), T.int64(num_storage)): + if position_map[global_pos] != T.int32(-1): + with T.block("k_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) + T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf]) + position: T.int32 = position_map[vgpos] # type: ignore + pages[T.floordiv(position, 16), 0, vh, T.floormod(position, 16), vf] = k_data[vgpos, vh, vf] + with T.block("v_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + T.reads(position_map[vgpos], v_data[vgpos, vh, vf]) + T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf]) + position: T.int32 = position_map[vgpos] # type: ignore[name-defined,no-redef] + pages[T.floordiv(position, 16), 1, vh, T.floormod(position, 16), vf] = v_data[vgpos, vh, vf] + # fmt: on # pylint: enable=line-too-long,invalid-name - return tir_kv_cache_transpose_append + return ( + tir_kv_cache_transpose_append + if kv_cache_config.kind == "no-quant" + else kv_cache_config.kv_cache_quantize_transpose_append() + ) -def _kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype): +def _kv_cache_debug_get_kv(kv_cache_config: BaseKVConfig): """Return the TIR function that fetches the k/v data on given positions and layer.""" + num_hidden_layers = kv_cache_config.num_hidden_layers + num_key_value_heads = kv_cache_config.num_key_value_heads + num_storage = kv_cache_config.num_storage + head_dim = kv_cache_config.head_dim + dtype = kv_cache_config.model_dtype + kv_storage_dtype = kv_cache_config.kv_storage_dtype + # pylint: disable=line-too-long,invalid-name # fmt: off @T.prim_func @@ -432,10 +487,8 @@ def tir_kv_cache_debug_get_kv( page_size = T.SizeVar("page_size", "int64") num_pages = T.int64() position_map_elem_offset = T.int64() - pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype) - position_map = T.match_buffer( - var_position_map, (seqlen,), "int32", elem_offset=position_map_elem_offset - ) + pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, num_storage), kv_storage_dtype) + position_map = T.match_buffer(var_position_map, (seqlen,), "int32", elem_offset=position_map_elem_offset) k_data = T.match_buffer(var_k_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype) v_data = T.match_buffer(var_v_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype) for p, h, d in T.grid(seqlen, num_key_value_heads, head_dim): @@ -444,14 +497,47 @@ def tir_kv_cache_debug_get_kv( T.reads(position_map[vp], pages[position_map[vp] // page_size, 0:2, vh, position_map[vp] % page_size, vd]) T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) position: T.int32 = position_map[vp] # type: ignore[name-defined] - k_data[layer_id, vp, vh, vd] = pages[T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vd] - v_data[layer_id, vp, vh, vd] = pages[T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vd] + k_data[layer_id, vp, vh, vd] = _dequantize(kv_cache_config, pages, (T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vd)) + v_data[layer_id, vp, vh, vd] = _dequantize(kv_cache_config, pages, (T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vd)) # fmt: on # pylint: enable=line-too-long,invalid-name return tir_kv_cache_debug_get_kv +def _dequantize( + kv_cache_config: BaseKVConfig, + buffer: T.Buffer, + indices: Tuple[tir.Var, ...], +): + return ( + buffer[indices] + if kv_cache_config.kind == "no-quant" + else kv_cache_config.kv_cache_dequantize(buffer, indices) + ) + + +def _rope_dequantize( + kv_cache_config: BaseKVConfig, + buffer: T.Buffer, + offset: tir.Var, + rotary_dim: int, + theta: tir.Var, + scale: tir.Var, + indices: Tuple[tir.Var, ...], + qkv_dtype="float16", +): + d = indices[-1] + cos_freq, sin_freq = rope_freq(offset * scale, d, rotary_dim, theta, "float32") + cos = cos_freq * _dequantize(kv_cache_config, buffer, indices).astype("float32") + sin = sin_freq * tir.if_then_else( + d < rotary_dim // 2, + -_dequantize(kv_cache_config, buffer, indices[:-1] + (d + rotary_dim // 2,)), + _dequantize(kv_cache_config, buffer, indices[:-1] + (d - rotary_dim // 2,)), + ).astype("float32") + return (cos + sin).astype(qkv_dtype) + + def _rope( buffer: T.Buffer, offset: tir.Var, @@ -515,22 +601,28 @@ def _get_seq_offset(pos, seq_id, length_info, sliding_window): ) -def _attention_prefill(h_kv, h_q, d, dtype, sliding_window: bool, target: Target): +def _attention_prefill(kv_cache_config: BaseKVConfig, sliding_window: bool): # pylint: disable=invalid-name + + h_kv = kv_cache_config.num_key_value_heads + h_q = kv_cache_config.num_attention_heads + d = kv_cache_config.head_dim + dtype = kv_cache_config.model_dtype + kv_storage_dtype = kv_cache_config.kv_storage_dtype + num_storage = kv_cache_config.num_storage + target = kv_cache_config.target + NUM_BLKS = 16 - LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes + LOAD_VEC = 8 // ((dtype.bits + 7) // 8) # 8 bytes group_size = h_q // h_kv sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) bdx = 32 num_warps = 4 - tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + tile_x, tile_y, tile_z = 64 // ((dtype.bits + 7) // 8) // max(d // 128, 1), d, 16 # Otherwise we would exceed maxComputeWorkgroupStorageSize - if ( - str(target.kind) == "webgpu" - and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 - ): + if str(target.kind) == "webgpu" and ((d + 127) // 128) * ((dtype.bits + 15) // 16) >= 4: tile_z = 8 num_warps = 2 check_thread_limits(target, bdx=bdx, bdy=num_warps, bdz=1, gdz=1) @@ -574,7 +666,7 @@ def batch_prefill_paged_kv( q = T.match_buffer(var_q, (total_len, h_q, d), dtype) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) - pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype) + pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, num_storage), kv_storage_dtype) page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset) page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset) @@ -697,8 +789,8 @@ def batch_prefill_paged_kv( page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore K_smem[i, j] = T.if_then_else( rotary_mode == 1, - _rope(pages, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (page_no, 0, by, page_offset, j), dtype), - pages[page_no, 0, by, page_offset, j] + _rope_dequantize(kv_cache_config, pages, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (page_no, 0, by, page_offset, j), dtype), + _dequantize(kv_cache_config, pages, (page_no, 0, by, page_offset, j)) ) else: K_smem[i, j] = 0.0 @@ -713,7 +805,7 @@ def batch_prefill_paged_kv( seq_offset: T.int32(is_size_var=True) = _get_seq_offset(cur_L, b_idx, length_info, sliding_window) # type: ignore page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + T.floordiv(seq_offset, 16)] # type: ignore page_offset: T.int32(is_size_var=True) = T.floormod(seq_offset, 16) # type: ignore - V_smem[i, j] = pages[page_no, 1, by, page_offset, j] + V_smem[i, j] = _dequantize(kv_cache_config, pages, (page_no, 1, by, page_offset, j)) else: V_smem[i, j] = 0.0 T.tvm_storage_sync("shared") @@ -880,19 +972,18 @@ def apply_to_md(sch, block): return sch.mod["main"].with_attr("tir.is_scheduled", 1) -def _attention_decode( - num_kv_heads, - num_qo_heads, - head_dim, - qkv_dtype, - sliding_window: bool, - target: Target, -): +def _attention_decode(kv_cache_config: BaseKVConfig, sliding_window: bool): # pylint: disable=invalid-name + + H_qo = kv_cache_config.num_attention_heads + H_kv = kv_cache_config.num_key_value_heads + head_dim = kv_cache_config.head_dim + qkv_dtype = kv_cache_config.model_dtype + kv_storage_dtype = kv_cache_config.kv_storage_dtype + num_storage = kv_cache_config.num_storage + target = kv_cache_config.target + qkv_dtype_bytes = 2 - H_qo = num_qo_heads - H_kv = num_kv_heads - D = head_dim THREAD_LIMIT = 512 TILE_SIZE_PER_BDX = 2 @@ -903,8 +994,8 @@ def _attention_decode( thread_limit = min(max_num_threads_per_block, THREAD_LIMIT) GROUP_SIZE = H_qo // H_kv - VEC_SIZE = min(max(8 // qkv_dtype_bytes, D // 32), 4) - bdx = D // VEC_SIZE + VEC_SIZE = min(max(8 // qkv_dtype_bytes, head_dim // 32), 4) + bdx = head_dim // VEC_SIZE bdy = GROUP_SIZE while bdx * bdy > thread_limit and bdy > 1: bdy //= 2 @@ -948,15 +1039,13 @@ def batch_decode_paged_kv( q_rope_position_elem_offset = T.int32(is_size_var=True) length_info_elem_offset = T.int32(is_size_var=True) - Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype) - pages = T.match_buffer( - pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype - ) + Q = T.match_buffer(Q_handle, (B, H_qo, head_dim), qkv_dtype) + pages = T.match_buffer(pages_handle, (max_num_pages, 2, H_kv, 16, num_storage), kv_storage_dtype) page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", elem_offset=page_indptr_elem_offset) page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", elem_offset=k_rope_pos_offset_elem_offset) q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", elem_offset=q_rope_position_elem_offset) - output = T.match_buffer(output_handle, (B, H_qo, D), qkv_dtype) + output = T.match_buffer(output_handle, (B, H_qo, head_dim), qkv_dtype) lse = T.match_buffer(lse_handle, (B, H_qo), "float32") # pylint: disable=unused-variable # The length information of the sequences. # - It is in shape `(3, batch_size)` when sliding window is enabled. @@ -968,7 +1057,7 @@ def batch_decode_paged_kv( # denoting the "last_page_len". length_info = _declare_length_info(var_length_info, B, sliding_window, length_info_elem_offset) - sm_scale = 1.0 / math.sqrt(float(D)) * log2e + sm_scale = 1.0 / math.sqrt(float(head_dim)) * log2e for bx in T.thread_binding(B, thread="blockIdx.x"): for fused_by_bz in T.thread_binding(H_kv * gdz, thread="blockIdx.y"): @@ -978,9 +1067,9 @@ def batch_decode_paged_kv( with T.block("attn"): Q_local = T.alloc_buffer((VEC_SIZE,), qkv_dtype, scope="local") kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local") - K_smem = T.alloc_buffer((bdz * bdy * tile_size_per_bdx, D), qkv_dtype, scope="shared") - V_smem = T.alloc_buffer((bdz * bdy * tile_size_per_bdx, D), qkv_dtype, scope="shared") - O_allreduce = T.alloc_buffer((bdz, bdy, D), "float32", scope="shared") + K_smem = T.alloc_buffer((bdz * bdy * tile_size_per_bdx, head_dim), qkv_dtype, scope="shared") + V_smem = T.alloc_buffer((bdz * bdy * tile_size_per_bdx, head_dim), qkv_dtype, scope="shared") + O_allreduce = T.alloc_buffer((bdz, bdy, head_dim), "float32", scope="shared") md_allreduce = T.alloc_buffer((bdz, bdy, 2), "float32", scope="shared") S_reduce_local = T.alloc_buffer((1,), "float32", scope="local") t0 = T.alloc_buffer((1,), "float32", scope="local") @@ -1040,10 +1129,10 @@ def batch_decode_paged_kv( for vec in T.vectorized(VEC_SIZE): K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = T.if_then_else( rotary_mode == 1, - _rope(pages, k_rope_pos_offset[batch_idx] + row_g, head_dim, rope_theta, rope_scale, (page_no, 0, by, page_offset, tx * VEC_SIZE + vec), qkv_dtype), - pages[page_no, 0, by, page_offset, tx * VEC_SIZE + vec] + _rope_dequantize(kv_cache_config, pages, k_rope_pos_offset[batch_idx] + row_g, head_dim, rope_theta, rope_scale, (page_no, 0, by, page_offset, tx * VEC_SIZE + vec), qkv_dtype), + _dequantize(kv_cache_config, pages, (page_no, 0, by, page_offset, tx * VEC_SIZE + vec)) ) - V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = pages[page_no, 1, by, page_offset, tx * VEC_SIZE + vec] + V_smem[tile_start_s + j, tx * VEC_SIZE + vec] = _dequantize(kv_cache_config, pages, (page_no, 1, by, page_offset, tx * VEC_SIZE + vec)) else: for vec in T.vectorized(VEC_SIZE): K_smem[tile_start_s + j, tx * VEC_SIZE + vec] = 0.0 @@ -1134,18 +1223,21 @@ def batch_decode_paged_kv( return batch_decode_paged_kv -def _merge_state_inplace( - num_heads, head_dim, v_dtype, target: Target -): # pylint: disable=unused-argument +def _merge_state_inplace(kv_cache_config: BaseKVConfig): # pylint: disable=invalid-name + num_attention_heads = kv_cache_config.num_attention_heads + head_dim = kv_cache_config.head_dim + v_dtype = kv_cache_config.model_dtype + target = kv_cache_config.target + v_dtype_bytes = 2 VEC_SIZE = min(max(8 // v_dtype_bytes, head_dim // 32), 4) bdx = head_dim // VEC_SIZE - bdy = num_heads + bdy = num_attention_heads max_num_threads_per_block = get_max_num_threads_per_block(target) while bdx * bdy > max_num_threads_per_block and bdy > 1: bdy //= 2 - gdy = num_heads // bdy + gdy = num_attention_heads // bdy check_thread_limits(target, bdx=bdx, bdy=bdy, bdz=1, gdz=1) @T.prim_func @@ -1211,24 +1303,25 @@ def merge_state_inplace( return merge_state_inplace -def _attention_prefill_ragged( - h_kv, h_q, d, dtype, target: Target -): # pylint: disable=unused-argument +def _attention_prefill_ragged(kv_cache_config: BaseKVConfig): # pylint: disable=invalid-name,line-too-long + h_kv = kv_cache_config.num_key_value_heads + h_q = kv_cache_config.num_attention_heads + d = kv_cache_config.head_dim + dtype = kv_cache_config.model_dtype + target = kv_cache_config.target + NUM_BLKS = 16 - LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes + LOAD_VEC = 8 // ((dtype.bits + 7) // 8) # 8 bytes group_size = h_q // h_kv sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) bdx = 32 num_warps = 4 - tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + tile_x, tile_y, tile_z = 64 // ((dtype.bits + 7) // 8) // max(d // 128, 1), d, 16 # Otherwise we would exceed maxComputeWorkgroupStorageSize - if ( - str(target.kind) == "webgpu" - and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 - ): + if str(target.kind) == "webgpu" and ((d + 127) // 128) * ((dtype.bits + 15) // 16) >= 4: tile_z = 8 num_warps = 2 @@ -1547,7 +1640,13 @@ def apply_to_md(sch, block): return sch.mod["main"].with_attr("tir.is_scheduled", 1) -def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target): +def _copy_single_page(kv_cache_config: BaseKVConfig): + num_key_value_heads = kv_cache_config.num_key_value_heads + kv_storage_dtype = kv_cache_config.kv_storage_dtype + num_storage = kv_cache_config.num_storage + head_dim = kv_cache_config.head_dim + target = kv_cache_config.target + tx = get_max_num_threads_per_block(target) @T.prim_func @@ -1559,27 +1658,30 @@ def copy_single_page( ): T.func_attr({"tir.is_scheduled": 1}) num_pages = T.int32() - pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype) + page_size = T.SizeVar("page_size", "int64") + pages = T.match_buffer( + var_pages, (num_pages, 2, num_key_value_heads, page_size, num_storage), kv_storage_dtype + ) for b in T.thread_binding( - (copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" + (copy_length * num_key_value_heads * num_storage + tx - 1) // tx, thread="blockIdx.x" ): for t in T.thread_binding(tx, thread="threadIdx.x"): with T.block("copy"): - T.where(b * tx + t < copy_length * num_heads * head_dim) + T.where(b * tx + t < copy_length * num_key_value_heads * head_dim) vh = T.axis.spatial( - num_heads, - T.Cast("int32", (b * tx + t) // (copy_length * head_dim)), + num_key_value_heads, + T.Cast("int32", (b * tx + t) // (copy_length * num_storage)), ) vp = T.axis.spatial( copy_length, - (b * tx + t) % (copy_length * head_dim) // head_dim, + (b * tx + t) % (copy_length * num_storage) // num_storage, ) vd = T.axis.spatial( - head_dim, + num_storage, T.Cast( "int32", - (b * tx + t) % head_dim, + (b * tx + t) % num_storage, ), ) pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd] @@ -1588,7 +1690,12 @@ def copy_single_page( return copy_single_page -def _compact_kv_copy(num_heads, head_dim, dtype, target: Target): +def _compact_kv_copy(kv_cache_config: BaseKVConfig): + num_key_value_heads = kv_cache_config.num_key_value_heads + kv_storage_dtype = kv_cache_config.kv_storage_dtype + num_storage = kv_cache_config.num_storage + target: Target = kv_cache_config.target + tx = get_max_num_threads_per_block(target) @T.prim_func @@ -1603,7 +1710,9 @@ def compact_kv_copy( total_copy_length = T.int32() copy_length_indptr_elem_offset = T.int32() copy_src_dst_pos_elem_offset = T.int32() - pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype) + pages = T.match_buffer( + var_pages, (num_pages, 2, num_key_value_heads, 16, num_storage), kv_storage_dtype + ) copy_length_indptr = T.match_buffer( var_copy_length_indptr, (batch_size + 1,), @@ -1619,13 +1728,13 @@ def compact_kv_copy( with T.block("root"): for bhd_o in T.thread_binding( - (batch_size * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x" + (batch_size * num_key_value_heads * num_storage + tx - 1) // tx, thread="blockIdx.x" ): for bhd_i in T.thread_binding(tx, thread="threadIdx.x"): - b: T.int32 = (bhd_o * tx + bhd_i) // (num_heads * head_dim) - h: T.int32 = (bhd_o * tx + bhd_i) // head_dim % num_heads - d: T.int32 = (bhd_o * tx + bhd_i) % head_dim - if (bhd_o * tx + bhd_i) < batch_size * num_heads * head_dim: + b: T.int32 = (bhd_o * tx + bhd_i) // (num_key_value_heads * num_storage) + h: T.int32 = (bhd_o * tx + bhd_i) // num_storage % num_key_value_heads + d: T.int32 = (bhd_o * tx + bhd_i) % num_storage + if (bhd_o * tx + bhd_i) < batch_size * num_key_value_heads * num_storage: for i in T.serial(copy_length_indptr[b + 1] - copy_length_indptr[b]): src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i] dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i] diff --git a/python/mlc_llm/op/position_embedding.py b/python/mlc_llm/op/position_embedding.py index 4416e8bc9a..d2c61243e3 100644 --- a/python/mlc_llm/op/position_embedding.py +++ b/python/mlc_llm/op/position_embedding.py @@ -170,40 +170,33 @@ def fused_rope( # pylint: disable=too-many-locals def llama_rope_with_position_map( # pylint: disable=too-many-arguments + kv_cache_config, theta: float, scale: float, - head_dim: int, - num_q_heads: int, - num_kv_heads: int, - dtype: str, rotary_dim: Optional[int] = None, ): """Return the TIR function that computes Llama-style RoPE with q position map. Parameters ---------- + kv_cache_config : BaseKVConfig + Page KV Cache configuration. + theta : float The theta value, or "base" in RoPE, which controls the frequency. scale : float The RoPE scaling factor. - head_dim : int - The number of features on each head. - - num_q_heads : int - The number of query heads. - - num_kv_heads : int - The number of key/value heads. It differs from `num_q_heads` in group-query attention. - - dtype : str - The dtype of qkv data. - rotary_dim : int The number of dimensions in the embedding that RoPE is applied to. By default, the rotary_dim is the same as head_dim. """ + head_dim = kv_cache_config.head_dim + num_q_heads = kv_cache_config.num_attention_heads + num_kv_heads = kv_cache_config.num_key_value_heads + dtype = kv_cache_config.model_dtype + fused_heads = num_q_heads + num_kv_heads * 2 if rotary_dim is None: rotary_dim = head_dim diff --git a/python/mlc_llm/op/tree_attn.py b/python/mlc_llm/op/tree_attn.py index 0a9373125d..dd2e8d5f6d 100644 --- a/python/mlc_llm/op/tree_attn.py +++ b/python/mlc_llm/op/tree_attn.py @@ -4,9 +4,7 @@ from typing import Tuple from tvm import tir -from tvm.runtime import DataType from tvm.script import tir as T -from tvm.target import Target from mlc_llm.op.position_embedding import rope_freq @@ -42,21 +40,13 @@ def _tree_mask(row, col, mask_ptr, offset, stride, kv_len): return tir.all(col < kv_len, mask_ptr[offset + row * stride + col] == 1) -def tree_attn(h_kv, h_q, d, dtype, target: Target): # pylint: disable=unused-argument +def tree_attn(kv_cache_config): """Generate tree attention kernel for batched tree attention. Parameters ---------- - h_kv : int - Number of heads for key and value. - h_q : int - Number of heads for query. - d : int - Hidden dimension. - dtype : str - Data type. - target : Target - The target device. + kv_cache_config : BaseKVConfig + Page KV Cache configuration. Returns ------- @@ -64,20 +54,23 @@ def tree_attn(h_kv, h_q, d, dtype, target: Target): # pylint: disable=unused-ar The generated IR module. """ # pylint: disable=invalid-name,line-too-long + h_kv = kv_cache_config.num_key_value_heads + h_q = kv_cache_config.num_attention_heads + d = kv_cache_config.head_dim + dtype = kv_cache_config.model_dtype + target = kv_cache_config.target + NUM_BLKS = 16 - LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes + LOAD_VEC = 8 // ((dtype.bits + 7) // 8) # 8 bytes group_size = h_q // h_kv sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) bdx = 32 num_warps = 4 - tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 + tile_x, tile_y, tile_z = 64 // ((dtype.bits + 7) // 8) // max(d // 128, 1), d, 16 # Otherwise we would exceed maxComputeWorkgroupStorageSize - if ( - str(target.kind) == "webgpu" - and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4 - ): + if str(target.kind) == "webgpu" and ((d + 127) // 128) * ((dtype.bits + 15) // 16) >= 4: tile_z = 8 num_warps = 2 diff --git a/python/mlc_llm/quantization/__init__.py b/python/mlc_llm/quantization/__init__.py index d2c89bb2a1..3f42c8d401 100644 --- a/python/mlc_llm/quantization/__init__.py +++ b/python/mlc_llm/quantization/__init__.py @@ -5,5 +5,11 @@ from .ft_quantization import FTQuantize from .group_quantization import GroupQuantize from .no_quantization import NoQuantize +from .paged_kv_cache_quantization import ( + BaseKVConfig, + PagedKVCacheQuantization, + get_kv_storage_dtype, + get_paged_kv_cache_config, +) from .per_tensor_quantization import PerTensorQuantize from .quantization import QUANTIZATION, Quantization diff --git a/python/mlc_llm/quantization/awq_quantization.py b/python/mlc_llm/quantization/awq_quantization.py index d51f0a6020..8e1a55975b 100644 --- a/python/mlc_llm/quantization/awq_quantization.py +++ b/python/mlc_llm/quantization/awq_quantization.py @@ -9,6 +9,7 @@ from mlc_llm.loader import QuantizeMapping +from .paged_kv_cache_quantization import PagedKVCacheQuantization from .utils import convert_uint_to_float, is_final_fc, is_moe_gate @@ -41,6 +42,7 @@ class AWQQuantize: # pylint: disable=too-many-instance-attributes quantize_dtype: str # "int3", "int4", "int8" storage_dtype: str # "uint32" model_dtype: str # "float16", "float32" + kv_quantization: PagedKVCacheQuantization num_elem_per_storage: int = 0 num_storage_per_group: int = 0 diff --git a/python/mlc_llm/quantization/ft_quantization.py b/python/mlc_llm/quantization/ft_quantization.py index 4a15846096..bd04acc9ca 100644 --- a/python/mlc_llm/quantization/ft_quantization.py +++ b/python/mlc_llm/quantization/ft_quantization.py @@ -21,6 +21,7 @@ GroupQuantizeEmbedding, GroupQuantizeLinear, ) +from .paged_kv_cache_quantization import PagedKVCacheQuantization from .utils import is_final_fc, is_moe_gate logger = logging.getLogger(__name__) @@ -35,6 +36,7 @@ class FTQuantize: # pylint: disable=too-many-instance-attributes quantize_dtype: Literal["int4", "int8"] storage_dtype: Literal["int8"] model_dtype: Literal["float16"] + kv_quantization: PagedKVCacheQuantization group_size: Optional[int] = None num_elem_per_storage: int = 0 @@ -57,6 +59,7 @@ def fallback_group_quantize(self) -> GroupQuantize: storage_dtype="uint32", model_dtype=self.model_dtype, linear_weight_layout="NK", + kv_quantization=self.kv_quantization, ) def __post_init__(self): diff --git a/python/mlc_llm/quantization/group_quantization.py b/python/mlc_llm/quantization/group_quantization.py index 27cac54212..b1e8c72f7e 100644 --- a/python/mlc_llm/quantization/group_quantization.py +++ b/python/mlc_llm/quantization/group_quantization.py @@ -12,6 +12,7 @@ from mlc_llm.nn import MixtralExperts from mlc_llm.support import logging +from .paged_kv_cache_quantization import PagedKVCacheQuantization from .utils import ( apply_sharding, compile_quantize_func, @@ -35,6 +36,7 @@ class GroupQuantize: # pylint: disable=too-many-instance-attributes storage_dtype: Literal["uint32"] model_dtype: Literal["float16", "float32"] linear_weight_layout: Literal["KN", "NK"] + kv_quantization: PagedKVCacheQuantization quantize_embedding: bool = True quantize_final_fc: bool = True diff --git a/python/mlc_llm/quantization/paged_kv_cache_quantization.py b/python/mlc_llm/quantization/paged_kv_cache_quantization.py new file mode 100644 index 0000000000..5d2a87665c --- /dev/null +++ b/python/mlc_llm/quantization/paged_kv_cache_quantization.py @@ -0,0 +1,323 @@ +"""Paged KV cache quantization config""" + +# pylint: disable=too-many-statements,too-many-lines,too-many-arguments,too-many-locals +import enum +import math +from dataclasses import dataclass +from typing import Any, Dict, Tuple + +from tvm import DataType, tir +from tvm.script import tir as T +from tvm.target import Target + + +class PagedKVCacheQuantization(enum.IntEnum): + """The quantization scheme to apply to Paged KV cache. + If it is none, quantization will not be applied to kv cache. + Otherwise, different quantization schemes can be applied to kv cache. + """ + + KV_NO_QUANT = 0 + KV_GROUP_QUANT_INT_3 = 1 + KV_GROUP_QUANT_INT_4 = 2 + + +@dataclass +class BaseKVConfig: # pylint: disable=too-many-instance-attributes + """Base configuration for key-value cache""" + + name: str + kind: str + head_dim: int + num_hidden_layers: int + num_attention_heads: int + num_key_value_heads: int + model_dtype: DataType + target: Target + + +@dataclass +class NoQuantizeKV(BaseKVConfig): + """Configuration for key-value non-quantization""" + + num_storage: int = 0 + kv_storage_dtype: DataType = None + + def __post_init__(self): + assert self.kind == "no-quant" + assert str(self.model_dtype) in ["float16", "float32"] + + self.num_storage = self.head_dim + self.kv_storage_dtype = self.model_dtype + + +@dataclass +class GroupQuantizeKV(BaseKVConfig): # pylint: disable=too-many-instance-attributes + """Configuration for key-value group quantization""" + + group_size: int + kv_quantize_dtype: DataType + + max_int_value: int = 0 + num_elem_per_storage: int = 0 + num_storage_per_group: int = 0 + num_groups: int = 0 + num_storage_weight: int = 0 + num_storage_scale: int = 0 + num_storage: int = 0 + kv_storage_dtype: DataType = None + + def __post_init__(self): + assert self.kind == "group-quant" + assert str(self.kv_quantize_dtype) in ["int3", "int4"] + assert str(self.model_dtype) in ["float16", "float32"] + + self.kv_storage_dtype = { + "float16": DataType("uint16"), + "float32": DataType("uint32"), + }[str(self.model_dtype)] + + self.max_int_value = (2 ** (self.kv_quantize_dtype.bits - 1)) - 1 + self.num_elem_per_storage = self.kv_storage_dtype.bits // self.kv_quantize_dtype.bits + self.num_storage_per_group = self.group_size // self.num_elem_per_storage + self.num_groups = math.ceil(self.head_dim / self.group_size) + self.num_storage_weight = self.num_storage_per_group * self.num_groups + self.num_storage_scale = self.num_groups + self.num_storage = self.num_storage_weight + self.num_storage_scale + + if self.kv_storage_dtype.bits < self.kv_quantize_dtype.bits: + raise ValueError("Storage unit should be greater or equal to quantized element") + if self.group_size % self.num_elem_per_storage != 0: + raise ValueError("Group size should be divisible by numbers of elements per storage") + + def kv_cache_quantize_transpose_append(self): + """ + Return the TIR function that appends new k/v data to PagedKVCache + (fused w/ kv quantization). + """ + + # pylint: disable=line-too-long,invalid-name + # fmt: off + @T.prim_func + def tir_kv_cache_transpose_append( + var_pages: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + var_position_map: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + ntoken = T.SizeVar("num_tokens_excluding_cache", "int64") + num_pages = T.int64() + position_map_elem_offset = T.int32() + pages = T.match_buffer(var_pages, (num_pages, 2, self.num_key_value_heads, 16, self.num_storage), self.kv_storage_dtype) + k_data = T.match_buffer(var_k_data, (ntoken, self.num_key_value_heads, self.head_dim), self.model_dtype) + v_data = T.match_buffer(var_v_data, (ntoken, self.num_key_value_heads, self.head_dim), self.model_dtype) + position_map = T.match_buffer(var_position_map, (ntoken,), "int32", elem_offset=position_map_elem_offset) + + k_max_abs_value = T.alloc_buffer((ntoken, self.num_key_value_heads, self.num_groups), self.model_dtype) + v_max_abs_value = T.alloc_buffer((ntoken, self.num_key_value_heads, self.num_groups), self.model_dtype) + k_scale = T.alloc_buffer((ntoken, self.num_key_value_heads, self.num_groups), self.model_dtype) + v_scale = T.alloc_buffer((ntoken, self.num_key_value_heads, self.num_groups), self.model_dtype) + k_compute = T.alloc_buffer((ntoken, self.num_key_value_heads, self.num_storage_weight), self.kv_storage_dtype) + v_compute = T.alloc_buffer((ntoken, self.num_key_value_heads, self.num_storage_weight), self.kv_storage_dtype) + + for i0, i1, i2, r in T.grid(ntoken, T.int64(self.num_key_value_heads), T.int64(self.num_groups), T.int64(self.group_size)): + with T.block("k_max_abs_value"): + v_i0, v_i1, v_i2, v_r = T.axis.remap("SSSR", [i0, i1, i2, r]) + T.reads(k_data[v_i0, v_i1, v_i2 * self.group_size + v_r]) + T.writes(k_max_abs_value[v_i0, v_i1, v_i2]) + with T.init(): + k_max_abs_value[v_i0, v_i1, v_i2] = T.min_value(self.model_dtype) + k_max_abs_value[v_i0, v_i1, v_i2] = T.max( + k_max_abs_value[v_i0, v_i1, v_i2], + T.if_then_else( + v_i2 * self.group_size + v_r < self.head_dim, + T.fabs(k_data[v_i0, v_i1, v_i2 * self.group_size + v_r]), + T.min_value(self.model_dtype), + ), + ) + for i0, i1, i2, r in T.grid(ntoken, T.int64(self.num_key_value_heads), T.int64(self.num_groups), T.int64(self.group_size)): + with T.block("v_max_abs_value"): + v_i0, v_i1, v_i2, v_r = T.axis.remap("SSSR", [i0, i1, i2, r]) + T.reads(v_data[v_i0, v_i1, v_i2 * self.group_size + v_r]) + T.writes(v_max_abs_value[v_i0, v_i1, v_i2]) + with T.init(): + v_max_abs_value[v_i0, v_i1, v_i2] = T.min_value(self.model_dtype) + v_max_abs_value[v_i0, v_i1, v_i2] = T.max( + v_max_abs_value[v_i0, v_i1, v_i2], + T.if_then_else( + v_i2 * self.group_size + v_r < self.head_dim, + T.fabs(v_data[v_i0, v_i1, v_i2 * self.group_size + v_r]), + T.min_value(self.model_dtype), + ), + ) + + for i0, i1, i2 in T.grid(ntoken, T.int64(self.num_key_value_heads), T.int64(self.num_groups)): + with T.block("scale"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(k_max_abs_value[v_i0, v_i1, v_i2], v_max_abs_value[v_i0, v_i1, v_i2]) + T.writes(k_scale[v_i0, v_i1, v_i2], v_scale[v_i0, v_i1, v_i2]) + k_scale[v_i0, v_i1, v_i2] = k_max_abs_value[v_i0, v_i1, v_i2] / self.max_int_value + v_scale[v_i0, v_i1, v_i2] = v_max_abs_value[v_i0, v_i1, v_i2] / self.max_int_value + + for i0, i1, i2, r in T.grid(ntoken, T.int64(self.num_key_value_heads), T.int64(self.num_storage_weight), T.int64(self.num_elem_per_storage)): + with T.block("k_compute_pack"): + v_i0, v_i1, v_i2, v_r = T.axis.remap("SSSR", [i0, i1, i2, r]) + T.reads( + k_data[v_i0, v_i1, v_i2 * self.num_elem_per_storage + v_r], + k_scale[v_i0, v_i1, v_i2 // self.num_storage_per_group], + ) + T.writes(k_compute[v_i0, v_i1, v_i2]) + with T.init(): + k_compute[v_i0, v_i1, v_i2] = 0 + k_compute[v_i0, v_i1, v_i2] = k_compute[v_i0, v_i1, v_i2] + T.if_then_else( + v_i2 * self.num_elem_per_storage + v_r < self.head_dim, + T.shift_left( + T.Cast( + self.kv_storage_dtype, + T.min( + T.max( + T.round( + k_data[v_i0, v_i1, v_i2 * self.num_elem_per_storage + v_r] + / k_scale[v_i0, v_i1, v_i2 // self.num_storage_per_group] + + self.max_int_value + ), + 0.0, + ), + self.max_int_value * 2.0, + ), + ), + T.Cast(self.kv_storage_dtype, v_r * self.kv_quantize_dtype.bits), + ), + tir.const(0, str(self.kv_storage_dtype)), + ) + for i0, i1, i2, r in T.grid(ntoken, T.int64(self.num_key_value_heads), T.int64(self.num_storage_weight), T.int64(self.num_elem_per_storage)): + with T.block("v_compute_pack"): + v_i0, v_i1, v_i2, v_r = T.axis.remap("SSSR", [i0, i1, i2, r]) + T.reads( + v_data[v_i0, v_i1, v_i2 * self.num_elem_per_storage + v_r], + v_scale[v_i0, v_i1, v_i2 // self.num_storage_per_group], + ) + T.writes(v_compute[v_i0, v_i1, v_i2]) + with T.init(): + v_compute[v_i0, v_i1, v_i2] = 0 + v_compute[v_i0, v_i1, v_i2] = v_compute[v_i0, v_i1, v_i2] + T.if_then_else( + v_i2 * self.num_elem_per_storage + v_r < self.head_dim, + T.shift_left( + T.Cast( + self.kv_storage_dtype, + T.min( + T.max( + T.round( + v_data[v_i0, v_i1, v_i2 * self.num_elem_per_storage + v_r] + / v_scale[v_i0, v_i1, v_i2 // self.num_storage_per_group] + + self.max_int_value + ), + 0.0, + ), + self.max_int_value * 2.0, + ), + ), + T.Cast(self.kv_storage_dtype, v_r * self.kv_quantize_dtype.bits), + ), + tir.const(0, str(self.kv_storage_dtype)), + ) + + for global_pos, h, f in T.grid(ntoken, T.int64(self.num_key_value_heads), T.int64(self.num_storage)): + if position_map[global_pos] != T.int32(-1): + with T.block("transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + T.reads( + position_map[vgpos], + k_compute[vgpos, vh, 0:self.num_storage_weight], + v_compute[vgpos, vh, 0:self.num_storage_weight], + k_scale[vgpos, vh, 0:self.num_storage_scale], + v_scale[vgpos, vh, 0:self.num_storage_scale], + ) + T.writes(pages[position_map[vgpos] // 16, 0:2, vh, position_map[vgpos] % 16, vf]) + position: T.int32 = position_map[vgpos] # type: ignore + + if vf < self.num_storage_weight: + pages[T.floordiv(position, 16), 0, vh, T.floormod(position, 16), vf] = k_compute[vgpos, vh, vf] + pages[T.floordiv(position, 16), 1, vh, T.floormod(position, 16), vf] = v_compute[vgpos, vh, vf] + else: + pages[T.floordiv(position, 16), 0, vh, T.floormod(position, 16), vf] = T.reinterpret( + self.kv_storage_dtype, k_scale[vgpos, vh, vf - self.num_storage_weight] + ) + pages[T.floordiv(position, 16), 1, vh, T.floormod(position, 16), vf] = T.reinterpret( + self.kv_storage_dtype, v_scale[vgpos, vh, vf - self.num_storage_weight] + ) + + # fmt: on + # pylint: enable=line-too-long,invalid-name + + return tir_kv_cache_transpose_append + + def kv_cache_dequantize( + self, + buffer: T.Buffer, + indices: Tuple[tir.Var, ...], + ): + """TIR dequantizae kv""" + + d = indices[-1] + bin_mask = (1 << self.kv_quantize_dtype.bits) - 1 + + quantized_data = T.Cast( + self.model_dtype, + T.bitwise_and( + T.shift_right( + buffer[indices[:-1] + (d // self.num_elem_per_storage,)], + T.Cast("int32", (d % self.num_elem_per_storage) * self.kv_quantize_dtype.bits), + ), + bin_mask, + ), + ) + data = ( + quantized_data - tir.const(self.max_int_value, str(self.model_dtype)) + ) * T.reinterpret( + self.model_dtype, + buffer[indices[:-1] + (self.num_storage_weight + (d // self.group_size),)], + ) + return data + + +def get_kv_storage_dtype(kv_quant_scheme: str, model_dtype: str) -> DataType: + """Get Cache storage dtype according to quantization scheme""" + return { + "kv_no_quant": DataType(model_dtype), + "kv_group_quant_int_3": DataType("int3"), + "kv_group_quant_int_4": DataType("int4"), + }[kv_quant_scheme] + + +def get_paged_kv_cache_config( + kv_quant_scheme: str, + model_dtype: str, + kwargs: Dict[str, Any], +) -> BaseKVConfig: + """Get Paged KV Cache configuration""" + return { + "kv_no_quant": NoQuantizeKV( + name="kv_no_quant", + kind="no-quant", + model_dtype=DataType(model_dtype), + **kwargs, + ), + "kv_group_quant_int_3": GroupQuantizeKV( + name="kv_group_quant_int_3", + kind="group-quant", + group_size=40, + kv_quantize_dtype=get_kv_storage_dtype("kv_group_quant_int_3", model_dtype), + model_dtype=DataType(model_dtype), + **kwargs, + ), + "kv_group_quant_int_4": GroupQuantizeKV( + name="kv_group_quant_int_4", + kind="group-quant", + group_size=32, + kv_quantize_dtype=get_kv_storage_dtype("kv_group_quant_int_4", model_dtype), + model_dtype=DataType(model_dtype), + **kwargs, + ), + }[kv_quant_scheme] diff --git a/python/mlc_llm/quantization/per_tensor_quantization.py b/python/mlc_llm/quantization/per_tensor_quantization.py index ff20c7e7dd..cc61e9d207 100644 --- a/python/mlc_llm/quantization/per_tensor_quantization.py +++ b/python/mlc_llm/quantization/per_tensor_quantization.py @@ -12,6 +12,7 @@ from mlc_llm.nn import MixtralExperts from mlc_llm.support import logging +from .paged_kv_cache_quantization import PagedKVCacheQuantization from .utils import ( apply_sharding, compile_quantize_func, @@ -34,6 +35,7 @@ class PerTensorQuantize: # pylint: disable=too-many-instance-attributes weight_dtype: Literal["e4m3_float8", "e5m2_float8"] storage_dtype: Literal["uint32", "e4m3_float8", "e5m2_float8"] model_dtype: Literal["float16"] + kv_quantization: PagedKVCacheQuantization quantize_embedding: bool = True quantize_final_fc: bool = True quantize_linear: bool = True diff --git a/python/mlc_llm/quantization/quantization.py b/python/mlc_llm/quantization/quantization.py index 1a5719a63f..a57d3b8f63 100644 --- a/python/mlc_llm/quantization/quantization.py +++ b/python/mlc_llm/quantization/quantization.py @@ -6,6 +6,7 @@ from .ft_quantization import FTQuantize from .group_quantization import GroupQuantize from .no_quantization import NoQuantize +from .paged_kv_cache_quantization import PagedKVCacheQuantization from .per_tensor_quantization import PerTensorQuantize Quantization = Any @@ -48,6 +49,7 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr linear_weight_layout="KN", quantize_embedding=True, quantize_final_fc=True, + kv_quantization=PagedKVCacheQuantization.KV_NO_QUANT, ), "q3f16_1": GroupQuantize( name="q3f16_1", @@ -59,6 +61,19 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr linear_weight_layout="NK", quantize_embedding=True, quantize_final_fc=True, + kv_quantization=PagedKVCacheQuantization.KV_NO_QUANT, + ), + "q3f16kv": GroupQuantize( + name="q3f16kv", + kind="group-quant", + group_size=32, + quantize_dtype="int4", + storage_dtype="uint32", + model_dtype="float16", + linear_weight_layout="NK", + quantize_embedding=True, + quantize_final_fc=True, + kv_quantization=PagedKVCacheQuantization.KV_GROUP_QUANT_INT_3, ), "q4f16_0": GroupQuantize( name="q4f16_0", @@ -70,6 +85,7 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr linear_weight_layout="KN", quantize_embedding=True, quantize_final_fc=True, + kv_quantization=PagedKVCacheQuantization.KV_NO_QUANT, ), "q4f16_1": GroupQuantize( name="q4f16_1", @@ -81,6 +97,19 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr linear_weight_layout="NK", quantize_embedding=True, quantize_final_fc=True, + kv_quantization=PagedKVCacheQuantization.KV_NO_QUANT, + ), + "q4f16kv": GroupQuantize( + name="q4f16kv", + kind="group-quant", + group_size=32, + quantize_dtype="int4", + storage_dtype="uint32", + model_dtype="float16", + linear_weight_layout="NK", + quantize_embedding=True, + quantize_final_fc=True, + kv_quantization=PagedKVCacheQuantization.KV_GROUP_QUANT_INT_4, ), "q4f32_1": GroupQuantize( name="q4f32_1", @@ -92,6 +121,7 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr linear_weight_layout="NK", quantize_embedding=True, quantize_final_fc=True, + kv_quantization=PagedKVCacheQuantization.KV_NO_QUANT, ), "q4f16_2": GroupQuantize( name="q4f16_2", @@ -103,6 +133,7 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr linear_weight_layout="NK", quantize_embedding=False, quantize_final_fc=False, + kv_quantization=PagedKVCacheQuantization.KV_NO_QUANT, ), "q4f16_autoawq": AWQQuantize( name="q4f16_autoawq", @@ -111,6 +142,7 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr quantize_dtype="int4", storage_dtype="uint32", model_dtype="float16", + kv_quantization=PagedKVCacheQuantization.KV_NO_QUANT, ), "q4f16_ft": FTQuantize( name="q4f16_ft", @@ -118,6 +150,7 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr quantize_dtype="int4", storage_dtype="int8", model_dtype="float16", + kv_quantization=PagedKVCacheQuantization.KV_NO_QUANT, ), "e5m2_e5m2_f16": PerTensorQuantize( name="e5m2_e5m2_f16", @@ -130,6 +163,7 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr quantize_embedding=False, quantize_linear=True, use_scale=False, + kv_quantization=PagedKVCacheQuantization.KV_NO_QUANT, ), "e4m3_e4m3_f16": PerTensorQuantize( name="e4m3_e4m3_f16", @@ -143,6 +177,7 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr quantize_linear=True, use_scale=True, calibration_mode="inference", + kv_quantization=PagedKVCacheQuantization.KV_NO_QUANT, ), "e4m3_e4m3_f16_max_calibrate": PerTensorQuantize( name="e4m3_e4m3_f16_max_calibrate", @@ -156,5 +191,6 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr quantize_linear=True, use_scale=True, calibration_mode="max", + kv_quantization=PagedKVCacheQuantization.KV_NO_QUANT, ), } diff --git a/tests/python/model/test_kv_cache.py b/tests/python/model/test_kv_cache.py index 3e3afb92cc..ed3dc3f683 100644 --- a/tests/python/model/test_kv_cache.py +++ b/tests/python/model/test_kv_cache.py @@ -7,6 +7,7 @@ from tvm.script import tir as T from mlc_llm.nn.kv_cache import FlashInferPagedKVCache, PagedKVCache, RopeMode +from mlc_llm.quantization import PagedKVCacheQuantization # mypy: disable-error-code="attr-defined" # pylint: disable=invalid-name,unused-argument,too-many-locals,too-many-statements @@ -88,6 +89,7 @@ def create_paged_kv_cache( rope_scale=1, rope_theta=10000, rotary_dim=128, + kv_quantization=PagedKVCacheQuantization.KV_NO_QUANT, dtype="float16", ) diff --git a/tests/python/op/test_tree_attn.py b/tests/python/op/test_tree_attn.py index a23231b52e..20cc2546a4 100644 --- a/tests/python/op/test_tree_attn.py +++ b/tests/python/op/test_tree_attn.py @@ -6,6 +6,7 @@ import tvm.testing from mlc_llm.op.tree_attn import tree_attn +from mlc_llm.quantization import PagedKVCacheQuantization, get_paged_kv_cache_config # test category "op_correctness" pytestmark = [pytest.mark.op_correctness] @@ -112,9 +113,21 @@ def gen_full_binary_tree(height): mask_tvm = tvm.nd.array(mask, dev) output_tvm = tvm.nd.array(output, dev) lse_tvm = tvm.nd.array(lse, dev) - target = tvm.target.Target("cuda") - kernel = tree_attn(h_kv=h_kv, h_q=h_q, d=d, dtype="float16", target=target) + + kv_cache_config = get_paged_kv_cache_config( + kv_quant_scheme=PagedKVCacheQuantization.KV_NO_QUANT.name.lower(), + model_dtype="float16", + kwargs={ + "head_dim": d, + "num_hidden_layers": -1, + "num_attention_heads": h_q, + "num_key_value_heads": h_kv, + "target": target, + }, + ) + + kernel = tree_attn(kv_cache_config) mod = tvm.build(kernel, target=target) mod( q_tvm,