Skip to content

Commit

Permalink
kv quantize
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpissarra committed Jul 16, 2024
1 parent baeb195 commit d2acb93
Show file tree
Hide file tree
Showing 59 changed files with 774 additions and 175 deletions.
4 changes: 3 additions & 1 deletion cpp/metadata/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ ModelMetadata::KVCacheMetadata ModelMetadata::KVCacheMetadata::FromJSON(
kv_cache_metadata.head_dim = json::Lookup<int64_t>(json, "head_dim");
kv_cache_metadata.num_attention_heads = json::Lookup<int64_t>(json, "num_attention_heads");
kv_cache_metadata.num_key_value_heads = json::Lookup<int64_t>(json, "num_key_value_heads");
kv_cache_metadata.kv_nbits = json::Lookup<int64_t>(json, "kv_nbits");
return kv_cache_metadata;
}

Expand Down Expand Up @@ -73,7 +74,8 @@ ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata,
result.kv_cache_metadata = {/*num_hidden_layers=*/0,
/*head_dim=*/0,
/*num_attention_heads=*/0,
/*num_key_value_heads=*/0};
/*num_key_value_heads=*/0,
/*kv_nbits=*/0};
}
{
std::vector<ModelMetadata::Param>& params = result.params;
Expand Down
1 change: 1 addition & 0 deletions cpp/metadata/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ struct ModelMetadata {
int64_t num_attention_heads;
int64_t num_key_value_heads;
int64_t head_dim;
int64_t kv_nbits;
static KVCacheMetadata FromJSON(const picojson::object& json);
};

Expand Down
4 changes: 3 additions & 1 deletion cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,10 @@ Result<MemUsageEstimationResult> EstimateMemoryUsageOnMode(
int64_t head_dim = model_metadata[i].kv_cache_metadata.head_dim;
int64_t num_qo_heads = model_metadata[i].kv_cache_metadata.num_attention_heads;
int64_t num_kv_heads = model_metadata[i].kv_cache_metadata.num_key_value_heads;
int64_t kv_nbits = model_metadata[i].kv_cache_metadata.kv_nbits;
int64_t hidden_size = head_dim * num_qo_heads;
kv_bytes_per_token += head_dim * num_kv_heads * num_layers * 4 + 1.25;
double kv_nbytes = kv_nbits / 8.0f;
kv_bytes_per_token += head_dim * num_kv_heads * num_layers * kv_nbytes * 2 + 1.25;
kv_aux_workspace_bytes +=
(max_num_sequence + 1) * 88 + prefill_chunk_size * (num_qo_heads + 1) * 8 +
prefill_chunk_size * head_dim * (num_qo_heads + num_kv_heads) * 4 + 48 * 1024 * 1024;
Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/bench/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" MLC LLM bench Metrics"""

import json
from typing import Any, Callable, Dict, List, Optional, Union

Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/bench/replay.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""MLC LLM bench replay request"""

import asyncio
import json
from datetime import datetime
Expand Down
16 changes: 12 additions & 4 deletions python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tvm import IRModule, relax

from mlc_llm.nn import RopeMode, kv_cache
from mlc_llm.quantization import PagedKVCacheQuantization, get_kv_storage_dtype


def extract_creation_args(func: relax.Function) -> Dict[str, Any]:
Expand All @@ -20,13 +21,13 @@ def extract_creation_args(func: relax.Function) -> Dict[str, Any]:
assert isinstance(args[0], relax.ExternFunc)
assert args[0].global_symbol == "mlc.create_paged_kv_cache_generic"

assert len(args) == 11
assert len(args) == 12
assert isinstance(args[1], relax.ShapeExpr)
assert len(args[1].values) == 5
for i in range(2, 10):
for i in range(2, 11):
assert isinstance(args[i], relax.PrimValue)
assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm))
assert isinstance(args[10], relax.DataTypeImm)
assert isinstance(args[11], relax.DataTypeImm)

return {
"max_batch_size": args[1].values[0],
Expand All @@ -42,7 +43,8 @@ def extract_creation_args(func: relax.Function) -> Dict[str, Any]:
"rope_scale": args[7].value.value,
"rope_theta": args[8].value.value,
"rotary_dim": args[9].value.value,
"dtype": args[10].value,
"kv_quantization": args[10].value.value,
"dtype": args[11].value,
}


Expand Down Expand Up @@ -105,6 +107,10 @@ def attach_kv_cache_metadata(self, kwargs: Dict[str, Any]):
"num_attention_heads": kwargs["num_attention_heads"],
"num_key_value_heads": kwargs["num_key_value_heads"],
"head_dim": kwargs["head_dim"],
"kv_nbits": get_kv_storage_dtype(
kv_quant_scheme=PagedKVCacheQuantization(kwargs["kv_quantization"]).name.lower(),
model_dtype=kwargs["dtype"],
).bits,
}

def create_tir_paged_kv_cache(self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -156,6 +162,8 @@ def create_flashinfer_paged_kv_cache(
)
# filter by attention group size
or kwargs["num_attention_heads"] // kwargs["num_key_value_heads"] not in [1, 4, 8]
# MLC-LLM with FlashInfer and KV Quantization not supported yet
or kwargs["kv_quantization"] != PagedKVCacheQuantization.KV_NO_QUANT
):
return

Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/model/baichuan/baichuan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.quantization import PagedKVCacheQuantization
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
Expand Down Expand Up @@ -41,6 +42,7 @@ class BaichuanConfig(ConfigBase): # pylint: disable=too-many-instance-attribute
tensor_parallel_shards: int = 1
max_batch_size: int = 1
head_dim: int = 0
kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
Expand Down Expand Up @@ -203,6 +205,7 @@ def __init__(self, config: BaichuanConfig):
self.vocab_size = config.vocab_size
self.rope_theta = 10000
self.tensor_parallel_shards = config.tensor_parallel_shards
self.kv_quantization = config.kv_quantization
self.dtype = "float32"

def to(self, dtype: Optional[str] = None):
Expand Down Expand Up @@ -291,6 +294,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
rope_mode=RopeMode.NORMAL,
rope_scale=1,
rope_theta=self.rope_theta,
kv_quantization=self.kv_quantization,
dtype=self.dtype,
)

Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/model/baichuan/baichuan_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def group_quant(
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a BaichuanLM-architecture model using group quantization."""
model_config.kv_quantization = quantization.kv_quantization
model: nn.Module = BaichuanForCausalLM(model_config)
model.to(quantization.model_dtype)
quant_map = QuantizeMapping({}, {})
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/model/chatglm3/chatglm3_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.quantization import PagedKVCacheQuantization
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
Expand Down Expand Up @@ -43,6 +44,7 @@ class GLMConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
tensor_parallel_shards: int = 1
head_dim: int = 0
max_batch_size: int = 1
kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
Expand Down Expand Up @@ -279,6 +281,7 @@ def __init__(self, config: GLMConfig):
self.vocab_size = config.vocab_size
self.rope_theta = 10000
self.tensor_parallel_shards = config.tensor_parallel_shards
self.kv_quantization = config.kv_quantization
self.dtype = "float32"

def to(self, dtype: Optional[str] = None):
Expand Down Expand Up @@ -367,6 +370,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
rope_mode=RopeMode.NORMAL,
rope_scale=1,
rope_theta=self.rope_theta,
kv_quantization=self.kv_quantization,
dtype=self.dtype,
)

Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/model/chatglm3/chatglm3_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def group_quant(
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a ChatGLM-architecture model using group quantization."""
model_config.kv_quantization = quantization.kv_quantization
model: nn.Module = ChatGLMForCausalLM(model_config)
model.to(quantization.model_dtype)
quant_map = QuantizeMapping({}, {})
Expand Down
2 changes: 2 additions & 0 deletions python/mlc_llm/model/eagle/eagle_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(self, config: EagleConfig):
self.vocab_size = config.vocab_size
self.rope_theta = config.position_embedding_base
self.tensor_parallel_shards = config.tensor_parallel_shards
self.kv_quantization = config.kv_quantization
self.dtype = "float32"

def fuse_embed_hidden_states(self, input_embed: Tensor, hidden_states: Tensor):
Expand Down Expand Up @@ -177,6 +178,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
rope_mode=RopeMode.NORMAL,
rope_scale=1,
rope_theta=self.rope_theta,
kv_quantization=self.kv_quantization,
dtype=self.dtype,
)

Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/model/eagle/eagle_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def group_quant(
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a Eagle-architecture model using group quantization."""
model_config.kv_quantization = quantization.kv_quantization
model: nn.Module = EagleForCasualLM(model_config)
model.to(quantization.model_dtype)
quant_map = QuantizeMapping({}, {})
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/model/gemma/gemma_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.quantization import PagedKVCacheQuantization
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
Expand Down Expand Up @@ -36,6 +37,7 @@ class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
prefill_chunk_size: int = 0
tensor_parallel_shards: int = 1
max_batch_size: int = 1
kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
Expand Down Expand Up @@ -228,6 +230,7 @@ def __init__(self, config: GemmaConfig):
self.vocab_size = config.vocab_size
self.rope_theta = config.position_embedding_base
self.tensor_parallel_shards = config.tensor_parallel_shards
self.kv_quantization = config.kv_quantization
self.dtype = "float32"

def to(self, dtype: Optional[str] = None):
Expand Down Expand Up @@ -316,6 +319,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
rope_mode=RopeMode.NORMAL,
rope_scale=1,
rope_theta=self.rope_theta,
kv_quantization=self.kv_quantization,
dtype=self.dtype,
)

Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/model/gemma/gemma_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def group_quant(
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a Gemma-architecture model using group quantization."""
model_config.kv_quantization = quantization.kv_quantization
model: nn.Module = GemmaForCausalLM(model_config)
model.to(quantization.model_dtype)
quant_map = QuantizeMapping({}, {})
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/model/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.quantization import PagedKVCacheQuantization
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
Expand All @@ -36,6 +37,7 @@ class GPT2Config(ConfigBase): # pylint: disable=too-many-instance-attributes
tensor_parallel_shards: int = 1
head_dim: int = 0
max_batch_size: int = 1
kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
Expand Down Expand Up @@ -223,6 +225,7 @@ def __init__(self, config: GPT2Config):
self.n_head = config.n_head
self.head_dim = config.head_dim
self.tensor_parallel_shards = config.tensor_parallel_shards
self.kv_quantization = config.kv_quantization
self.dtype = "float32"

def to(self, dtype: Optional[str] = None):
Expand Down Expand Up @@ -311,6 +314,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
rope_mode=RopeMode.NONE,
rope_scale=-1,
rope_theta=-1,
kv_quantization=self.kv_quantization,
dtype=self.dtype,
)

Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/model/gpt2/gpt2_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def group_quant(
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a GPT-2-architecture model using group quantization."""
model_config.kv_quantization = quantization.kv_quantization
model: nn.Module = GPT2LMHeadModel(model_config)
model.to(quantization.model_dtype)
quant_map = QuantizeMapping({}, {})
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.quantization import PagedKVCacheQuantization
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
Expand All @@ -35,6 +36,7 @@ class GPTBigCodeConfig(ConfigBase): # pylint: disable=too-many-instance-attribu
prefill_chunk_size: int = 0
tensor_parallel_shards: int = 1
max_batch_size: int = 1
kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
Expand Down Expand Up @@ -190,6 +192,7 @@ def __init__(self, config: GPTBigCodeConfig):
self.num_kv_heads = 1
self.head_dim = config.n_embd // config.n_head
self.tensor_parallel_shards = config.tensor_parallel_shards
self.kv_quantization = config.kv_quantization
self.dtype = "float32"

def to(self, dtype: Optional[str] = None):
Expand Down Expand Up @@ -278,6 +281,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
rope_mode=RopeMode.NONE,
rope_scale=-1,
rope_theta=-1,
kv_quantization=self.kv_quantization,
dtype=self.dtype,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def group_quant(
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a GPTBigCode-architecture model using group quantization."""
model_config.kv_quantization = quantization.kv_quantization
model: nn.Module = GPTBigCodeForCausalLM(model_config)
model.to(quantization.model_dtype)
quant_map = QuantizeMapping({}, {})
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/model/gpt_neox/gpt_neox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from mlc_llm import op as op_ext
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.quantization import PagedKVCacheQuantization
from mlc_llm.support import tensor_parallel as tp
from mlc_llm.support.config import ConfigBase
from mlc_llm.support.style import bold
Expand All @@ -39,6 +40,7 @@ class GPTNeoXConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
tensor_parallel_shards: int = 1
ffn_out_dtype: str = "float32"
max_batch_size: int = 1
kv_quantization: PagedKVCacheQuantization = PagedKVCacheQuantization.KV_NO_QUANT
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
Expand Down Expand Up @@ -253,6 +255,7 @@ def __init__(self, config: GPTNeoXConfig):
self.vocab_size = config.vocab_size
self.rope_theta = config.position_embedding_base
self.tensor_parallel_shards = config.tensor_parallel_shards
self.kv_quantization = config.kv_quantization
self.dtype = "float32"
self.rotary_pct = config.rotary_pct

Expand Down Expand Up @@ -342,6 +345,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments
rope_mode=RopeMode.NORMAL,
rope_scale=1,
rope_theta=self.rope_theta,
kv_quantization=self.kv_quantization,
dtype=self.dtype,
rotary_dim=int(self.head_dim * self.rotary_pct),
)
Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/model/gpt_neox/gpt_neox_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def group_quant(
quantization: GroupQuantize,
) -> Tuple[nn.Module, QuantizeMapping]:
"""Quantize a GPTNeoX-architecture model using group quantization."""
model_config.kv_quantization = quantization.kv_quantization
model: nn.Module = GPTNeoXForCausalLM(model_config)
model.to(quantization.model_dtype)
quant_map = QuantizeMapping({}, {})
Expand Down
Loading

0 comments on commit d2acb93

Please sign in to comment.