Skip to content

Commit

Permalink
Optimize implementation with vLLM (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinakaduc authored Nov 9, 2024
1 parent 00ff94e commit 6354eee
Show file tree
Hide file tree
Showing 11 changed files with 443 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.ipynb_checkpoints
__pycache__
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pip install git+https://github.com/stair-lab/embedder.git
```
In your script, include the module:
```bash
from embed_text_package.embed_text import Embedder
from embed_text_package.embed_text_v2 import Embedder
```


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ name = "embed_text_package"
version = "0.0.1"
authors = [
{ name="Luca Morlok", email="luca.morlok@stanford.edu" },
{ name="Martin Nguyen", email="nqduc@cs.stanford.edu" },
]
description = "A small package providing a function for embedding extraction"
readme = "README.md"
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pytorch==2.3.1
flash_attn==2.6.2
torch==2.4.0
flash_attn==2.6.3
vllm==0.6.3.post1
2 changes: 2 additions & 0 deletions src/embed_text_package/embed_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def load(self, model_name: str):
:type which_model: str
"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if not self.tokenizer.pad_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModel.from_pretrained(model_name, device_map="auto")
self.which_model = model_name

Expand Down
94 changes: 94 additions & 0 deletions src/embed_text_package/embed_text_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
embed_text package
"""

import gc

import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

from .utils import register_model

register_model()

from vllm import LLM


class Embedder:
"""
Instances of this class can embed sentences to embeddings.
"""

def __init__(self):
"""
Initialize class object.
"""
self.model = None
self.which_model = None

def load(self, model_name: str, *arg, **kwargs):
"""
Loads class variables: model and tokenizer.
:param model_name: HF model name (used for model and tokenizer)
format: "hf_repo/hf_model"
:type model_name: str
:param self.model: LLM-style model that transforms tokens & attention
mask to embeddings
:type self.model: AutoModel
:param self.tokenizer: Tokenizer mapping strings to key-values
:type self.tokenizer: AutoTokenizer
:param which_model: Variable storing the name of the loaded model
:type which_model: str
"""
self.model = LLM(model=model_name, *arg, **kwargs)
self.which_model = model_name

def unload(self):
"""
Unloads class variables: model and tokenizer
"""
del self.model
self.model = None
gc.collect()
torch.cuda.empty_cache()

def get_embeddings(self, dataloader: DataLoader, model_name: str, cols: list):
"""
Function converts sentences to sentence embeddings. Designed to take
dataloader format as input. Dataset of dataloader should contain
sentences in string format.
:param dataloader: Dataloader object of pytorch
:type dataloader: DataLoader
:param model_name: HF model name (used for model and tokenizer)
format: "hf_repo/hf_model"
:type model_name: str
:param cols: list of column names to be embedded
:type cols: list
:return: Dataset with columns cols and embeddings of sentences
:rtype: Dataset
"""
assert (
model_name == self.which_model
), f"Model '{model_name}' is not preloaded. Loaded model is \
'{self.which_model}'. Load the correct model by calling the load \
function."

emb_dict = {}

for col in cols:
col_emb = []
tqdm_dataloader = tqdm(dataloader)
for batch in tqdm_dataloader:
encoded = self.model.encode(batch[col])
col_emb.extend([x.outputs.embedding for x in encoded])

emb_dict[col] = col_emb
# >>> num_cols x dataset_length x hidden_size
emb_dataset = Dataset.from_dict(emb_dict)
return emb_dataset
2 changes: 2 additions & 0 deletions src/embed_text_package/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .deepseek import DeepseekEmbeddingModel
from .llama import LlamaEmbeddingModel
153 changes: 153 additions & 0 deletions src/embed_text_package/models/deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from typing import Iterable, List, Optional, Tuple, Union

import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput


class DeepseekEmbeddingModel(nn.Module, SupportsPP):
"""
A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config

self.model = DeepseekV2Model(config, cache_config, quant_config, prefix="model")
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
)

params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias")
and name not in params_dict
or name == "lm_head.weight"
):
continue

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
106 changes: 106 additions & 0 deletions src/embed_text_package/models/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Iterable, List, Optional, Tuple, Union

import torch
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput


class LlamaEmbeddingModel(nn.Module, SupportsPP):
"""
A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""

def __init__(
self,
**kwargs,
) -> None:
super().__init__()

self.model = LlamaModel(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)

def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds,
)

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.model.named_parameters())

for name, loaded_weight in weights:
name = name.replace("model.", "")
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias")
and name not in params_dict
or name == "lm_head.weight"
):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
Loading

0 comments on commit 6354eee

Please sign in to comment.