diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..212d545 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.ipynb_checkpoints +__pycache__ diff --git a/README.md b/README.md index 11071b0..d7fb5bd 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ pip install git+https://github.com/stair-lab/embedder.git ``` In your script, include the module: ```bash -from embed_text_package.embed_text import Embedder +from embed_text_package.embed_text_v2 import Embedder ``` diff --git a/pyproject.toml b/pyproject.toml index f7a4edb..a6df1a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ name = "embed_text_package" version = "0.0.1" authors = [ { name="Luca Morlok", email="luca.morlok@stanford.edu" }, + { name="Martin Nguyen", email="nqduc@cs.stanford.edu" }, ] description = "A small package providing a function for embedding extraction" readme = "README.md" diff --git a/requirements.txt b/requirements.txt index a070996..121dbe9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ -pytorch==2.3.1 -flash_attn==2.6.2 \ No newline at end of file +torch==2.4.0 +flash_attn==2.6.3 +vllm==0.6.3.post1 \ No newline at end of file diff --git a/src/embed_text_package/embed_text.py b/src/embed_text_package/embed_text.py index de0c9c8..91345dc 100644 --- a/src/embed_text_package/embed_text.py +++ b/src/embed_text_package/embed_text.py @@ -40,6 +40,8 @@ def load(self, model_name: str): :type which_model: str """ self.tokenizer = AutoTokenizer.from_pretrained(model_name) + if not self.tokenizer.pad_token: + self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModel.from_pretrained(model_name, device_map="auto") self.which_model = model_name diff --git a/src/embed_text_package/embed_text_v2.py b/src/embed_text_package/embed_text_v2.py new file mode 100644 index 0000000..6f1c5cc --- /dev/null +++ b/src/embed_text_package/embed_text_v2.py @@ -0,0 +1,94 @@ +""" +embed_text package +""" + +import gc + +import torch +from datasets import Dataset +from torch.utils.data import DataLoader +from tqdm import tqdm + +from .utils import register_model + +register_model() + +from vllm import LLM + + +class Embedder: + """ + Instances of this class can embed sentences to embeddings. + """ + + def __init__(self): + """ + Initialize class object. + """ + self.model = None + self.which_model = None + + def load(self, model_name: str, *arg, **kwargs): + """ + Loads class variables: model and tokenizer. + + :param model_name: HF model name (used for model and tokenizer) + format: "hf_repo/hf_model" + :type model_name: str + :param self.model: LLM-style model that transforms tokens & attention + mask to embeddings + :type self.model: AutoModel + :param self.tokenizer: Tokenizer mapping strings to key-values + :type self.tokenizer: AutoTokenizer + :param which_model: Variable storing the name of the loaded model + :type which_model: str + """ + self.model = LLM(model=model_name, *arg, **kwargs) + self.which_model = model_name + + def unload(self): + """ + Unloads class variables: model and tokenizer + """ + del self.model + self.model = None + gc.collect() + torch.cuda.empty_cache() + + def get_embeddings(self, dataloader: DataLoader, model_name: str, cols: list): + """ + Function converts sentences to sentence embeddings. Designed to take + dataloader format as input. Dataset of dataloader should contain + sentences in string format. + + :param dataloader: Dataloader object of pytorch + :type dataloader: DataLoader + :param model_name: HF model name (used for model and tokenizer) + format: "hf_repo/hf_model" + :type model_name: str + :param cols: list of column names to be embedded + :type cols: list + + + :return: Dataset with columns cols and embeddings of sentences + :rtype: Dataset + """ + assert ( + model_name == self.which_model + ), f"Model '{model_name}' is not preloaded. Loaded model is \ + '{self.which_model}'. Load the correct model by calling the load \ + function." + + emb_dict = {} + + for col in cols: + col_emb = [] + tqdm_dataloader = tqdm(dataloader) + for batch in tqdm_dataloader: + encoded = self.model.encode(batch[col]) + col_emb.extend([x.outputs.embedding for x in encoded]) + + emb_dict[col] = col_emb + # >>> num_cols x dataset_length x hidden_size + emb_dataset = Dataset.from_dict(emb_dict) + return emb_dataset diff --git a/src/embed_text_package/models/__init__.py b/src/embed_text_package/models/__init__.py new file mode 100644 index 0000000..33e370f --- /dev/null +++ b/src/embed_text_package/models/__init__.py @@ -0,0 +1,2 @@ +from .deepseek import DeepseekEmbeddingModel +from .llama import LlamaEmbeddingModel diff --git a/src/embed_text_package/models/deepseek.py b/src/embed_text_package/models/deepseek.py new file mode 100644 index 0000000..2b581a2 --- /dev/null +++ b/src/embed_text_package/models/deepseek.py @@ -0,0 +1,153 @@ +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.utils import is_pp_missing_parameter +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + + +class DeepseekEmbeddingModel(nn.Module, SupportsPP): + """ + A model that uses Llama with additional embedding functionalities. + + This class encapsulates the LlamaModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of LlamaModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + + self.model = DeepseekV2Model(config, cache_config, quant_config, prefix="model") + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + return self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + ) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + ) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") + and name not in params_dict + or name == "lm_head.weight" + ): + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) diff --git a/src/embed_text_package/models/llama.py b/src/embed_text_package/models/llama.py new file mode 100644 index 0000000..3d3d866 --- /dev/null +++ b/src/embed_text_package/models/llama.py @@ -0,0 +1,106 @@ +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from vllm.attention import AttentionMetadata +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsPP +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + + +class LlamaEmbeddingModel(nn.Module, SupportsPP): + """ + A model that uses Llama with additional embedding functionalities. + + This class encapsulates the LlamaModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of LlamaModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + + self.model = LlamaModel(**kwargs) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + return self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds, + ) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.model.named_parameters()) + + for name, loaded_weight in weights: + name = name.replace("model.", "") + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") + and name not in params_dict + or name == "lm_head.weight" + ): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/src/embed_text_package/utils.py b/src/embed_text_package/utils.py new file mode 100644 index 0000000..2edde6d --- /dev/null +++ b/src/embed_text_package/utils.py @@ -0,0 +1,14 @@ +from vllm import ModelRegistry +from vllm.model_executor.models.registry import _EMBEDDING_MODELS + +from .models import DeepseekEmbeddingModel, LlamaEmbeddingModel + +global _EMBEDDING_MODELS + + +def register_model(): + ModelRegistry.register_model("LlamaEmbModel", LlamaEmbeddingModel) + _EMBEDDING_MODELS["LlamaEmbModel"] = LlamaEmbeddingModel + + ModelRegistry.register_model("DeepseekEmbModel", DeepseekEmbeddingModel) + _EMBEDDING_MODELS["DeepseekEmbModel"] = DeepseekEmbeddingModel diff --git a/tests/test_embeds_v2.py b/tests/test_embeds_v2.py new file mode 100644 index 0000000..759f779 --- /dev/null +++ b/tests/test_embeds_v2.py @@ -0,0 +1,65 @@ +import itertools +import pickle + +from datasets import Dataset, load_dataset +from embed_text_package.embed_text_v2 import Embedder +from torch.utils.data import DataLoader +from tqdm import tqdm + +from vllm import LLM + +ds = load_dataset("stair-lab/questioin_difficulty", split="train") +model = "/lfs/local/0/nqduc/Llama-3.1-8B-embedding" + +if __name__ == "__main__": + embedder = Embedder() + embedder.load(model, enable_chunked_prefill=False, enforce_eager=True) + list_text = [] + list_score = [] + + # ds = Dataset.from_dict(ds[:10]) + + for sample in ds: + answer = [ + "Correct answer: " + sample["option_correct_ans"], + "Wrong answer 1: " + sample["option_distractor1"], + "Wrong answer 2: " + sample["option_distractor2"], + "Wrong answer 3: " + sample["option_distractor3"], + ] + answers = list(itertools.permutations(answer)) + # answers = answer[:2] + for answer in answers: + text = ( + sample["Passage"] + + "\n" + + sample["QuestionText"] + + "\n" + + answer[0] + + "\n" + + answer[1] + + "\n" + + answer[2] + + "\n" + + answer[3] + ) + list_text.append(text) + list_score.append(sample["pVal"]) + + combine_ds = Dataset.from_dict({"text": list_text}) + + ds_emb = ( + embedder.get_embeddings( + DataLoader(combine_ds, batch_size=8), + embedder.which_model, + ["text"], + ) + .data["text"] + .to_pylist() + ) + + new_ds = Dataset.from_dict({"embedding": ds_emb, "score": list_score}) + + with open("embedding.pkl", "wb") as f: + pickle.dump(new_ds, f) + + new_ds.push_to_hub("stair-lab/question_difficulty_embedded-full")