From c54461a9ea0832b8c76366400b21fe73baebd3d4 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 7 Feb 2024 13:08:24 -0500 Subject: [PATCH 1/2] [TextGeneration] Add helper function to parse model path from args (#1583) * add helper function to parse model path from args * update model path * revert cli changes * remove empty args --- src/deepsparse/pipeline.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index aaa65409d8..23ff3a2810 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -702,7 +702,8 @@ def text_generation_pipeline(*args, **kwargs) -> "Pipeline": :return: text generation pipeline with the given args and kwargs passed to Pipeline.create """ - return Pipeline.create("text_generation", *args, **kwargs) + kwargs = _check_model_path_arg(*args, **kwargs) + return Pipeline.create("text_generation", **kwargs) def code_generation_pipeline(*args, **kwargs) -> "Pipeline": @@ -710,7 +711,8 @@ def code_generation_pipeline(*args, **kwargs) -> "Pipeline": :return: text generation pipeline with the given args and kwargs passed to Pipeline.create """ - return Pipeline.create("code_generation", *args, **kwargs) + kwargs = _check_model_path_arg(*args, **kwargs) + return Pipeline.create("code_generation", **kwargs) def chat_pipeline(*args, **kwargs) -> "Pipeline": @@ -718,7 +720,8 @@ def chat_pipeline(*args, **kwargs) -> "Pipeline": :return: text generation pipeline with the given args and kwargs passed to Pipeline.create """ - return Pipeline.create("chat", *args, **kwargs) + kwargs = _check_model_path_arg(*args, **kwargs) + return Pipeline.create("chat", **kwargs) TextGeneration = text_generation_pipeline @@ -802,3 +805,13 @@ def zero_shot_text_classification_pipeline(*args, **kwargs) -> "Pipeline": is returned depends on the value of the passed model_scheme argument. """ return Pipeline.create("zero_shot_text_classification", *args, **kwargs) + + +def _check_model_path_arg(*args, **kwargs): + if args: + if len(args) > 1 or "model_path" in kwargs or "model" in kwargs: + raise ValueError( + "Only the model path can be provided as a non-kwarg argument" + ) + kwargs["model_path"] = args[0] + return kwargs From 2e33e673cf897ee93689a6ade070b40503ee8df3 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 7 Feb 2024 14:24:22 -0500 Subject: [PATCH 2/2] [server] Add `model` argument to server cli (#1584) * update model path to be an argument; remove unused openai command pathway * add model path arg and option --- src/deepsparse/server/cli.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/src/deepsparse/server/cli.py b/src/deepsparse/server/cli.py index 5eacc748a0..acd7b6897c 100644 --- a/src/deepsparse/server/cli.py +++ b/src/deepsparse/server/cli.py @@ -79,6 +79,7 @@ ), ) +MODEL_ARG = click.argument("model", type=str, default=None, required=False) MODEL_OPTION = click.option( "--model_path", type=str, @@ -152,6 +153,7 @@ @PORT_OPTION @LOG_LEVEL_OPTION @HOT_RELOAD_OPTION +@MODEL_ARG @MODEL_OPTION @BATCH_OPTION @CORES_OPTION @@ -167,6 +169,7 @@ def main( log_level: str, hot_reload_config: bool, model_path: str, + model: str, batch_size: int, num_cores: int, num_workers: int, @@ -216,6 +219,17 @@ def main( ... ``` """ + # the server cli can take a model argument or --model_path option + # if the --model_path option is provided, use that + # otherwise if the argument is given and --model_path is not used, use the + # argument instead + if model and model_path == "default": + model_path = model + + if integration == INTEGRATION_OPENAI: + if task is None or task != "text_generation": + task = "text_generation" + if ctx.invoked_subcommand is not None: return @@ -254,24 +268,6 @@ def main( server.start_server(host, port, log_level, hot_reload_config=hot_reload_config) -@main.command( - context_settings=dict( - token_normalize_func=lambda x: x.replace("-", "_"), show_default=True - ), -) -@click.argument("config-file", type=str) -@HOST_OPTION -@PORT_OPTION -@LOG_LEVEL_OPTION -@HOT_RELOAD_OPTION -def openai( - config_file: str, host: str, port: int, log_level: str, hot_reload_config: bool -): - - server = OpenAIServer(server_config=config_file) - server.start_server(host, port, log_level, hot_reload_config=hot_reload_config) - - @main.command( context_settings=dict( token_normalize_func=lambda x: x.replace("-", "_"), show_default=True