Skip to content

Commit

Permalink
[TextGeneration] Add helper function to parse model path from args (#…
Browse files Browse the repository at this point in the history
…1583)

* add helper function to parse model path from args

* update model path

* revert cli changes

* remove empty args
  • Loading branch information
dsikka authored Feb 7, 2024
1 parent 8a83e24 commit c54461a
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/deepsparse/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,23 +702,26 @@ def text_generation_pipeline(*args, **kwargs) -> "Pipeline":
:return: text generation pipeline with the given args and
kwargs passed to Pipeline.create
"""
return Pipeline.create("text_generation", *args, **kwargs)
kwargs = _check_model_path_arg(*args, **kwargs)
return Pipeline.create("text_generation", **kwargs)


def code_generation_pipeline(*args, **kwargs) -> "Pipeline":
"""
:return: text generation pipeline with the given args and
kwargs passed to Pipeline.create
"""
return Pipeline.create("code_generation", *args, **kwargs)
kwargs = _check_model_path_arg(*args, **kwargs)
return Pipeline.create("code_generation", **kwargs)


def chat_pipeline(*args, **kwargs) -> "Pipeline":
"""
:return: text generation pipeline with the given args and
kwargs passed to Pipeline.create
"""
return Pipeline.create("chat", *args, **kwargs)
kwargs = _check_model_path_arg(*args, **kwargs)
return Pipeline.create("chat", **kwargs)


TextGeneration = text_generation_pipeline
Expand Down Expand Up @@ -802,3 +805,13 @@ def zero_shot_text_classification_pipeline(*args, **kwargs) -> "Pipeline":
is returned depends on the value of the passed model_scheme argument.
"""
return Pipeline.create("zero_shot_text_classification", *args, **kwargs)


def _check_model_path_arg(*args, **kwargs):
if args:
if len(args) > 1 or "model_path" in kwargs or "model" in kwargs:
raise ValueError(
"Only the model path can be provided as a non-kwarg argument"
)
kwargs["model_path"] = args[0]
return kwargs

0 comments on commit c54461a

Please sign in to comment.