Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update sentence-transformers and allow setting trust_remote_code #379

Merged
merged 3 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions caikit_nlp/config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ training_data_limit:

# Config used only in EmbeddingModule. Set here or use env vars like EMBEDDING_RETRIES=32
embedding:
# Allow models with remote code.
trust_remote_code: false
# Number of times to retry on error. Most deployments should use 0 retries.
retries: 0
# Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used
Expand Down
42 changes: 40 additions & 2 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Dict,
Iterable,
List,
Literal,
NamedTuple,
Optional,
TypeVar,
Expand Down Expand Up @@ -82,6 +83,8 @@
sentence_transformers = importlib.import_module("sentence_transformers")
# Third Party
from sentence_transformers import SentenceTransformer
from sentence_transformers.model_card import SentenceTransformerModelCardData
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.util import batch_to_device, cos_sim, dot_score
from sentence_transformers.util import (
normalize_embeddings as normalize, # avoid parameter shadowing
Expand All @@ -107,6 +110,7 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
val=embedding_cfg.get("implicit_truncation_errors", True)
)
DEVICE = embedding_cfg.get("device", "")
TRUST_REMOTE_CODE = embedding_cfg.get("trust_remote_code")

RT = TypeVar("RT") # return type

Expand Down Expand Up @@ -183,7 +187,9 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule":
ipex = cls._get_ipex(IPEX)
device = cls._select_device(ipex, DEVICE)
model = SentenceTransformerWithTruncate(
model_name_or_path=artifacts_path, device=device
model_name_or_path=artifacts_path,
device=device,
trust_remote_code=TRUST_REMOTE_CODE,
)
model.eval() # required for IPEX at least
if device is not None:
Expand Down Expand Up @@ -719,7 +725,12 @@ def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule":
model_name_or_path: str
Model name (Hugging Face hub) or path to model to load.
"""
return cls(model=SentenceTransformer(model_name_or_path=model_name_or_path))
return cls(
model=SentenceTransformer(
model_name_or_path=model_name_or_path,
trust_remote_code=TRUST_REMOTE_CODE,
)
)

def save(self, model_path: str, *args, **kwargs):
"""Save model using config in model_path
Expand Down Expand Up @@ -875,21 +886,39 @@ def __init__(
model_name_or_path: Optional[str] = None,
modules: Optional[Iterable[nn.Module]] = None,
device: Optional[str] = None,
prompts: Optional[Dict[str, str]] = None,
default_prompt_name: Optional[str] = None,
similarity_fn_name: Optional[Union[str, SimilarityFunction]] = None,
cache_folder: Optional[str] = None,
trust_remote_code: bool = False,
revision: Optional[str] = None,
local_files_only: bool = False,
token: Optional[Union[bool, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
truncate_dim: Optional[int] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
model_card_data: Optional[SentenceTransformerModelCardData] = None,
):
super().__init__(
model_name_or_path,
modules,
device,
prompts,
default_prompt_name,
similarity_fn_name,
cache_folder,
trust_remote_code,
revision,
local_files_only,
token,
use_auth_token,
truncate_dim,
model_kwargs,
tokenizer_kwargs,
config_kwargs,
model_card_data,
)
self.tokenizers = {}

Expand Down Expand Up @@ -1014,9 +1043,12 @@ def _get_tokenized(self, texts):
def encode(
self,
sentences: Union[str, List[str]],
prompt_name: Optional[str] = None,
prompt: Optional[str] = None,
batch_size: int = 32,
show_progress_bar: bool = None,
output_value: str = "sentence_embedding",
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: str = None,
Expand All @@ -1029,9 +1061,12 @@ def encode(
Computes sentence embeddings

:param sentences: the sentences to embed
:param prompt_name: Ignored here. Added for compatibility with super API.
:param prompt: Ignored here. Added for compatibility with super API.
:param batch_size: the batch size used for the computation
:param show_progress_bar: Ignored here. Added for compatibility with super API.
:param output_value: Ignored here. Added for compatibility with super API.
:param precision: Ignored here. Added for compatibility with super API.
:param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list
of pytorch tensors.
:param convert_to_tensor: If true, you get one large tensor as return. Overwrites any
Expand All @@ -1057,8 +1092,11 @@ def encode(

# These args are for API compatability, but are currently ignored in our version of encode()
_ = (
prompt_name,
prompt,
show_progress_bar,
output_value,
precision,
normalize_embeddings,
)

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ dependencies = [
"pandas>=1.5.0",
"scikit-learn>=1.1",
"scipy>=1.8.1",
"sentence-transformers>=2.3.1,<2.4.0",
"sentence-transformers>=3.0.0,<3.1.0",
"tokenizers>=0.13.3",
"torch>=2.3.1,<2.4.0",
"tqdm>=4.65.0",
"transformers>=4.32.0",
"transformers>=4.38.0,<4.44.0",
"peft==0.6.0",
]

Expand Down
2 changes: 2 additions & 0 deletions runtime_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ model_management:

# Config used only in EmbeddingModule. Set here or use env vars like EMBEDDING_RETRIES=32
embedding:
# Allow models with remote code.
trust_remote_code: false
# Number of times to retry on error. Most deployments should use 0 retries.
retries: 0
# Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ passenv =
LOG_FORMATTER
LOG_THREAD_ID
LOG_CHANNEL_WIDTH
PYTORCH_ENABLE_MPS_FALLBACK
commands = pytest --durations=42 --cov=caikit_nlp --cov-report=term --cov-report=html {posargs:tests}

; Unclear: We probably want to test wheel packaging
Expand Down
Loading