-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optimize implementation with vLLM (#13)
- Loading branch information
1 parent
00ff94e
commit 6354eee
Showing
11 changed files
with
443 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
.ipynb_checkpoints | ||
__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
pytorch==2.3.1 | ||
flash_attn==2.6.2 | ||
torch==2.4.0 | ||
flash_attn==2.6.3 | ||
vllm==0.6.3.post1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
""" | ||
embed_text package | ||
""" | ||
|
||
import gc | ||
|
||
import torch | ||
from datasets import Dataset | ||
from torch.utils.data import DataLoader | ||
from tqdm import tqdm | ||
|
||
from .utils import register_model | ||
|
||
register_model() | ||
|
||
from vllm import LLM | ||
|
||
|
||
class Embedder: | ||
""" | ||
Instances of this class can embed sentences to embeddings. | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Initialize class object. | ||
""" | ||
self.model = None | ||
self.which_model = None | ||
|
||
def load(self, model_name: str, *arg, **kwargs): | ||
""" | ||
Loads class variables: model and tokenizer. | ||
:param model_name: HF model name (used for model and tokenizer) | ||
format: "hf_repo/hf_model" | ||
:type model_name: str | ||
:param self.model: LLM-style model that transforms tokens & attention | ||
mask to embeddings | ||
:type self.model: AutoModel | ||
:param self.tokenizer: Tokenizer mapping strings to key-values | ||
:type self.tokenizer: AutoTokenizer | ||
:param which_model: Variable storing the name of the loaded model | ||
:type which_model: str | ||
""" | ||
self.model = LLM(model=model_name, *arg, **kwargs) | ||
self.which_model = model_name | ||
|
||
def unload(self): | ||
""" | ||
Unloads class variables: model and tokenizer | ||
""" | ||
del self.model | ||
self.model = None | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
|
||
def get_embeddings(self, dataloader: DataLoader, model_name: str, cols: list): | ||
""" | ||
Function converts sentences to sentence embeddings. Designed to take | ||
dataloader format as input. Dataset of dataloader should contain | ||
sentences in string format. | ||
:param dataloader: Dataloader object of pytorch | ||
:type dataloader: DataLoader | ||
:param model_name: HF model name (used for model and tokenizer) | ||
format: "hf_repo/hf_model" | ||
:type model_name: str | ||
:param cols: list of column names to be embedded | ||
:type cols: list | ||
:return: Dataset with columns cols and embeddings of sentences | ||
:rtype: Dataset | ||
""" | ||
assert ( | ||
model_name == self.which_model | ||
), f"Model '{model_name}' is not preloaded. Loaded model is \ | ||
'{self.which_model}'. Load the correct model by calling the load \ | ||
function." | ||
|
||
emb_dict = {} | ||
|
||
for col in cols: | ||
col_emb = [] | ||
tqdm_dataloader = tqdm(dataloader) | ||
for batch in tqdm_dataloader: | ||
encoded = self.model.encode(batch[col]) | ||
col_emb.extend([x.outputs.embedding for x in encoded]) | ||
|
||
emb_dict[col] = col_emb | ||
# >>> num_cols x dataset_length x hidden_size | ||
emb_dataset = Dataset.from_dict(emb_dict) | ||
return emb_dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .deepseek import DeepseekEmbeddingModel | ||
from .llama import LlamaEmbeddingModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
from typing import Iterable, List, Optional, Tuple, Union | ||
|
||
import torch | ||
from torch import nn | ||
from transformers import PretrainedConfig | ||
from vllm.attention import AttentionMetadata | ||
from vllm.config import CacheConfig | ||
from vllm.model_executor.layers.fused_moe import FusedMoE | ||
from vllm.model_executor.layers.pooler import Pooler, PoolingType | ||
from vllm.model_executor.layers.quantization import QuantizationConfig | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model | ||
from vllm.model_executor.models.interfaces import SupportsPP | ||
from vllm.model_executor.models.utils import is_pp_missing_parameter | ||
from vllm.model_executor.pooling_metadata import PoolingMetadata | ||
from vllm.sequence import IntermediateTensors, PoolerOutput | ||
|
||
|
||
class DeepseekEmbeddingModel(nn.Module, SupportsPP): | ||
""" | ||
A model that uses Llama with additional embedding functionalities. | ||
This class encapsulates the LlamaModel and provides an interface for | ||
embedding operations and customized pooling functions. | ||
Attributes: | ||
model: An instance of LlamaModel used for forward operations. | ||
_pooler: An instance of Pooler used for pooling operations. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
config: PretrainedConfig, | ||
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
) -> None: | ||
super().__init__() | ||
self.config = config | ||
self.quant_config = quant_config | ||
|
||
self.model = DeepseekV2Model(config, cache_config, quant_config, prefix="model") | ||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) | ||
self.make_empty_intermediate_tensors = ( | ||
self.model.make_empty_intermediate_tensors | ||
) | ||
|
||
def forward( | ||
self, | ||
input_ids: Optional[torch.Tensor], | ||
positions: torch.Tensor, | ||
kv_caches: List[torch.Tensor], | ||
attn_metadata: AttentionMetadata, | ||
intermediate_tensors: Optional[IntermediateTensors] = None, | ||
) -> Union[torch.Tensor, IntermediateTensors]: | ||
return self.model( | ||
input_ids, | ||
positions, | ||
kv_caches, | ||
attn_metadata, | ||
intermediate_tensors, | ||
) | ||
|
||
def pooler( | ||
self, | ||
hidden_states: torch.Tensor, | ||
pooling_metadata: PoolingMetadata, | ||
) -> Optional[PoolerOutput]: | ||
return self._pooler(hidden_states, pooling_metadata) | ||
|
||
def load_kv_cache_scales(self, quantization_param_path: str) -> None: | ||
self.model.load_kv_cache_scales(quantization_param_path) | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
stacked_params_mapping = [ | ||
# (param_name, shard_name, shard_id) | ||
("gate_up_proj", "gate_proj", 0), | ||
("gate_up_proj", "up_proj", 1), | ||
] | ||
|
||
# Params for weights, fp8 weight scales, fp8 activation scales | ||
# (param_name, weight_name, expert_id, shard_id) | ||
expert_params_mapping = FusedMoE.make_expert_params_mapping( | ||
ckpt_gate_proj_name="gate_proj", | ||
ckpt_down_proj_name="down_proj", | ||
ckpt_up_proj_name="up_proj", | ||
num_experts=self.config.n_routed_experts, | ||
) | ||
|
||
params_dict = dict(self.named_parameters()) | ||
for name, loaded_weight in weights: | ||
if "rotary_emb.inv_freq" in name: | ||
continue | ||
for param_name, weight_name, shard_id in stacked_params_mapping: | ||
# Skip non-stacked layers and experts (experts handled below). | ||
if weight_name not in name: | ||
continue | ||
# We have mlp.experts[0].gate_proj in the checkpoint. | ||
# Since we handle the experts below in expert_params_mapping, | ||
# we need to skip here BEFORE we update the name, otherwise | ||
# name will be updated to mlp.experts[0].gate_up_proj, which | ||
# will then be updated below in expert_params_mapping | ||
# for mlp.experts[0].gate_gate_up_proj, which breaks load. | ||
if ("mlp.experts." in name) and name not in params_dict: | ||
continue | ||
name = name.replace(weight_name, param_name) | ||
# Skip loading extra bias for GPTQ models. | ||
if name.endswith(".bias") and name not in params_dict: | ||
continue | ||
|
||
if is_pp_missing_parameter(name, self): | ||
continue | ||
|
||
param = params_dict[name] | ||
weight_loader = param.weight_loader | ||
weight_loader(param, loaded_weight, shard_id) | ||
break | ||
else: | ||
for mapping in expert_params_mapping: | ||
param_name, weight_name, expert_id, shard_id = mapping | ||
if weight_name not in name: | ||
continue | ||
name = name.replace(weight_name, param_name) | ||
|
||
if is_pp_missing_parameter(name, self): | ||
continue | ||
|
||
param = params_dict[name] | ||
weight_loader = param.weight_loader | ||
weight_loader( | ||
param, | ||
loaded_weight, | ||
name, | ||
shard_id=shard_id, | ||
expert_id=expert_id, | ||
) | ||
break | ||
else: | ||
# Skip loading extra bias for GPTQ models. | ||
if ( | ||
name.endswith(".bias") | ||
and name not in params_dict | ||
or name == "lm_head.weight" | ||
): | ||
continue | ||
|
||
if is_pp_missing_parameter(name, self): | ||
continue | ||
|
||
param = params_dict[name] | ||
weight_loader = getattr( | ||
param, "weight_loader", default_weight_loader | ||
) | ||
weight_loader(param, loaded_weight) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from typing import Iterable, List, Optional, Tuple, Union | ||
|
||
import torch | ||
from torch import nn | ||
from vllm.attention import AttentionMetadata | ||
from vllm.model_executor.layers.pooler import Pooler, PoolingType | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
from vllm.model_executor.models.interfaces import SupportsPP | ||
from vllm.model_executor.models.llama import LlamaModel | ||
from vllm.model_executor.pooling_metadata import PoolingMetadata | ||
from vllm.sequence import IntermediateTensors, PoolerOutput | ||
|
||
|
||
class LlamaEmbeddingModel(nn.Module, SupportsPP): | ||
""" | ||
A model that uses Llama with additional embedding functionalities. | ||
This class encapsulates the LlamaModel and provides an interface for | ||
embedding operations and customized pooling functions. | ||
Attributes: | ||
model: An instance of LlamaModel used for forward operations. | ||
_pooler: An instance of Pooler used for pooling operations. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
**kwargs, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.model = LlamaModel(**kwargs) | ||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) | ||
self.make_empty_intermediate_tensors = ( | ||
self.model.make_empty_intermediate_tensors | ||
) | ||
|
||
def forward( | ||
self, | ||
input_ids: Optional[torch.Tensor], | ||
positions: torch.Tensor, | ||
kv_caches: List[torch.Tensor], | ||
attn_metadata: AttentionMetadata, | ||
intermediate_tensors: Optional[IntermediateTensors] = None, | ||
inputs_embeds: Optional[torch.Tensor] = None, | ||
) -> Union[torch.Tensor, IntermediateTensors]: | ||
return self.model( | ||
input_ids, | ||
positions, | ||
kv_caches, | ||
attn_metadata, | ||
intermediate_tensors, | ||
inputs_embeds, | ||
) | ||
|
||
def pooler( | ||
self, | ||
hidden_states: torch.Tensor, | ||
pooling_metadata: PoolingMetadata, | ||
) -> Optional[PoolerOutput]: | ||
return self._pooler(hidden_states, pooling_metadata) | ||
|
||
def load_kv_cache_scales(self, quantization_param_path: str) -> None: | ||
self.model.load_kv_cache_scales(quantization_param_path) | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
stacked_params_mapping = [ | ||
# (param_name, shard_name, shard_id) | ||
("qkv_proj", "q_proj", "q"), | ||
("qkv_proj", "k_proj", "k"), | ||
("qkv_proj", "v_proj", "v"), | ||
("gate_up_proj", "gate_proj", 0), | ||
("gate_up_proj", "up_proj", 1), | ||
] | ||
params_dict = dict(self.model.named_parameters()) | ||
|
||
for name, loaded_weight in weights: | ||
name = name.replace("model.", "") | ||
if "rotary_emb.inv_freq" in name: | ||
continue | ||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: | ||
# Models trained using ColossalAI may include these tensors in | ||
# the checkpoint. Skip them. | ||
continue | ||
for param_name, weight_name, shard_id in stacked_params_mapping: | ||
if weight_name not in name: | ||
continue | ||
name = name.replace(weight_name, param_name) | ||
# Skip loading extra bias for GPTQ models. | ||
if name.endswith(".bias") and name not in params_dict: | ||
continue | ||
param = params_dict[name] | ||
weight_loader = param.weight_loader | ||
weight_loader(param, loaded_weight, shard_id) | ||
break | ||
else: | ||
# Skip loading extra bias for GPTQ models. | ||
if ( | ||
name.endswith(".bias") | ||
and name not in params_dict | ||
or name == "lm_head.weight" | ||
): | ||
continue | ||
param = params_dict[name] | ||
weight_loader = getattr(param, "weight_loader", default_weight_loader) | ||
weight_loader(param, loaded_weight) |
Oops, something went wrong.