Skip to content

Commit

Permalink
Pass arguments for pipeline creation properly to deepsparse.evaluate (#…
Browse files Browse the repository at this point in the history
…1624)

* initial commit

* fix tests

* Update src/deepsparse/evaluation/utils.py

* quality
  • Loading branch information
dbogunowicz authored Mar 7, 2024
1 parent acf190c commit 4e791ee
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
7 changes: 4 additions & 3 deletions src/deepsparse/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ def evaluate(

# if target is a string, turn it into an appropriate pipeline
# otherwise assume it is a pipeline
pipeline = (
create_pipeline(model, engine_type) if isinstance(model, (Path, str)) else model
)
if isinstance(model, (Path, str)):
pipeline, kwargs = create_pipeline(model, engine_type, **kwargs)
else:
pipeline = model

eval_integration = EvaluationRegistry.resolve(pipeline, datasets, integration)

Expand Down
19 changes: 11 additions & 8 deletions src/deepsparse/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,17 @@ def create_pipeline(
text generation model from. This can be a local
or remote path to the model or a sparsezoo stub
:param engine_type: The engine type to initialize the model with.
:return: The initialized pipeline
:return: The initialized pipeline and the mutated
(potentially reduced number of) kwargs
"""
engine_type = engine_type or DEEPSPARSE_ENGINE
return Pipeline.create(
task=kwargs.pop("task", "text-generation"),
model_path=model_path,
sequence_length=kwargs.pop("sequence_length", 2048),
engine_type=engine_type,
batch_size=kwargs.pop("batch_size", 1),
**kwargs,
return (
Pipeline.create(
task=kwargs.pop("task", "text-generation"),
model_path=model_path,
sequence_length=kwargs.pop("sequence_length", 2048),
engine_type=engine_type,
batch_size=kwargs.pop("batch_size", 1),
),
kwargs,
)
4 changes: 2 additions & 2 deletions tests/deepsparse/evaluation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def pipeline_target():


def test_initialize_model_from_target_pipeline_onnx(pipeline_target):
model = create_pipeline(pipeline_target, "onnxruntime")
model, _ = create_pipeline(pipeline_target, "onnxruntime")
assert model.ops.get("single_engine")._engine_type == "onnxruntime"


def test_initialize_model_from_target_pipeline_with_kwargs(pipeline_target):
model = create_pipeline(pipeline_target, "deepsparse", sequence_length=64)
model, _ = create_pipeline(pipeline_target, "deepsparse", sequence_length=64)
assert model.ops.get("process_input").sequence_length == 64

0 comments on commit 4e791ee

Please sign in to comment.