Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serving] PagedKVCache Quantization #2663

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cpp/metadata/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ ModelMetadata::KVCacheMetadata ModelMetadata::KVCacheMetadata::FromJSON(
kv_cache_metadata.head_dim = json::Lookup<int64_t>(json, "head_dim");
kv_cache_metadata.num_attention_heads = json::Lookup<int64_t>(json, "num_attention_heads");
kv_cache_metadata.num_key_value_heads = json::Lookup<int64_t>(json, "num_key_value_heads");
kv_cache_metadata.kv_nbits = json::Lookup<int64_t>(json, "kv_nbits");
return kv_cache_metadata;
}

Expand Down Expand Up @@ -73,7 +74,8 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata,
result.kv_cache_metadata = {/*num_hidden_layers=*/0,
/*head_dim=*/0,
/*num_attention_heads=*/0,
/*num_key_value_heads=*/0};
/*num_key_value_heads=*/0,
/*kv_nbits=*/0};
}
{
std::vector<ModelMetadata::Param>& params = result.params;
Expand Down
1 change: 1 addition & 0 deletions cpp/metadata/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ struct ModelMetadata {
int64_t num_attention_heads;
int64_t num_key_value_heads;
int64_t head_dim;
int64_t kv_nbits;
static KVCacheMetadata FromJSON(const picojson::object& json);
};

Expand Down
4 changes: 3 additions & 1 deletion cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,10 @@ Result<MemUsageEstimationResult> EstimateMemoryUsageOnMode(
int64_t head_dim = model_metadata[i].kv_cache_metadata.head_dim;
int64_t num_qo_heads = model_metadata[i].kv_cache_metadata.num_attention_heads;
int64_t num_kv_heads = model_metadata[i].kv_cache_metadata.num_key_value_heads;
int64_t kv_nbits = model_metadata[i].kv_cache_metadata.kv_nbits;
int64_t hidden_size = head_dim * num_qo_heads;
kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25;
double kv_nbytes = kv_nbits / 8.0f;
kv_bytes_per_token += head_dim * num_kv_heads * num_layers * kv_nbytes * 2 + 1.25;
kv_aux_workspace_bytes +=
(max_num_sequence + 1) * 88 + prefill_chunk_size * (num_qo_heads + 1) * 8 +
prefill_chunk_size * head_dim * (num_qo_heads + num_kv_heads) * 4 + 48 * 1024 * 1024;
Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/bench/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" MLC LLM bench Metrics"""

import json
from typing import Any, Callable, Dict, List, Optional, Union

Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/bench/replay.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""MLC LLM bench replay request"""

import asyncio
import json
from datetime import datetime
Expand Down
16 changes: 12 additions & 4 deletions python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tvm import IRModule, relax

from mlc_llm.nn import RopeMode, kv_cache
from mlc_llm.quantization import PagedKVCacheQuantization, get_kv_storage_dtype


def extract_creation_args(func: relax.Function) -> Dict[str, Any]:
Expand All @@ -20,13 +21,13 @@ def extract_creation_args(func: relax.Function) -> Dict[str, Any]:
assert isinstance(args[0], relax.ExternFunc)
assert args[0].global_symbol == "mlc.create_paged_kv_cache_generic"

assert len(args) == 11
assert len(args) == 12
assert isinstance(args[1], relax.ShapeExpr)
assert len(args[1].values) == 5
for i in range(2, 10):
for i in range(2, 11):
assert isinstance(args[i], relax.PrimValue)
assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm))
assert isinstance(args[10], relax.DataTypeImm)
assert isinstance(args[11], relax.DataTypeImm)

return {
"max_batch_size": args[1].values[0],
Expand All @@ -42,7 +43,8 @@ def extract_creation_args(func: relax.Function) -> Dict[str, Any]:
"rope_scale": args[7].value.value,
"rope_theta": args[8].value.value,
"rotary_dim": args[9].value.value,
"dtype": args[10].value,
"kv_quantization": args[10].value.value,
"dtype": args[11].value,
}


Expand Down Expand Up @@ -105,6 +107,10 @@ def attach_kv_cache_metadata(self, kwargs: Dict[str, Any]):
"num_attention_heads": kwargs["num_attention_heads"],
"num_key_value_heads": kwargs["num_key_value_heads"],
"head_dim": kwargs["head_dim"],
"kv_nbits": get_kv_storage_dtype(
kv_quant_scheme=PagedKVCacheQuantization(kwargs["kv_quantization"]).name.lower(),
model_dtype=kwargs["dtype"],
).bits,
}

def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -156,6 +162,8 @@ def create_flashinfer_paged_kv_cache(
)
# filter by attention group size
or kwargs["num_attention_heads"] // kwargs["num_key_value_heads"] not in [1, 4, 8]
# MLC-LLM with FlashInfer and KV Quantization not supported yet
or kwargs["kv_quantization"] != PagedKVCacheQuantization.KV_NO_QUANT
):
return

Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/model/baichuan/baichuan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.quantization import PagedKVCacheQuantization
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
Expand Down Expand Up @@ -41,6 +42,7 @@ class BaichuanConfig(ConfigBase): # pylint: disable=too-many-instance-attribute
tensor_parallel_shards: int = 1
max_batch_size: int = 1
head_dim: int = 0
kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
Expand Down Expand Up @@ -203,6 +205,7 @@ def __init__(self, config: BaichuanConfig):
self.vocab_size = config.vocab_size
self.rope_theta = 10000
self.tensor_parallel_shards = config.tensor_parallel_shards
self.kv_quantization = config.kv_quantization
self.dtype = "float32"

def to(self, dtype: Optional[str] = None):
Expand Down Expand Up @@ -291,6 +294,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
rope_mode=RopeMode.NORMAL,
rope_scale=1,
rope_theta=self.rope_theta,
kv_quantization=self.kv_quantization,
dtype=self.dtype,
)

Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/model/baichuan/baichuan_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def group_quant(
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a BaichuanLM-architecture model using group quantization."""
model_config.kv_quantization = quantization.kv_quantization
model: nn.Module = BaichuanForCausalLM(model_config)
model.to(quantization.model_dtype)
quant_map = QuantizeMapping({}, {})
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/model/chatglm3/chatglm3_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.quantization import PagedKVCacheQuantization
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
Expand Down Expand Up @@ -43,6 +44,7 @@ class GLMConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
tensor_parallel_shards: int = 1
head_dim: int = 0
max_batch_size: int = 1
kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
Expand Down Expand Up @@ -279,6 +281,7 @@ def __init__(self, config: GLMConfig):
self.vocab_size = config.vocab_size
self.rope_theta = 10000
self.tensor_parallel_shards = config.tensor_parallel_shards
self.kv_quantization = config.kv_quantization
self.dtype = "float32"

def to(self, dtype: Optional[str] = None):
Expand Down Expand Up @@ -367,6 +370,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
rope_mode=RopeMode.NORMAL,
rope_scale=1,
rope_theta=self.rope_theta,
kv_quantization=self.kv_quantization,
dtype=self.dtype,
)

Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/model/chatglm3/chatglm3_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def group_quant(
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a ChatGLM-architecture model using group quantization."""
model_config.kv_quantization = quantization.kv_quantization
model: nn.Module = ChatGLMForCausalLM(model_config)
model.to(quantization.model_dtype)
quant_map = QuantizeMapping({}, {})
Expand Down
2 changes: 2 additions & 0 deletions python/mlc_llm/model/eagle/eagle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(self, config: EagleConfig):
self.vocab_size = config.vocab_size
self.rope_theta = config.position_embedding_base
self.tensor_parallel_shards = config.tensor_parallel_shards
self.kv_quantization = config.kv_quantization
self.dtype = "float32"

def fuse_embed_hidden_states(self, input_embed: Tensor, hidden_states: Tensor):
Expand Down Expand Up @@ -177,6 +178,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
rope_mode=RopeMode.NORMAL,
rope_scale=1,
rope_theta=self.rope_theta,
kv_quantization=self.kv_quantization,
dtype=self.dtype,
)

Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/model/eagle/eagle_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def group_quant(
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a Eagle-architecture model using group quantization."""
model_config.kv_quantization = quantization.kv_quantization
model: nn.Module = EagleForCasualLM(model_config)
model.to(quantization.model_dtype)
quant_map = QuantizeMapping({}, {})
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/model/gemma/gemma_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.quantization import PagedKVCacheQuantization
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
Expand Down Expand Up @@ -36,6 +37,7 @@ class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
prefill_chunk_size: int = 0
tensor_parallel_shards: int = 1
max_batch_size: int = 1
kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
Expand Down Expand Up @@ -228,6 +230,7 @@ def __init__(self, config: GemmaConfig):
self.vocab_size = config.vocab_size
self.rope_theta = config.position_embedding_base
self.tensor_parallel_shards = config.tensor_parallel_shards
self.kv_quantization = config.kv_quantization
self.dtype = "float32"

def to(self, dtype: Optional[str] = None):
Expand Down Expand Up @@ -316,6 +319,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
rope_mode=RopeMode.NORMAL,
rope_scale=1,
rope_theta=self.rope_theta,
kv_quantization=self.kv_quantization,
dtype=self.dtype,
)

Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/model/gemma/gemma_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def group_quant(
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a Gemma-architecture model using group quantization."""
model_config.kv_quantization = quantization.kv_quantization
model: nn.Module = GemmaForCausalLM(model_config)
model.to(quantization.model_dtype)
quant_map = QuantizeMapping({}, {})
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/model/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.quantization import PagedKVCacheQuantization
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
Expand All @@ -36,6 +37,7 @@ class GPT2Config(ConfigBase): # pylint: disable=too-many-instance-attributes
tensor_parallel_shards: int = 1
head_dim: int = 0
max_batch_size: int = 1
kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
Expand Down Expand Up @@ -223,6 +225,7 @@ def __init__(self, config: GPT2Config):
self.n_head = config.n_head
self.head_dim = config.head_dim
self.tensor_parallel_shards = config.tensor_parallel_shards
self.kv_quantization = config.kv_quantization
self.dtype = "float32"

def to(self, dtype: Optional[str] = None):
Expand Down Expand Up @@ -311,6 +314,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
rope_mode=RopeMode.NONE,
rope_scale=-1,
rope_theta=-1,
kv_quantization=self.kv_quantization,
dtype=self.dtype,
)

Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/model/gpt2/gpt2_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def group_quant(
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a GPT-2-architecture model using group quantization."""
model_config.kv_quantization = quantization.kv_quantization
model: nn.Module = GPT2LMHeadModel(model_config)
model.to(quantization.model_dtype)
quant_map = QuantizeMapping({}, {})
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.quantization import PagedKVCacheQuantization
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
Expand All @@ -35,6 +36,7 @@ class GPTBigCodeConfig(ConfigBase): # pylint: disable=too-many-instance-attribu
prefill_chunk_size: int = 0
tensor_parallel_shards: int = 1
max_batch_size: int = 1
kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
Expand Down Expand Up @@ -190,6 +192,7 @@ def __init__(self, config: GPTBigCodeConfig):
self.num_kv_heads = 1
self.head_dim = config.n_embd // config.n_head
self.tensor_parallel_shards = config.tensor_parallel_shards
self.kv_quantization = config.kv_quantization
self.dtype = "float32"

def to(self, dtype: Optional[str] = None):
Expand Down Expand Up @@ -278,6 +281,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
rope_mode=RopeMode.NONE,
rope_scale=-1,
rope_theta=-1,
kv_quantization=self.kv_quantization,
dtype=self.dtype,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def group_quant(
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a GPTBigCode-architecture model using group quantization."""
model_config.kv_quantization = quantization.kv_quantization
model: nn.Module = GPTBigCodeForCausalLM(model_config)
model.to(quantization.model_dtype)
quant_map = QuantizeMapping({}, {})
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/model/gpt_neox/gpt_neox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.quantization import PagedKVCacheQuantization
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
from mlc_llm.support.style import bold
Expand All @@ -39,6 +40,7 @@ class GPTNeoXConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
tensor_parallel_shards: int = 1
ffn_out_dtype: str = "float32"
max_batch_size: int = 1
kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
Expand Down Expand Up @@ -253,6 +255,7 @@ def __init__(self, config: GPTNeoXConfig):
self.vocab_size = config.vocab_size
self.rope_theta = config.position_embedding_base
self.tensor_parallel_shards = config.tensor_parallel_shards
self.kv_quantization = config.kv_quantization
self.dtype = "float32"
self.rotary_pct = config.rotary_pct

Expand Down Expand Up @@ -342,6 +345,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
rope_mode=RopeMode.NORMAL,
rope_scale=1,
rope_theta=self.rope_theta,
kv_quantization=self.kv_quantization,
dtype=self.dtype,
rotary_dim=int(self.head_dim * self.rotary_pct),
)
Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/model/gpt_neox/gpt_neox_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def group_quant(
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a GPTNeoX-architecture model using group quantization."""
model_config.kv_quantization = quantization.kv_quantization
model: nn.Module = GPTNeoXForCausalLM(model_config)
model.to(quantization.model_dtype)
quant_map = QuantizeMapping({}, {})
Expand Down
Loading
Loading