Skip to content

Commit

Permalink
[fix] Fix loading pre-exported OV/ONNX model if export=False (#3036)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen authored Nov 6, 2024
1 parent 7d99ca9 commit cb81136
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
12 changes: 7 additions & 5 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _backend_should_export(
"""

export = model_args.pop("export", None)
if export is not None:
if export:
return export, model_args

file_name = model_args.get("file_name", target_file_name)
Expand Down Expand Up @@ -283,17 +283,19 @@ def _backend_should_export(
# First check if the expected file exists in the root of the model directory
# If it doesn't, check if it exists in the backend subfolder.
# If it does, set the subfolder to include the backend
export = primary_full_path not in model_file_names
if export and "subfolder" not in model_args:
export = secondary_full_path not in model_file_names
if not export:
model_found = primary_full_path in model_file_names
if not model_found and "subfolder" not in model_args:
model_found = secondary_full_path in model_file_names
if model_found:
if len(model_file_names) > 1 and "file_name" not in model_args:
logger.warning(
f"Multiple {backend_name} files found in {load_path.as_posix()!r}: {model_file_names}, defaulting to {secondary_full_path!r}. "
f'Please specify the desired file name via `model_kwargs={{"file_name": "<file_name>"}}`.'
)
model_args["subfolder"] = self.backend
model_args["file_name"] = file_name
if export is None:
export = not model_found

# If the file_name contains subfolders, set it as the subfolder instead
file_name_parts = Path(file_name).parts
Expand Down
35 changes: 35 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,38 @@ def test_openvino_backend() -> None:
), "OpenVINO saved model output differs from in-memory converted model"
del local_openvino_model
gc.collect()


def test_export_false_subfolder() -> None:
model_id = "sentence-transformers-testing/stsb-bert-tiny-openvino"

def from_pretrained_decorator(method):
def decorator(*args, **kwargs):
assert not kwargs["export"]
assert kwargs["subfolder"] == "openvino"
assert kwargs["file_name"] == "openvino_model.xml"
return method(*args, **kwargs)

return decorator

OVModelForFeatureExtraction.from_pretrained = from_pretrained_decorator(
OVModelForFeatureExtraction.from_pretrained
)
SentenceTransformer(model_id, backend="openvino", model_kwargs={"export": False})


def test_export_set_nested_filename() -> None:
model_id = "sentence-transformers-testing/stsb-bert-tiny-openvino"

def from_pretrained_decorator(method):
def decorator(*args, **kwargs):
assert kwargs["subfolder"] == "openvino"
assert kwargs["file_name"] == "openvino_model.xml"
return method(*args, **kwargs)

return decorator

OVModelForFeatureExtraction.from_pretrained = from_pretrained_decorator(
OVModelForFeatureExtraction.from_pretrained
)
SentenceTransformer(model_id, backend="openvino", model_kwargs={"file_name": "openvino/openvino_model.xml"})

0 comments on commit cb81136

Please sign in to comment.