From 9ccf39c66f670d14c2a91861372e3c3d0ffe88ae Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 12 Feb 2024 21:29:41 +0000 Subject: [PATCH 1/4] benchmark server pipeline --- .../benchmark/benchmark_pipeline.py | 61 +++++++++++++++++ src/deepsparse/benchmark/config.py | 34 ++++++++++ src/deepsparse/middlewares/middleware.py | 8 ++- src/deepsparse/pipeline.py | 2 +- src/deepsparse/pipeline_config.py | 4 ++ src/deepsparse/server/config.py | 4 ++ src/deepsparse/server/deepsparse_server.py | 67 +++++++++++++++++++ src/deepsparse/server/server.py | 16 +++++ 8 files changed, 194 insertions(+), 2 deletions(-) diff --git a/src/deepsparse/benchmark/benchmark_pipeline.py b/src/deepsparse/benchmark/benchmark_pipeline.py index b0453060ad..be9e209d79 100644 --- a/src/deepsparse/benchmark/benchmark_pipeline.py +++ b/src/deepsparse/benchmark/benchmark_pipeline.py @@ -401,6 +401,67 @@ def _get_statistics(batch_times): return sections, all_sections +def benchmark_from_pipeline( + pipeline: Pipeline, + batch_size: int = 1, + seconds_to_run: int = 10, + warmup_time: int = 2, + thread_pinning: str = "core", + scenario: str = "sync", + num_streams: int = 1, + data_type: str = "dummy", + **kwargs, +): + decide_thread_pinning(thread_pinning) + scenario = parse_scenario(scenario.lower()) + + input_type = data_type + + config = PipelineBenchmarkConfig( + data_type=data_type, + **kwargs, + ) + inputs = create_input_schema(pipeline, input_type, batch_size, config) + + def _clear_measurements(): + # Helper method to handle variations between v1 and v2 timers + if hasattr(pipeline.timer_manager, "clear"): + pipeline.timer_manager.clear() + else: + pipeline.timer_manager.measurements.clear() + + if scenario == "singlestream": + singlestream_benchmark(pipeline, inputs, warmup_time) + _clear_measurements() + start_time = time.perf_counter() + singlestream_benchmark(pipeline, inputs, seconds_to_run) + elif scenario == "multistream": + multistream_benchmark(pipeline, inputs, warmup_time, num_streams) + _clear_measurements() + start_time = time.perf_counter() + multistream_benchmark(pipeline, inputs, seconds_to_run, num_streams) + elif scenario == "elastic": + multistream_benchmark(pipeline, inputs, warmup_time, num_streams) + _clear_measurements() + start_time = time.perf_counter() + multistream_benchmark(pipeline, inputs, seconds_to_run, num_streams) + else: + raise Exception(f"Unknown scenario '{scenario}'") + + end_time = time.perf_counter() + total_run_time = end_time - start_time + if hasattr(pipeline.timer_manager, "all_times"): + batch_times = pipeline.timer_manager.all_times + else: + batch_times = pipeline.timer_manager.measurements + if len(batch_times) == 0: + raise Exception( + "Generated no batch timings, try extending benchmark time with '--time'" + ) + + return batch_times, total_run_time, num_streams + + @click.command() @click.argument("task_name", type=str) @click.argument("model_path", type=str) diff --git a/src/deepsparse/benchmark/config.py b/src/deepsparse/benchmark/config.py index e1cf9c1972..73173a02c0 100644 --- a/src/deepsparse/benchmark/config.py +++ b/src/deepsparse/benchmark/config.py @@ -83,3 +83,37 @@ class PipelineBenchmarkConfig(BaseModel): default={}, description=("Additional arguments passed to input schema creations "), ) + + +class PipelineBenchmarkServerConfig(PipelineBenchmarkConfig): + batch_size: int = Field( + default=1, + description="The batch size of the inputs to be used with the engine", + ) + seconds_to_run: int = Field( + default=10, + description="The number of seconds to run benchmark for", + ) + warmup_time: int = Field( + default=2, + description="The length to run pipeline before beginning benchmark", + ) + thread_pinning: str = Field( + default="core", + description="To enable binding threads to cores", + ) + scenario: str = Field( + default="sync", + description=( + "`BenchmarkScenario` object with specification for running " + "benchmark on an onnx model" + ), + ) + num_streams: int = Field( + default=1, + description=( + " The max number of requests the model can handle " + "concurrently. None or 0 implies a scheduler-defined default value; " + "default None" + ), + ) diff --git a/src/deepsparse/middlewares/middleware.py b/src/deepsparse/middlewares/middleware.py index ee1dc57eaa..0adf2cdd37 100644 --- a/src/deepsparse/middlewares/middleware.py +++ b/src/deepsparse/middlewares/middleware.py @@ -112,7 +112,9 @@ class MiddlewareManager: :param _lock: lock for the state """ - def __init__(self, middleware: Optional[Sequence[MiddlewareSpec]], *args, **kwargs): + def __init__( + self, middleware: Optional[Sequence[MiddlewareSpec]] = None, *args, **kwargs + ): self.middleware: Optional[ Sequence[MiddlewareSpec] @@ -172,3 +174,7 @@ def _update_middleware_spec_send( next_middleware.send = self.recieve self.middleware.append(MiddlewareSpec(next_middleware, **init_args)) + + @property + def middlewares(self): + return [middleware.cls for middleware in self.middleware] diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index e2a1beeab1..f2992beef2 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -91,7 +91,7 @@ def __init__( self.schedulers = schedulers self.pipeline_state = pipeline_state self._continuous_batching_scheduler = continuous_batching_scheduler - self.middleware_manager = middleware_manager + self.middleware_manager = middleware_manager or MiddlewareManager() self.timer_manager = timer_manager or TimerManager() self.validate() diff --git a/src/deepsparse/pipeline_config.py b/src/deepsparse/pipeline_config.py index 2966f99d7d..3af9a82dd8 100644 --- a/src/deepsparse/pipeline_config.py +++ b/src/deepsparse/pipeline_config.py @@ -73,6 +73,10 @@ class PipelineConfig(BaseModel): "with multiple models. Default is None" ), ) + middlewares: Optional[List[str]] = Field( + default=None, + description="Middlewares to use", + ) kwargs: Optional[Dict[str, Any]] = Field( default={}, description=( diff --git a/src/deepsparse/server/config.py b/src/deepsparse/server/config.py index a42eb00059..d628120faf 100644 --- a/src/deepsparse/server/config.py +++ b/src/deepsparse/server/config.py @@ -127,6 +127,9 @@ class EndpointConfig(BaseModel): "```\n" ), ) + middlewares: Optional[List[str]] = Field( + default=None, description=("Middleware to use") + ) kwargs: Dict[str, Any] = Field( default={}, description="Additional arguments to pass to the Pipeline" @@ -147,6 +150,7 @@ def to_pipeline_config(self) -> PipelineConfig: num_cores=None, # this will be set from Context alias=self.name, input_shapes=input_shapes, + middlewares=self.middlewares, kwargs=kwargs, ) diff --git a/src/deepsparse/server/deepsparse_server.py b/src/deepsparse/server/deepsparse_server.py index 8ffc7508cb..7e3ca2aec1 100644 --- a/src/deepsparse/server/deepsparse_server.py +++ b/src/deepsparse/server/deepsparse_server.py @@ -16,6 +16,7 @@ from functools import partial from deepsparse import Pipeline +from deepsparse.middlewares import MiddlewareSpec, TimerMiddleware from deepsparse.server.config import EndpointConfig from deepsparse.server.server import CheckReady, ModelMetaData, ProxyPipeline, Server from deepsparse.tasks import SupportedTasks @@ -86,6 +87,11 @@ def _add_endpoint( endpoint_config, pipeline, ) + self._add_benchmark_endpoints( + app, + endpoint_config, + pipeline, + ) self._add_status_and_metadata_endpoints(app, endpoint_config, pipeline) def _add_status_and_metadata_endpoints( @@ -180,3 +186,64 @@ def _add_inference_endpoints( methods=["POST"], tags=["model", "inference"], ) + + def _add_benchmark_endpoints( + self, + app: FastAPI, + endpoint_config: EndpointConfig, + pipeline: Pipeline, + ): + if TimerMiddleware not in pipeline.middleware_manager.middlewares: + pipeline.middleware_manager.add_middleware( + [MiddlewareSpec(TimerMiddleware)] + ) + + routes_and_fns = [] + if endpoint_config.route: + endpoint_config.route = self.clean_up_route(endpoint_config.route) + route = f"/v2/models{endpoint_config.route}/benchmark" + else: + route = f"/v2/models/{endpoint_config.name}/benchmark" + + routes_and_fns.append( + ( + route, + partial( + Server.benchmark, + ProxyPipeline(pipeline), + self.server_config.system_logging, + ), + ) + ) + + legacy_pipeline = not isinstance(pipeline, Pipeline) and hasattr( + pipeline.input_schema, "from_files" + ) + # New pipelines do not have to have an input_schema. Just checking task + # names for now but can keep a list of supported from_files tasks in + # SupportedTasks as more pipelines are migrated as well as output schemas. + new_pipeline = SupportedTasks.is_image_classification(endpoint_config.task) + + if legacy_pipeline or new_pipeline: + routes_and_fns.append( + ( + route + "/from_files", + partial( + Server.predict_from_files, + ProxyPipeline(pipeline), + self.server_config.system_logging, + ), + ) + ) + if isinstance(pipeline, Pipeline): + response_model = None + else: + response_model = pipeline.output_schema + + self._update_routes( + app=app, + routes_and_fns=routes_and_fns, + response_model=response_model, + methods=["POST"], + tags=["model", "pipeline"], + ) diff --git a/src/deepsparse/server/server.py b/src/deepsparse/server/server.py index 3c1cb053f7..175d643190 100644 --- a/src/deepsparse/server/server.py +++ b/src/deepsparse/server/server.py @@ -24,6 +24,8 @@ from pydantic import BaseModel import uvicorn +from deepsparse.benchmark.benchmark_pipeline import benchmark_from_pipeline +from deepsparse.benchmark.config import PipelineBenchmarkConfig from deepsparse.engine import Context from deepsparse.pipeline import Pipeline from deepsparse.server.config import ServerConfig, SystemLoggingConfig @@ -274,6 +276,20 @@ async def format_response(): return prep_for_serialization(pipeline_outputs) + @staticmethod + async def benchmark( + proxy_pipeline: ProxyPipeline, + system_logging_config: SystemLoggingConfig, + raw_request: Request, + ): + json_params = await raw_request.json() + benchmark_config = PipelineBenchmarkConfig(**json_params) + results = benchmark_from_pipeline( + pipeline=proxy_pipeline.pipeline, **benchmark_config.dict() + ) + + return results + @staticmethod async def predict_from_files( proxy_pipeline: ProxyPipeline, From 0a6e5d42ae906d2f70e0efdffa9276d157b6c95d Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Tue, 13 Feb 2024 18:05:41 +0000 Subject: [PATCH 2/4] pass tests --- src/deepsparse/pipeline.py | 2 +- src/deepsparse/server/deepsparse_server.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index f2992beef2..e2a1beeab1 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -91,7 +91,7 @@ def __init__( self.schedulers = schedulers self.pipeline_state = pipeline_state self._continuous_batching_scheduler = continuous_batching_scheduler - self.middleware_manager = middleware_manager or MiddlewareManager() + self.middleware_manager = middleware_manager self.timer_manager = timer_manager or TimerManager() self.validate() diff --git a/src/deepsparse/server/deepsparse_server.py b/src/deepsparse/server/deepsparse_server.py index 7e3ca2aec1..0fadd96630 100644 --- a/src/deepsparse/server/deepsparse_server.py +++ b/src/deepsparse/server/deepsparse_server.py @@ -16,7 +16,7 @@ from functools import partial from deepsparse import Pipeline -from deepsparse.middlewares import MiddlewareSpec, TimerMiddleware +from deepsparse.middlewares import MiddlewareManager, MiddlewareSpec, TimerMiddleware from deepsparse.server.config import EndpointConfig from deepsparse.server.server import CheckReady, ModelMetaData, ProxyPipeline, Server from deepsparse.tasks import SupportedTasks @@ -193,6 +193,8 @@ def _add_benchmark_endpoints( endpoint_config: EndpointConfig, pipeline: Pipeline, ): + if not hasattr(pipeline, "middleware_mamanger"): + pipeline.middleware_manager = MiddlewareManager() if TimerMiddleware not in pipeline.middleware_manager.middlewares: pipeline.middleware_manager.add_middleware( [MiddlewareSpec(TimerMiddleware)] From 628d4f18cbbbb4d64b09f544102c38d6ad5e02f5 Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 4 Mar 2024 15:28:25 +0000 Subject: [PATCH 3/4] comments --- .../benchmark/benchmark_pipeline.py | 71 +++++++--------- src/deepsparse/utils/imports.py | 42 ++++++++++ tests/server/test_benchmark.py | 83 +++++++++++++++++++ 3 files changed, 157 insertions(+), 39 deletions(-) create mode 100644 src/deepsparse/utils/imports.py create mode 100644 tests/server/test_benchmark.py diff --git a/src/deepsparse/benchmark/benchmark_pipeline.py b/src/deepsparse/benchmark/benchmark_pipeline.py index be9e209d79..4dcc623292 100644 --- a/src/deepsparse/benchmark/benchmark_pipeline.py +++ b/src/deepsparse/benchmark/benchmark_pipeline.py @@ -308,45 +308,16 @@ def benchmark_pipeline( num_streams=num_streams, **kwargs, ) - inputs = create_input_schema(pipeline, input_type, batch_size, config) - - def _clear_measurements(): - # Helper method to handle variations between v1 and v2 timers - if hasattr(pipeline.timer_manager, "clear"): - pipeline.timer_manager.clear() - else: - pipeline.timer_manager.measurements.clear() - - if scenario == "singlestream": - singlestream_benchmark(pipeline, inputs, warmup_time) - _clear_measurements() - start_time = time.perf_counter() - singlestream_benchmark(pipeline, inputs, seconds_to_run) - elif scenario == "multistream": - multistream_benchmark(pipeline, inputs, warmup_time, num_streams) - _clear_measurements() - start_time = time.perf_counter() - multistream_benchmark(pipeline, inputs, seconds_to_run, num_streams) - elif scenario == "elastic": - multistream_benchmark(pipeline, inputs, warmup_time, num_streams) - _clear_measurements() - start_time = time.perf_counter() - multistream_benchmark(pipeline, inputs, seconds_to_run, num_streams) - else: - raise Exception(f"Unknown scenario '{scenario}'") - - end_time = time.perf_counter() - total_run_time = end_time - start_time - if hasattr(pipeline.timer_manager, "all_times"): - batch_times = pipeline.timer_manager.all_times - else: - batch_times = pipeline.timer_manager.measurements - if len(batch_times) == 0: - raise Exception( - "Generated no batch timings, try extending benchmark time with '--time'" - ) - - return batch_times, total_run_time, num_streams + return run( + pipeline, + input_type, + batch_size, + config, + warmup_time, + seconds_to_run, + num_streams, + scenario, + ) def calculate_statistics( @@ -421,6 +392,28 @@ def benchmark_from_pipeline( data_type=data_type, **kwargs, ) + return run( + pipeline, + input_type, + batch_size, + config, + warmup_time, + seconds_to_run, + num_streams, + scenario, + ) + + +def run( + pipeline: Pipeline, + input_type: str, + batch_size: int, + config: PipelineBenchmarkConfig, + warmup_time: int, + seconds_to_run: int, + num_streams: int, + scenario: str, +): inputs = create_input_schema(pipeline, input_type, batch_size, config) def _clear_measurements(): diff --git a/src/deepsparse/utils/imports.py b/src/deepsparse/utils/imports.py new file mode 100644 index 0000000000..70400dde86 --- /dev/null +++ b/src/deepsparse/utils/imports.py @@ -0,0 +1,42 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import re +from typing import Any, Type + + +def import_from_path(path: str) -> Type[Any]: + """ + Import the module and the name of the function/class separated by : + Examples: + path = "/path/to/file.py:func_name" + path = "/path/to/file:class_name" + :param path: path including the file path and object name + :return Function or class object + """ + path, class_name = path.split(":") + _path = path + + path = path.split(".py")[0] + path = re.sub(r"/+", ".", path) + try: + module = importlib.import_module(path) + except ImportError: + raise ImportError(f"Cannot find module with path {_path}") + + try: + return getattr(module, class_name) + except AttributeError: + raise AttributeError(f"Cannot find {class_name} in {_path}") diff --git a/tests/server/test_benchmark.py b/tests/server/test_benchmark.py new file mode 100644 index 0000000000..8d10399566 --- /dev/null +++ b/tests/server/test_benchmark.py @@ -0,0 +1,83 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List + +import pytest +from deepsparse.server.config import EndpointConfig, ServerConfig +from deepsparse.server.deepsparse_server import DeepsparseServer +from fastapi.testclient import TestClient + + +TEST_MODEL_ID = "hf:mgoin/TinyStories-1M-ds" + + +@pytest.fixture(scope="module") +def endpoint_config(): + endpoint = EndpointConfig( + task="text_generation", model=TEST_MODEL_ID, middlewares=["TimerMiddleware"] + ) + return endpoint + + +@pytest.fixture(scope="module") +def server_config(endpoint_config): + server_config = ServerConfig( + num_cores=1, num_workers=1, endpoints=[endpoint_config], loggers={} + ) + + return server_config + + +@pytest.fixture(scope="module") +def server(server_config): + server = DeepsparseServer(server_config=server_config) + return server + + +@pytest.fixture(scope="module") +def app(server): + app = server._build_app() + return app + + +@pytest.fixture(scope="module") +def client(app): + return TestClient(app) + + +def test_benchmark_pipeline(client): + url = "v2/models/text_generation-0/benchmark" + obj = { + "data_type": "dummy", + "gen_sequence_length": 100, + "pipeline_kwargs": {}, + "input_schema_kwargs": {}, + } + response = client.post(url, json=obj) + response.raise_for_status() + + response_json: List[Dict] = response.json() + + # iterate over all benchmarks that are Lists + # The final layer is a dict, where + # key is the name of the Operator, value is the List of timings + # Ex. {'CompileGeneratedTokens': [1.7404556274414062e-05, ... } + timings = response_json[0] + for timing in timings: + for key, values in timing.items(): + assert isinstance(key, str) + assert isinstance(values, List) + for run_time in values: + assert isinstance(run_time, float) From a425828544aea7a9a9273e9668f14a69b091226f Mon Sep 17 00:00:00 2001 From: George Ohashi Date: Mon, 4 Mar 2024 15:56:36 +0000 Subject: [PATCH 4/4] raise watning using middleware + cont.batching.sched --- src/deepsparse/pipeline.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 1c0c324f24..fe74798d0a 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -499,6 +499,16 @@ def validate(self): elif isinstance(router_validation, str): raise ValueError(f"Invalid Router for operators: {router_validation}") + if ( + self.middleware_manager is not None + and self._continuous_batching_scheduler is not None + ): + _LOGGER.warning( + "Middleware is yet to be supported using continous batching scheduler. " + "Either remove middleware or remove continous batching scheduler " + "in the instantiation of the Pipeline class" + ) + def run_func( self, *args,