From cb46dcf8368803ac57a0edd2b1f05ec353509372 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 19 Jun 2023 20:05:18 +0000 Subject: [PATCH 01/14] initial refactor --- src/deepsparse/pipeline.py | 595 ++++++++++++++++++++----------------- 1 file changed, 329 insertions(+), 266 deletions(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 4f21ab54aa..9fb756d1dd 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -49,6 +49,7 @@ "DEEPSPARSE_ENGINE", "ORT_ENGINE", "SUPPORTED_PIPELINE_ENGINES", + "BasePipeline", "Pipeline", "PipelineConfig", "question_answering_pipeline", @@ -70,7 +71,284 @@ _REGISTERED_PIPELINES = {} -class Pipeline(ABC): +class BasePipeline(ABC): + """ + :param alias: optional name to give this pipeline instance, useful when + inferencing with multiple models. Default is None + :param logger: An optional item that can be either a DeepSparse Logger object, + or an object that can be transformed into one. Those object can be either + a path to the logging config, or yaml string representation the logging + config. If logger provided (in any form), the pipeline will log inference + metrics to the logger. Default is None + + """ + + def __init__( + self, + alias: Optional[str] = None, + logger: Optional[Union[BaseLogger, str]] = None, + ): + + self._alias = alias + self.logger = ( + logger + if isinstance(logger, BaseLogger) + else ( + logger_from_config( + config=logger, pipeline_identifier=self._identifier() + ) + if isinstance(logger, str) + else None + ) + ) + + @staticmethod + def _get_task_constructor(task: str) -> Type["BasePipeline"]: + """ + This function retrieves the class previously registered via `Pipeline.register` + for `task`. + + If `task` starts with "import:", it is treated as a module to be imported, + and retrieves the task via the `TASK` attribute of the imported module. + + If `task` starts with "custom", then it is mapped to the "custom" task. + + :param task: The task name to get the constructor for + :return: The class registered to `task` + :raises ValueError: if `task` was not registered via `Pipeline.register`. + """ + if task.startswith("import:"): + # dynamically import the task from a file + task = dynamic_import_task(module_or_path=task.replace("import:", "")) + elif task.startswith("custom"): + # support any task that has "custom" at the beginning via the "custom" task + task = "custom" + else: + task = task.lower().replace("-", "_") + + # extra step to register pipelines for a given task domain + # for cases where imports should only happen once a user specifies + # that domain is to be used. (ie deepsparse.transformers will auto + # install extra packages so should only import and register once a + # transformers task is specified) + SupportedTasks.check_register_task(task, _REGISTERED_PIPELINES.keys()) + + if task not in _REGISTERED_PIPELINES: + raise ValueError( + f"Unknown Pipeline task {task}. Pipeline tasks should be " + "must be declared with the Pipeline.register decorator. Currently " + f"registered pipelines: {list(_REGISTERED_PIPELINES.keys())}" + ) + + return _REGISTERED_PIPELINES[task] + + @staticmethod + def create( + task: str, + **kwargs, + ) -> "BasePipeline": + """ + :param task: name of task to create a pipeline for. Use "custom" for + custom tasks (see `CustomTaskPipeline`). + :param model_path: path on local system or SparseZoo stub to load the model + from. Some tasks may have a default model path + :param engine_type: inference engine to use. Currently supported values + include 'deepsparse' and 'onnxruntime'. Default is 'deepsparse' + :param batch_size: static batch size to use for inference. Default is 1 + :param num_cores: number of CPU cores to allocate for inference engine. None + specifies all available cores. Default is None + :param scheduler: (deepsparse only) kind of scheduler to execute with. + Pass None for the default + :param input_shapes: list of shapes to set ONNX the inputs to. Pass None + to use model as-is. Default is None + :param alias: optional name to give this pipeline instance, useful when + inferencing with multiple models. Default is None + :param context: Optional Context object to use for creating instances of + MultiModelEngine. The Context contains a shared scheduler along with + other runtime information that will be used across instances of the + MultiModelEngine to provide optimal performance when running + multiple models concurrently + :param kwargs: extra task specific kwargs to be passed to task Pipeline + implementation + :return: pipeline object initialized for the given task + """ + pipeline_constructor = BasePipeline._get_task_constructor(task) + model_path = kwargs.get("model_path", None) + + if issubclass(pipeline_constructor, Pipeline): + if ( + (model_path is None or model_path == "default") + and hasattr(pipeline_constructor, "default_model_path") + and pipeline_constructor.default_model_path + ): + model_path = pipeline_constructor.default_model_path + + if model_path is None: + raise ValueError( + f"No model_path provided for pipeline {pipeline_constructor}. Must " + "provide a model path for pipelines that do not have a default " + "defined" + ) + + kwargs["model_path"] = model_path + + if issubclass( + pipeline_constructor, Bucketable + ) and pipeline_constructor.should_bucket(**kwargs): + if kwargs.get("input_shape", None): + raise ValueError( + "Overriding input shapes not supported with Bucketing enabled" + ) + if not kwargs.get("context", None): + context = Context(num_cores=kwargs["num_cores"]) + kwargs["context"] = context + buckets = pipeline_constructor.create_pipeline_buckets( + task=task, + **kwargs, + ) + return BucketingPipeline(pipelines=buckets) + + return pipeline_constructor(**kwargs) + + @classmethod + def register( + cls, + task: str, + task_aliases: Optional[List[str]] = None, + default_model_path: Optional[str] = None, + ): + """ + Pipeline implementer class decorator that registers the pipeline + task name and its aliases as valid tasks that can be used to load + the pipeline through `BasePipeline.create()`. + + Multiple pipelines may not have the same task name. An error will + be raised if two different pipelines attempt to register the same task name + + :param task: main task name of this pipeline + :param task_aliases: list of extra task names that may be used to reference + this pipeline. Default is None + :param default_model_path: path (ie zoo stub) to use as default for this + task if None is provided + """ + task_names = [task] + if task_aliases: + task_names.extend(task_aliases) + + task_names = [task_name.lower().replace("-", "_") for task_name in task_names] + + def _register_task(task_name, pipeline_class): + if task_name in _REGISTERED_PIPELINES and ( + pipeline_class is not _REGISTERED_PIPELINES[task_name] + ): + raise RuntimeError( + f"task {task_name} already registered by BasePipeline.register. " + f"attempting to register pipeline: {pipeline_class}, but" + f"pipeline: {_REGISTERED_PIPELINES[task_name]}, already registered" + ) + _REGISTERED_PIPELINES[task_name] = pipeline_class + + def _register_pipeline_tasks_decorator( + pipeline_class: Union[BasePipeline, Pipeline] + ): + if not issubclass(pipeline_class, cls): + raise RuntimeError( + f"Attempting to register pipeline {pipeline_class}. " + f"Registered pipelines must inherit from {cls}" + ) + for task_name in task_names: + _register_task(task_name, pipeline_class) + + # set task and task_aliases as class level property + # leave default_model_path for now as is optional? + pipeline_class.task = task + pipeline_class.task_aliases = task_aliases + pipeline_class.default_model_path = default_model_path + + return pipeline_class + + return _register_pipeline_tasks_decorator + + @property + @abstractmethod + def input_schema(self) -> Type[BaseModel]: + """ + :return: pydantic model class that inputs to this pipeline must comply to + """ + raise NotImplementedError() + + @property + @abstractmethod + def output_schema(self) -> Type[BaseModel]: + """ + :return: pydantic model class that outputs of this pipeline must comply to + """ + raise NotImplementedError() + + @property + def alias(self) -> str: + """ + :return: optional name to give this pipeline instance, useful when + inferencing with multiple models + """ + return self._alias + + def log( + self, + identifier: str, + value: Any, + category: str, + ): + """ + Pass the logged data to the DeepSparse logger object (if present). + + :param identifier: The string name assigned to the logged value + :param value: The logged data structure + :param category: The metric category that the log belongs to + """ + if not self.logger: + return + + identifier = f"{self._identifier()}/{identifier}" + validate_identifier(identifier) + self.logger.log( + identifier=identifier, + value=value, + category=category, + pipeline_name=self._identifier(), + ) + return + + def parse_inputs(self, *args, **kwargs) -> BaseModel: + """ + :param args: ordered arguments to pipeline, only an input_schema object + is supported as an arg for this function + :param kwargs: keyword arguments to pipeline + :return: pipeline arguments parsed into the given `input_schema` + schema if necessary. If an instance of the `input_schema` is provided + it will be returned + """ + # passed input_schema schema directly + if len(args) == 1 and isinstance(args[0], self.input_schema) and not kwargs: + return args[0] + + if args: + raise ValueError( + f"pipeline {self.__class__} only supports either only a " + f"{self.input_schema} object. or keyword arguments to be construct " + f"one. Found {len(args)} args and {len(kwargs)} kwargs" + ) + + return self.input_schema(**kwargs) + + def _identifier(self): + # get pipeline identifier; used in the context of logging + if not hasattr(self, "task"): + self.task = None + return f"{self.alias or self.task or 'unknown_pipeline'}" + + +class Pipeline(BasePipeline): """ Generic Pipeline abstract class meant to wrap inference engine objects to include data pre/post-processing. Inputs and outputs of pipelines should be serialized @@ -128,8 +406,6 @@ class PipelineImplementation(Pipeline): Pass None for the default :param input_shapes: list of shapes to set ONNX the inputs to. Pass None to use model as-is. Default is None - :param alias: optional name to give this pipeline instance, useful when - inferencing with multiple models. Default is None :param context: Optional Context object to use for creating instances of MultiModelEngine. The Context contains a shared scheduler along with other runtime information that will be used across instances of the @@ -143,11 +419,6 @@ class PipelineImplementation(Pipeline): synchronous execution - if running in dynamic batch mode a default ThreadPoolExecutor with default workers equal to the number of available cores / 2 - :param logger: An optional item that can be either a DeepSparse Logger object, - or an object that can be transformed into one. Those object can be either - a path to the logging config, or yaml string representation the logging - config. If logger provided (in any form), the pipeline will log inference - metrics to the logger. Default is None """ def __init__( @@ -158,32 +429,18 @@ def __init__( num_cores: int = None, scheduler: Scheduler = None, input_shapes: List[List[int]] = None, - alias: Optional[str] = None, context: Optional[Context] = None, executor: Optional[Union[ThreadPoolExecutor, int]] = None, - logger: Optional[Union[BaseLogger, str]] = None, benchmark: bool = False, - _delay_engine_initialize: bool = False, # internal use only + _delay_engine_initialize: bool = False, ): self._benchmark = benchmark self._model_path_orig = model_path self._model_path = model_path self._engine_type = engine_type self._batch_size = batch_size - self._alias = alias self._timer_manager = TimerManager(enabled=True, multi=benchmark) self.context = context - self.logger = ( - logger - if isinstance(logger, BaseLogger) - else ( - logger_from_config( - config=logger, pipeline_identifier=self._identifier() - ) - if isinstance(logger, str) - else None - ) - ) self.executor, self._num_async_workers = _initialize_executor_and_workers( batch_size=batch_size, @@ -312,185 +569,69 @@ def __call__(self, *args, **kwargs) -> BaseModel: return pipeline_outputs @staticmethod - def _get_task_constructor(task: str) -> Type["Pipeline"]: - """ - This function retrieves the class previously registered via `Pipeline.register` - for `task`. - - If `task` starts with "import:", it is treated as a module to be imported, - and retrieves the task via the `TASK` attribute of the imported module. - - If `task` starts with "custom", then it is mapped to the "custom" task. - - :param task: The task name to get the constructor for - :return: The class registered to `task` - :raises ValueError: if `task` was not registered via `Pipeline.register`. + def split_engine_inputs( + items: List[numpy.ndarray], batch_size: int + ) -> List[List[numpy.ndarray]]: """ - if task.startswith("import:"): - # dynamically import the task from a file - task = dynamic_import_task(module_or_path=task.replace("import:", "")) - elif task.startswith("custom"): - # support any task that has "custom" at the beginning via the "custom" task - task = "custom" - else: - task = task.lower().replace("-", "_") + Splits each item into numpy arrays with the first dimension == `batch_size`. - # extra step to register pipelines for a given task domain - # for cases where imports should only happen once a user specifies - # that domain is to be used. (ie deepsparse.transformers will auto - # install extra packages so should only import and register once a - # transformers task is specified) - SupportedTasks.check_register_task(task, _REGISTERED_PIPELINES.keys()) + For example, if `items` has three numpy arrays with the following + shapes: `[(4, 32, 32), (4, 64, 64), (4, 128, 128)]` - if task not in _REGISTERED_PIPELINES: - raise ValueError( - f"Unknown Pipeline task {task}. Pipeline tasks should be " - "must be declared with the Pipeline.register decorator. Currently " - f"registered pipelines: {list(_REGISTERED_PIPELINES.keys())}" - ) + Then with `batch_size==4` the output would be: + ``` + [[(4, 32, 32), (4, 64, 64), (4, 128, 128)]] + ``` - return _REGISTERED_PIPELINES[task] + Then with `batch_size==2` the output would be: + ``` + [ + [(2, 32, 32), (2, 64, 64), (2, 128, 128)], + [(2, 32, 32), (2, 64, 64), (2, 128, 128)], + ] + ``` - @staticmethod - def create( - task: str, - model_path: str = None, - engine_type: str = DEEPSPARSE_ENGINE, - batch_size: int = 1, - num_cores: int = None, - scheduler: Scheduler = None, - input_shapes: List[List[int]] = None, - alias: Optional[str] = None, - context: Optional[Context] = None, - **kwargs, - ) -> "Pipeline": - """ - :param task: name of task to create a pipeline for. Use "custom" for - custom tasks (see `CustomTaskPipeline`). - :param model_path: path on local system or SparseZoo stub to load the model - from. Some tasks may have a default model path - :param engine_type: inference engine to use. Currently supported values - include 'deepsparse' and 'onnxruntime'. Default is 'deepsparse' - :param batch_size: static batch size to use for inference. Default is 1 - :param num_cores: number of CPU cores to allocate for inference engine. None - specifies all available cores. Default is None - :param scheduler: (deepsparse only) kind of scheduler to execute with. - Pass None for the default - :param input_shapes: list of shapes to set ONNX the inputs to. Pass None - to use model as-is. Default is None - :param alias: optional name to give this pipeline instance, useful when - inferencing with multiple models. Default is None - :param context: Optional Context object to use for creating instances of - MultiModelEngine. The Context contains a shared scheduler along with - other runtime information that will be used across instances of the - MultiModelEngine to provide optimal performance when running - multiple models concurrently - :param kwargs: extra task specific kwargs to be passed to task Pipeline - implementation - :return: pipeline object initialized for the given task + Then with `batch_size==1` the output would be: + ``` + [ + [(1, 32, 32), (1, 64, 64), (1, 128, 128)], + [(1, 32, 32), (1, 64, 64), (1, 128, 128)], + [(1, 32, 32), (1, 64, 64), (1, 128, 128)], + [(1, 32, 32), (1, 64, 64), (1, 128, 128)], + ] + ``` """ - pipeline_constructor = Pipeline._get_task_constructor(task) - - if ( - (model_path is None or model_path == "default") - and hasattr(pipeline_constructor, "default_model_path") - and pipeline_constructor.default_model_path - ): - model_path = pipeline_constructor.default_model_path + # if not all items here are numpy arrays, there's an internal + # but in the processing code + assert all(isinstance(item, numpy.ndarray) for item in items) - if model_path is None: - raise ValueError( - f"No model_path provided for pipeline {pipeline_constructor}. Must " - "provide a model path for pipelines that do not have a default defined" - ) + # if not all items have the same batch size, there's an + # internal bug in the processing code + total_batch_size = items[0].shape[0] + assert all(item.shape[0] == total_batch_size for item in items) - if issubclass( - pipeline_constructor, Bucketable - ) and pipeline_constructor.should_bucket(**kwargs): - if input_shapes: - raise ValueError( - "Overriding input shapes not supported with Bucketing enabled" - ) - if not context: - context = Context(num_cores=num_cores) - buckets = pipeline_constructor.create_pipeline_buckets( - task=task, - model_path=model_path, - engine_type=engine_type, - batch_size=batch_size, - alias=alias, - context=context, - **kwargs, + if total_batch_size % batch_size != 0: + raise RuntimeError( + f"batch size of {total_batch_size} passed into pipeline " + f"is not divisible by model batch size of {batch_size}" ) - return BucketingPipeline(pipelines=buckets) - return pipeline_constructor( - model_path=model_path, - engine_type=engine_type, - batch_size=batch_size, - num_cores=num_cores, - scheduler=scheduler, - input_shapes=input_shapes, - alias=alias, - context=context, - **kwargs, - ) + batches = [] + for i_batch in range(total_batch_size // batch_size): + start = i_batch * batch_size + batches.append([item[start : start + batch_size] for item in items]) + return batches - @classmethod - def register( - cls, - task: str, - task_aliases: Optional[List[str]] = None, - default_model_path: Optional[str] = None, - ): + @staticmethod + def join_engine_outputs( + batch_outputs: List[List[numpy.ndarray]], + ) -> List[numpy.ndarray]: """ - Pipeline implementer class decorator that registers the pipeline - task name and its aliases as valid tasks that can be used to load - the pipeline through `Pipeline.create()`. - - Multiple pipelines may not have the same task name. An error will - be raised if two different pipelines attempt to register the same task name + Joins list of engine outputs together into one list using `numpy.concatenate`. - :param task: main task name of this pipeline - :param task_aliases: list of extra task names that may be used to reference - this pipeline. Default is None - :param default_model_path: path (ie zoo stub) to use as default for this - task if None is provided + This is the opposite of `Pipeline.split_engine_inputs`. """ - task_names = [task] - if task_aliases: - task_names.extend(task_aliases) - - task_names = [task_name.lower().replace("-", "_") for task_name in task_names] - - def _register_task(task_name, pipeline_class): - if task_name in _REGISTERED_PIPELINES and ( - pipeline_class is not _REGISTERED_PIPELINES[task_name] - ): - raise RuntimeError( - f"task {task_name} already registered by Pipeline.register. " - f"attempting to register pipeline: {pipeline_class}, but" - f"pipeline: {_REGISTERED_PIPELINES[task_name]}, already registered" - ) - _REGISTERED_PIPELINES[task_name] = pipeline_class - - def _register_pipeline_tasks_decorator(pipeline_class: Pipeline): - if not issubclass(pipeline_class, cls): - raise RuntimeError( - f"Attempting to register pipeline {pipeline_class}. " - f"Registered pipelines must inherit from {cls}" - ) - for task_name in task_names: - _register_task(task_name, pipeline_class) - - # set task and task_aliases as class level property - pipeline_class.task = task - pipeline_class.task_aliases = task_aliases - pipeline_class.default_model_path = default_model_path - - return pipeline_class - - return _register_pipeline_tasks_decorator + return list(map(numpy.concatenate, zip(*batch_outputs))) @classmethod def from_config( @@ -575,30 +716,6 @@ def process_engine_outputs( """ raise NotImplementedError() - @property - @abstractmethod - def input_schema(self) -> Type[BaseModel]: - """ - :return: pydantic model class that inputs to this pipeline must comply to - """ - raise NotImplementedError() - - @property - @abstractmethod - def output_schema(self) -> Type[BaseModel]: - """ - :return: pydantic model class that outputs of this pipeline must comply to - """ - raise NotImplementedError() - - @property - def alias(self) -> str: - """ - :return: optional name to give this pipeline instance, useful when - inferencing with multiple models - """ - return self._alias - @property def model_path_orig(self) -> str: """ @@ -685,54 +802,6 @@ def to_config(self) -> "PipelineConfig": kwargs=kwargs, ) - def log( - self, - identifier: str, - value: Any, - category: Union[str, MetricCategories], - ): - """ - Pass the logged data to the DeepSparse logger object (if present). - - :param identifier: The string name assigned to the logged value - :param value: The logged data structure - :param category: The metric category that the log belongs to - """ - if not self.logger: - return - - identifier = f"{self._identifier()}/{identifier}" - validate_identifier(identifier) - self.logger.log( - identifier=identifier, - value=value, - category=category, - pipeline_name=self._identifier(), - ) - return - - def parse_inputs(self, *args, **kwargs) -> BaseModel: - """ - :param args: ordered arguments to pipeline, only an input_schema object - is supported as an arg for this function - :param kwargs: keyword arguments to pipeline - :return: pipeline arguments parsed into the given `input_schema` - schema if necessary. If an instance of the `input_schema` is provided - it will be returned - """ - # passed input_schema schema directly - if len(args) == 1 and isinstance(args[0], self.input_schema) and not kwargs: - return args[0] - - if args: - raise ValueError( - f"pipeline {self.__class__} only supports either only a " - f"{self.input_schema} object. or keyword arguments to be construct " - f"one. Found {len(args)} args and {len(kwargs)} kwargs" - ) - - return self.input_schema(**kwargs) - def engine_forward(self, engine_inputs: List[numpy.ndarray]) -> List[numpy.ndarray]: """ :param engine_inputs: list of numpy inputs to Pipeline engine forward @@ -759,12 +828,6 @@ def _initialize_engine(self) -> Union[Engine, MultiModelEngine, ORTEngine]: self.onnx_file_path, self.engine_type, self._engine_args, self.context ) - def _identifier(self): - # get pipeline identifier; used in the context of logging - if not hasattr(self, "task"): - self.task = None - return f"{self.alias or self.task or 'unknown_pipeline'}" - class PipelineConfig(BaseModel): """ From 76db4da262b3382b25a98c8386ab640f7a414567 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 20 Jun 2023 20:46:38 +0000 Subject: [PATCH 02/14] update test; finish off initial refactoring changes post local testing --- src/deepsparse/pipeline.py | 83 ++++++++++++++----- .../pipelines/test_custom_pipeline.py | 2 +- 2 files changed, 64 insertions(+), 21 deletions(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 9fb756d1dd..eb3fd2362e 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -102,11 +102,15 @@ def __init__( ) ) + @abstractmethod + def __call__(self, *args, **kwargs) -> BaseModel: + raise NotImplementedError() + @staticmethod def _get_task_constructor(task: str) -> Type["BasePipeline"]: """ - This function retrieves the class previously registered via `Pipeline.register` - for `task`. + This function retrieves the class previously registered via + `BasePipeline.register` for `task`. If `task` starts with "import:", it is treated as a module to be imported, and retrieves the task via the `TASK` attribute of the imported module. @@ -150,24 +154,6 @@ def create( """ :param task: name of task to create a pipeline for. Use "custom" for custom tasks (see `CustomTaskPipeline`). - :param model_path: path on local system or SparseZoo stub to load the model - from. Some tasks may have a default model path - :param engine_type: inference engine to use. Currently supported values - include 'deepsparse' and 'onnxruntime'. Default is 'deepsparse' - :param batch_size: static batch size to use for inference. Default is 1 - :param num_cores: number of CPU cores to allocate for inference engine. None - specifies all available cores. Default is None - :param scheduler: (deepsparse only) kind of scheduler to execute with. - Pass None for the default - :param input_shapes: list of shapes to set ONNX the inputs to. Pass None - to use model as-is. Default is None - :param alias: optional name to give this pipeline instance, useful when - inferencing with multiple models. Default is None - :param context: Optional Context object to use for creating instances of - MultiModelEngine. The Context contains a shared scheduler along with - other runtime information that will be used across instances of the - MultiModelEngine to provide optimal performance when running - multiple models concurrently :param kwargs: extra task specific kwargs to be passed to task Pipeline implementation :return: pipeline object initialized for the given task @@ -269,6 +255,35 @@ def _register_pipeline_tasks_decorator( return _register_pipeline_tasks_decorator + @classmethod + def from_config( + cls, + config: Union["PipelineConfig", str, Path], + logger: Optional[BaseLogger] = None, + ) -> "Pipeline": + """ + :param config: PipelineConfig object, filepath to a json serialized + PipelineConfig, or raw string of a json serialized PipelineConfig + :param logger: An optional DeepSparse Logger object for inference + logging. Default is None + :return: loaded Pipeline object from the config + """ + if isinstance(config, Path) or ( + isinstance(config, str) and os.path.exists(config) + ): + if isinstance(config, str): + config = Path(config) + config = PipelineConfig.parse_file(config) + if isinstance(config, str): + config = PipelineConfig.parse_raw(config) + + return cls.create( + task=config.task, + alias=config.alias, + logger=logger, + **config.kwargs, + ) + @property @abstractmethod def input_schema(self) -> Type[BaseModel]: @@ -293,6 +308,31 @@ def alias(self) -> str: """ return self._alias + def to_config(self) -> "PipelineConfig": + """ + :return: PipelineConfig that can be used to reload this object + """ + + if not hasattr(self, "task"): + raise RuntimeError( + f"{self.__class__} instance has no attribute task. Pipeline objects " + "must have a task to be serialized to a config. Pipeline objects " + "must be declared with the Pipeline.register object to be assigned a " + "task" + ) + + # parse any additional properties as kwargs + kwargs = {} + for attr_name, attr in self.__class__.__dict__.items(): + if isinstance(attr, property) and attr_name not in dir(PipelineConfig): + kwargs[attr_name] = getattr(self, attr_name) + + return PipelineConfig( + task=self.task, + alias=self.alias, + kwargs=kwargs, + ) + def log( self, identifier: str, @@ -433,6 +473,7 @@ def __init__( executor: Optional[Union[ThreadPoolExecutor, int]] = None, benchmark: bool = False, _delay_engine_initialize: bool = False, + **kwargs, ): self._benchmark = benchmark self._model_path_orig = model_path @@ -441,6 +482,7 @@ def __init__( self._batch_size = batch_size self._timer_manager = TimerManager(enabled=True, multi=benchmark) self.context = context + super().__init__(**kwargs) self.executor, self._num_async_workers = _initialize_executor_and_workers( batch_size=batch_size, @@ -842,6 +884,7 @@ class PipelineConfig(BaseModel): description="name of task to create a pipeline for", ) model_path: str = Field( + default=None, description="path on local system or SparseZoo stub to load the model from", ) engine_type: str = Field( diff --git a/tests/deepsparse/pipelines/test_custom_pipeline.py b/tests/deepsparse/pipelines/test_custom_pipeline.py index b718ac4782..061b59ae03 100644 --- a/tests/deepsparse/pipelines/test_custom_pipeline.py +++ b/tests/deepsparse/pipelines/test_custom_pipeline.py @@ -113,7 +113,7 @@ def postprocess(outputs, **kwargs): pipeline = Pipeline.create( "custom", - model_path, + model_path=model_path, input_schema=ImageClassificationInput, output_schema=ImageClassificationOutput, process_inputs_fn=preprocess, From de7645d614ed8d4a762ee7ce49132ae4a7da9b70 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 20 Jun 2023 21:01:25 +0000 Subject: [PATCH 03/14] test fix --- tests/deepsparse/image_classification/test_pipelines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/deepsparse/image_classification/test_pipelines.py b/tests/deepsparse/image_classification/test_pipelines.py index c86fd69764..25f7fcf813 100644 --- a/tests/deepsparse/image_classification/test_pipelines.py +++ b/tests/deepsparse/image_classification/test_pipelines.py @@ -56,7 +56,7 @@ def test_image_classification_pipeline_preprocessing( ] ) - ic_pipeline = Pipeline.create("image_classification", zoo_stub) + ic_pipeline = Pipeline.create("image_classification", model_path=zoo_stub) zoo_model = Model(zoo_stub) data_originals_path = None if zoo_model.sample_originals is not None: From 8b602a2cd532f7b68bd46fd532e2bce4a8fa3931 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 20 Jun 2023 21:25:00 +0000 Subject: [PATCH 04/14] Test BasePipeline register and create --- src/deepsparse/pipeline.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index eb3fd2362e..a0069529a3 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -249,6 +249,8 @@ def _register_pipeline_tasks_decorator( # leave default_model_path for now as is optional? pipeline_class.task = task pipeline_class.task_aliases = task_aliases + # Mandatory? Seems to be only be used in create? Won't be used at all + # BasePipeline pipeline_class.default_model_path = default_model_path return pipeline_class From 0c4ba2070d88810580944239ccbe68970ca95f81 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 21 Jun 2023 15:39:46 +0000 Subject: [PATCH 05/14] add tests for BasePipeline --- tests/deepsparse/pipelines/test_pipeline.py | 44 +++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/deepsparse/pipelines/test_pipeline.py b/tests/deepsparse/pipelines/test_pipeline.py index 66db456d0a..971a33afd9 100644 --- a/tests/deepsparse/pipelines/test_pipeline.py +++ b/tests/deepsparse/pipelines/test_pipeline.py @@ -18,6 +18,8 @@ import numpy +import pytest +from deepsparse.base_pipeline import BasePipeline from deepsparse.pipeline import Pipeline, _initialize_executor_and_workers from tests.utils import mock_engine @@ -55,6 +57,48 @@ def test_split_interaction_with_forward_batch_size_2(engine_forward): assert engine_forward.call_count == 4 +@pytest.fixture +def base_pipeline_example(): + @BasePipeline.register(task="base_example") + class BasePipelineExample(BasePipeline): + def __init__(self, base_specific, **kwargs): + self._base_specific = base_specific + super().__init__(**kwargs) + + def __call__(self, *args, **kwargs): + pass + + def input_schema(self): + pass + + def output_schema(self): + pass + + @property + def base_specific(self): + return self._base_specific + + kwargs = {"alias": "base_alias", "base_specific": "base_specific"} + base_pipeline = BasePipeline.create(task="base_example", **kwargs) + return base_pipeline, BasePipelineExample, kwargs + + +def test_base_pipeline(base_pipeline_example): + base_pipeline = base_pipeline_example[0] + pipeline = base_pipeline_example[1] + kwargs = base_pipeline_example[-1] + + assert base_pipeline.alias == kwargs["alias"] + assert base_pipeline.base_specific == kwargs["base_specific"] + + cls = BasePipeline._get_task_constructor("base_example") + assert cls == pipeline + + config = base_pipeline.to_config() + assert isinstance(config, PipelineConfig) + assert config.kwargs["base_specific"] == base_pipeline.base_specific + + def test_pipeline_executor_num_workers(): executor, _ = _initialize_executor_and_workers(2, None) assert executor._max_workers == 1 From 18bee125df9c01f4091accf42b7c6508d2808c78 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 21 Jun 2023 18:05:01 +0000 Subject: [PATCH 06/14] add and update docstring --- src/deepsparse/pipeline.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index a0069529a3..26f34d192d 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -73,6 +73,31 @@ class BasePipeline(ABC): """ + Generic BasePipeline abstract class meant to wrap inference objects to include + model-specific Pipelines objects. Any pipeline inherited from Pipeline objects + should handle all model-specific input/output pre/post processing while BasePipeline + is meant to serve as a generic wrapper. Inputs and outputs of BasePipelines should + be serialized as pydantic Models. + + BasePipelines should not be instantiated by their constructors, but rather the + `BasePipeline.create()` method. The task name given to `create` will be used to + load the appropriate pipeline. The pipeline should inherit from `BasePipeline` and + implement the `__call__`, `input_schema`, and `output_schema` abstract methods. + + Finally, the class definition should be decorated by the `BasePipeline.register` + function. This defines the task name and task aliases for the pipeline and + ensures that it will be accessible by `BasePipeline.create`. The implemented + `BasePipeline` subclass must be imported at runtime to be accessible. + + Example: + @BasePipeline.register(task="base_example") + class BasePipelineExample(BasePipeline): + def __init__(self, base_specific, **kwargs): + self._base_specific = base_specific + self.model_pipeline = Pipeline.create(task="..") + super().__init__(**kwargs) + # implementation of abstract methods + :param alias: optional name to give this pipeline instance, useful when inferencing with multiple models. Default is None :param logger: An optional item that can be either a DeepSparse Logger object, @@ -104,13 +129,19 @@ def __init__( @abstractmethod def __call__(self, *args, **kwargs) -> BaseModel: + """ + Runner function needed to stitch together any parsing, preprocessing, engine, + and post-processing steps. + + :returns: pydantic model class that outputs of this pipeline must comply to + """ raise NotImplementedError() @staticmethod def _get_task_constructor(task: str) -> Type["BasePipeline"]: """ This function retrieves the class previously registered via - `BasePipeline.register` for `task`. + `BasePipeline.register` or `Pipeline.register` for `task`. If `task` starts with "import:", it is treated as a module to be imported, and retrieves the task via the `TASK` attribute of the imported module. @@ -206,7 +237,7 @@ def register( """ Pipeline implementer class decorator that registers the pipeline task name and its aliases as valid tasks that can be used to load - the pipeline through `BasePipeline.create()`. + the pipeline through `BasePipeline.create()` or `Pipeline.create()` Multiple pipelines may not have the same task name. An error will be raised if two different pipelines attempt to register the same task name @@ -394,7 +425,8 @@ class Pipeline(BasePipeline): """ Generic Pipeline abstract class meant to wrap inference engine objects to include data pre/post-processing. Inputs and outputs of pipelines should be serialized - as pydantic Models. + as pydantic Models. See the BasePipeline above for additional parameters provided + during inference. Pipelines should not be instantiated by their constructors, but rather the `Pipeline.create()` method. The task name given to `create` will be used to From a51e22a582de01eb1734b6377bb408e8b2d3ebbe Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 26 Jun 2023 16:14:29 +0000 Subject: [PATCH 07/14] update test --- tests/deepsparse/pipelines/test_pipeline.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/deepsparse/pipelines/test_pipeline.py b/tests/deepsparse/pipelines/test_pipeline.py index 971a33afd9..a7abdf5dde 100644 --- a/tests/deepsparse/pipelines/test_pipeline.py +++ b/tests/deepsparse/pipelines/test_pipeline.py @@ -78,8 +78,10 @@ def output_schema(self): def base_specific(self): return self._base_specific - kwargs = {"alias": "base_alias", "base_specific": "base_specific"} - base_pipeline = BasePipeline.create(task="base_example", **kwargs) + kwargs = {"base_specific": "base_specific"} + base_pipeline = BasePipeline.create( + task="base_example", alias="base_alias", **kwargs + ) return base_pipeline, BasePipelineExample, kwargs @@ -88,7 +90,6 @@ def test_base_pipeline(base_pipeline_example): pipeline = base_pipeline_example[1] kwargs = base_pipeline_example[-1] - assert base_pipeline.alias == kwargs["alias"] assert base_pipeline.base_specific == kwargs["base_specific"] cls = BasePipeline._get_task_constructor("base_example") From e52b58a84d0ee85860dba74ba5867ce710482579 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 27 Jun 2023 01:44:00 +0000 Subject: [PATCH 08/14] move BasePipeline to a new file --- src/deepsparse/__init__.py | 1 + src/deepsparse/base_pipeline.py | 382 ++++++++++++++++++++ src/deepsparse/pipeline.py | 370 +------------------ tests/deepsparse/pipelines/test_pipeline.py | 6 +- 4 files changed, 391 insertions(+), 368 deletions(-) create mode 100644 src/deepsparse/base_pipeline.py diff --git a/src/deepsparse/__init__.py b/src/deepsparse/__init__.py index 83fc4d9632..6c7d0f1cac 100644 --- a/src/deepsparse/__init__.py +++ b/src/deepsparse/__init__.py @@ -33,6 +33,7 @@ from .engine import * from .tasks import * from .pipeline import * +from .base_pipeline import * from .loggers import * from .version import __version__, is_release from .analytics import deepsparse_analytics as _analytics diff --git a/src/deepsparse/base_pipeline.py b/src/deepsparse/base_pipeline.py new file mode 100644 index 0000000000..20a4558103 --- /dev/null +++ b/src/deepsparse/base_pipeline.py @@ -0,0 +1,382 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, List, Optional, Type, Union + +from pydantic import BaseModel + +from deepsparse import Context +from deepsparse.loggers.base_logger import BaseLogger +from deepsparse.loggers.build_logger import logger_from_config +from deepsparse.loggers.constants import validate_identifier +from deepsparse.tasks import SupportedTasks, dynamic_import_task + + +__all__ = [ + "BasePipeline", +] + +_REGISTERED_PIPELINES = {} + + +class BasePipeline(ABC): + """ + Generic BasePipeline abstract class meant to wrap inference objects to include + model-specific Pipelines objects. Any pipeline inherited from Pipeline objects + should handle all model-specific input/output pre/post processing while BasePipeline + is meant to serve as a generic wrapper. Inputs and outputs of BasePipelines should + be serialized as pydantic Models. + + BasePipelines should not be instantiated by their constructors, but rather the + `BasePipeline.create()` method. The task name given to `create` will be used to + load the appropriate pipeline. The pipeline should inherit from `BasePipeline` and + implement the `__call__`, `input_schema`, and `output_schema` abstract methods. + + Finally, the class definition should be decorated by the `BasePipeline.register` + function. This defines the task name and task aliases for the pipeline and + ensures that it will be accessible by `BasePipeline.create`. The implemented + `BasePipeline` subclass must be imported at runtime to be accessible. + + Example: + @BasePipeline.register(task="base_example") + class BasePipelineExample(BasePipeline): + def __init__(self, base_specific, **kwargs): + self._base_specific = base_specific + self.model_pipeline = Pipeline.create(task="..") + super().__init__(**kwargs) + # implementation of abstract methods + + :param alias: optional name to give this pipeline instance, useful when + inferencing with multiple models. Default is None + :param logger: An optional item that can be either a DeepSparse Logger object, + or an object that can be transformed into one. Those object can be either + a path to the logging config, or yaml string representation the logging + config. If logger provided (in any form), the pipeline will log inference + metrics to the logger. Default is None + + """ + + def __init__( + self, + alias: Optional[str] = None, + logger: Optional[Union[BaseLogger, str]] = None, + ): + + self._alias = alias + self.logger = ( + logger + if isinstance(logger, BaseLogger) + else ( + logger_from_config( + config=logger, pipeline_identifier=self._identifier() + ) + if isinstance(logger, str) + else None + ) + ) + + @abstractmethod + def __call__(self, *args, **kwargs) -> BaseModel: + """ + Runner function needed to stitch together any parsing, preprocessing, engine, + and post-processing steps. + + :returns: pydantic model class that outputs of this pipeline must comply to + """ + raise NotImplementedError() + + @property + @abstractmethod + def input_schema(self) -> Type[BaseModel]: + """ + :return: pydantic model class that inputs to this pipeline must comply to + """ + raise NotImplementedError() + + @property + @abstractmethod + def output_schema(self) -> Type[BaseModel]: + """ + :return: pydantic model class that outputs of this pipeline must comply to + """ + raise NotImplementedError() + + @staticmethod + def _get_task_constructor(task: str) -> Type["BasePipeline"]: + """ + This function retrieves the class previously registered via + `BasePipeline.register` or `Pipeline.register` for `task`. + + If `task` starts with "import:", it is treated as a module to be imported, + and retrieves the task via the `TASK` attribute of the imported module. + + If `task` starts with "custom", then it is mapped to the "custom" task. + + :param task: The task name to get the constructor for + :return: The class registered to `task` + :raises ValueError: if `task` was not registered via `Pipeline.register`. + """ + if task.startswith("import:"): + # dynamically import the task from a file + task = dynamic_import_task(module_or_path=task.replace("import:", "")) + elif task.startswith("custom"): + # support any task that has "custom" at the beginning via the "custom" task + task = "custom" + else: + task = task.lower().replace("-", "_") + + # extra step to register pipelines for a given task domain + # for cases where imports should only happen once a user specifies + # that domain is to be used. (ie deepsparse.transformers will auto + # install extra packages so should only import and register once a + # transformers task is specified) + SupportedTasks.check_register_task(task, _REGISTERED_PIPELINES.keys()) + + if task not in _REGISTERED_PIPELINES: + raise ValueError( + f"Unknown Pipeline task {task}. Pipeline tasks should be " + "must be declared with the Pipeline.register decorator. Currently " + f"registered pipelines: {list(_REGISTERED_PIPELINES.keys())}" + ) + + return _REGISTERED_PIPELINES[task] + + @staticmethod + def create( + task: str, + **kwargs, + ) -> "BasePipeline": + """ + :param task: name of task to create a pipeline for. Use "custom" for + custom tasks (see `CustomTaskPipeline`). + :param kwargs: extra task specific kwargs to be passed to task Pipeline + implementation + :return: pipeline object initialized for the given task + """ + from deepsparse.pipeline import Bucketable, BucketingPipeline, Pipeline + + pipeline_constructor = BasePipeline._get_task_constructor(task) + model_path = kwargs.get("model_path", None) + + if issubclass(pipeline_constructor, Pipeline): + if ( + (model_path is None or model_path == "default") + and hasattr(pipeline_constructor, "default_model_path") + and pipeline_constructor.default_model_path + ): + model_path = pipeline_constructor.default_model_path + + if model_path is None: + raise ValueError( + f"No model_path provided for pipeline {pipeline_constructor}. Must " + "provide a model path for pipelines that do not have a default " + "defined" + ) + + kwargs["model_path"] = model_path + + if issubclass( + pipeline_constructor, Bucketable + ) and pipeline_constructor.should_bucket(**kwargs): + if kwargs.get("input_shape", None): + raise ValueError( + "Overriding input shapes not supported with Bucketing enabled" + ) + if not kwargs.get("context", None): + context = Context(num_cores=kwargs["num_cores"]) + kwargs["context"] = context + buckets = pipeline_constructor.create_pipeline_buckets( + task=task, + **kwargs, + ) + return BucketingPipeline(pipelines=buckets) + + return pipeline_constructor(**kwargs) + + @classmethod + def register( + cls, + task: str, + task_aliases: Optional[List[str]] = None, + default_model_path: Optional[str] = None, + ): + """ + Pipeline implementer class decorator that registers the pipeline + task name and its aliases as valid tasks that can be used to load + the pipeline through `BasePipeline.create()` or `Pipeline.create()` + + Multiple pipelines may not have the same task name. An error will + be raised if two different pipelines attempt to register the same task name + + :param task: main task name of this pipeline + :param task_aliases: list of extra task names that may be used to reference + this pipeline. Default is None + :param default_model_path: path (ie zoo stub) to use as default for this + task if None is provided + """ + task_names = [task] + if task_aliases: + task_names.extend(task_aliases) + + task_names = [task_name.lower().replace("-", "_") for task_name in task_names] + + def _register_task(task_name, pipeline_class): + if task_name in _REGISTERED_PIPELINES and ( + pipeline_class is not _REGISTERED_PIPELINES[task_name] + ): + raise RuntimeError( + f"task {task_name} already registered by BasePipeline.register. " + f"attempting to register pipeline: {pipeline_class}, but" + f"pipeline: {_REGISTERED_PIPELINES[task_name]}, already registered" + ) + _REGISTERED_PIPELINES[task_name] = pipeline_class + + def _register_pipeline_tasks_decorator(pipeline_class: BasePipeline): + if not issubclass(pipeline_class, cls): + raise RuntimeError( + f"Attempting to register pipeline {pipeline_class}. " + f"Registered pipelines must inherit from {cls}" + ) + for task_name in task_names: + _register_task(task_name, pipeline_class) + + # set task and task_aliases as class level property + pipeline_class.task = task + pipeline_class.task_aliases = task_aliases + pipeline_class.default_model_path = default_model_path + + return pipeline_class + + return _register_pipeline_tasks_decorator + + @classmethod + def from_config( + cls, + config: Union["PipelineConfig", str, Path], # noqa: F821 + logger: Optional[BaseLogger] = None, + ) -> "BasePipeline": + """ + :param config: PipelineConfig object, filepath to a json serialized + PipelineConfig, or raw string of a json serialized PipelineConfig + :param logger: An optional DeepSparse Logger object for inference + logging. Default is None + :return: loaded Pipeline object from the config + """ + from deepsparse.pipeline import PipelineConfig + + if isinstance(config, Path) or ( + isinstance(config, str) and os.path.exists(config) + ): + if isinstance(config, str): + config = Path(config) + config = PipelineConfig.parse_file(config) + if isinstance(config, str): + config = PipelineConfig.parse_raw(config) + + return cls.create( + task=config.task, + alias=config.alias, + logger=logger, + **config.kwargs, + ) + + @property + def alias(self) -> str: + """ + :return: optional name to give this pipeline instance, useful when + inferencing with multiple models + """ + return self._alias + + def to_config(self) -> "PipelineConfig": # noqa: F821 + """ + :return: PipelineConfig that can be used to reload this object + """ + from deepsparse.pipeline import PipelineConfig + + if not hasattr(self, "task"): + raise RuntimeError( + f"{self.__class__} instance has no attribute task. Pipeline objects " + "must have a task to be serialized to a config. Pipeline objects " + "must be declared with the Pipeline.register object to be assigned a " + "task" + ) + + # parse any additional properties as kwargs + kwargs = {} + for attr_name, attr in self.__class__.__dict__.items(): + if isinstance(attr, property) and attr_name not in dir(PipelineConfig): + kwargs[attr_name] = getattr(self, attr_name) + + return PipelineConfig( + task=self.task, + alias=self.alias, + kwargs=kwargs, + ) + + def log( + self, + identifier: str, + value: Any, + category: str, + ): + """ + Pass the logged data to the DeepSparse logger object (if present). + + :param identifier: The string name assigned to the logged value + :param value: The logged data structure + :param category: The metric category that the log belongs to + """ + if not self.logger: + return + + identifier = f"{self._identifier()}/{identifier}" + validate_identifier(identifier) + self.logger.log( + identifier=identifier, + value=value, + category=category, + pipeline_name=self._identifier(), + ) + return + + def parse_inputs(self, *args, **kwargs) -> BaseModel: + """ + :param args: ordered arguments to pipeline, only an input_schema object + is supported as an arg for this function + :param kwargs: keyword arguments to pipeline + :return: pipeline arguments parsed into the given `input_schema` + schema if necessary. If an instance of the `input_schema` is provided + it will be returned + """ + # passed input_schema schema directly + if len(args) == 1 and isinstance(args[0], self.input_schema) and not kwargs: + return args[0] + + if args: + raise ValueError( + f"pipeline {self.__class__} only supports either only a " + f"{self.input_schema} object. or keyword arguments to be construct " + f"one. Found {len(args)} args and {len(kwargs)} kwargs" + ) + + return self.input_schema(**kwargs) + + def _identifier(self): + # get pipeline identifier; used in the context of logging + if not hasattr(self, "task"): + self.task = None + return f"{self.alias or self.task or 'unknown_pipeline'}" diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 26f34d192d..7ab90fff3d 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -26,30 +26,18 @@ from pydantic import BaseModel, Field from deepsparse import Context, Engine, MultiModelEngine, Scheduler +from deepsparse.base_pipeline import BasePipeline from deepsparse.benchmark import ORTEngine from deepsparse.cpu import cpu_details from deepsparse.loggers.base_logger import BaseLogger -from deepsparse.loggers.build_logger import logger_from_config -from deepsparse.loggers.constants import ( - MetricCategories, - SystemGroups, - validate_identifier, -) -from deepsparse.tasks import SupportedTasks, dynamic_import_task -from deepsparse.utils import ( - InferenceStages, - StagedTimer, - TimerManager, - join_engine_outputs, - split_engine_inputs, -) +from deepsparse.loggers.constants import MetricCategories, SystemGroups +from deepsparse.utils import InferenceStages, StagedTimer, TimerManager __all__ = [ "DEEPSPARSE_ENGINE", "ORT_ENGINE", "SUPPORTED_PIPELINE_ENGINES", - "BasePipeline", "Pipeline", "PipelineConfig", "question_answering_pipeline", @@ -68,358 +56,6 @@ SUPPORTED_PIPELINE_ENGINES = [DEEPSPARSE_ENGINE, ORT_ENGINE] -_REGISTERED_PIPELINES = {} - - -class BasePipeline(ABC): - """ - Generic BasePipeline abstract class meant to wrap inference objects to include - model-specific Pipelines objects. Any pipeline inherited from Pipeline objects - should handle all model-specific input/output pre/post processing while BasePipeline - is meant to serve as a generic wrapper. Inputs and outputs of BasePipelines should - be serialized as pydantic Models. - - BasePipelines should not be instantiated by their constructors, but rather the - `BasePipeline.create()` method. The task name given to `create` will be used to - load the appropriate pipeline. The pipeline should inherit from `BasePipeline` and - implement the `__call__`, `input_schema`, and `output_schema` abstract methods. - - Finally, the class definition should be decorated by the `BasePipeline.register` - function. This defines the task name and task aliases for the pipeline and - ensures that it will be accessible by `BasePipeline.create`. The implemented - `BasePipeline` subclass must be imported at runtime to be accessible. - - Example: - @BasePipeline.register(task="base_example") - class BasePipelineExample(BasePipeline): - def __init__(self, base_specific, **kwargs): - self._base_specific = base_specific - self.model_pipeline = Pipeline.create(task="..") - super().__init__(**kwargs) - # implementation of abstract methods - - :param alias: optional name to give this pipeline instance, useful when - inferencing with multiple models. Default is None - :param logger: An optional item that can be either a DeepSparse Logger object, - or an object that can be transformed into one. Those object can be either - a path to the logging config, or yaml string representation the logging - config. If logger provided (in any form), the pipeline will log inference - metrics to the logger. Default is None - - """ - - def __init__( - self, - alias: Optional[str] = None, - logger: Optional[Union[BaseLogger, str]] = None, - ): - - self._alias = alias - self.logger = ( - logger - if isinstance(logger, BaseLogger) - else ( - logger_from_config( - config=logger, pipeline_identifier=self._identifier() - ) - if isinstance(logger, str) - else None - ) - ) - - @abstractmethod - def __call__(self, *args, **kwargs) -> BaseModel: - """ - Runner function needed to stitch together any parsing, preprocessing, engine, - and post-processing steps. - - :returns: pydantic model class that outputs of this pipeline must comply to - """ - raise NotImplementedError() - - @staticmethod - def _get_task_constructor(task: str) -> Type["BasePipeline"]: - """ - This function retrieves the class previously registered via - `BasePipeline.register` or `Pipeline.register` for `task`. - - If `task` starts with "import:", it is treated as a module to be imported, - and retrieves the task via the `TASK` attribute of the imported module. - - If `task` starts with "custom", then it is mapped to the "custom" task. - - :param task: The task name to get the constructor for - :return: The class registered to `task` - :raises ValueError: if `task` was not registered via `Pipeline.register`. - """ - if task.startswith("import:"): - # dynamically import the task from a file - task = dynamic_import_task(module_or_path=task.replace("import:", "")) - elif task.startswith("custom"): - # support any task that has "custom" at the beginning via the "custom" task - task = "custom" - else: - task = task.lower().replace("-", "_") - - # extra step to register pipelines for a given task domain - # for cases where imports should only happen once a user specifies - # that domain is to be used. (ie deepsparse.transformers will auto - # install extra packages so should only import and register once a - # transformers task is specified) - SupportedTasks.check_register_task(task, _REGISTERED_PIPELINES.keys()) - - if task not in _REGISTERED_PIPELINES: - raise ValueError( - f"Unknown Pipeline task {task}. Pipeline tasks should be " - "must be declared with the Pipeline.register decorator. Currently " - f"registered pipelines: {list(_REGISTERED_PIPELINES.keys())}" - ) - - return _REGISTERED_PIPELINES[task] - - @staticmethod - def create( - task: str, - **kwargs, - ) -> "BasePipeline": - """ - :param task: name of task to create a pipeline for. Use "custom" for - custom tasks (see `CustomTaskPipeline`). - :param kwargs: extra task specific kwargs to be passed to task Pipeline - implementation - :return: pipeline object initialized for the given task - """ - pipeline_constructor = BasePipeline._get_task_constructor(task) - model_path = kwargs.get("model_path", None) - - if issubclass(pipeline_constructor, Pipeline): - if ( - (model_path is None or model_path == "default") - and hasattr(pipeline_constructor, "default_model_path") - and pipeline_constructor.default_model_path - ): - model_path = pipeline_constructor.default_model_path - - if model_path is None: - raise ValueError( - f"No model_path provided for pipeline {pipeline_constructor}. Must " - "provide a model path for pipelines that do not have a default " - "defined" - ) - - kwargs["model_path"] = model_path - - if issubclass( - pipeline_constructor, Bucketable - ) and pipeline_constructor.should_bucket(**kwargs): - if kwargs.get("input_shape", None): - raise ValueError( - "Overriding input shapes not supported with Bucketing enabled" - ) - if not kwargs.get("context", None): - context = Context(num_cores=kwargs["num_cores"]) - kwargs["context"] = context - buckets = pipeline_constructor.create_pipeline_buckets( - task=task, - **kwargs, - ) - return BucketingPipeline(pipelines=buckets) - - return pipeline_constructor(**kwargs) - - @classmethod - def register( - cls, - task: str, - task_aliases: Optional[List[str]] = None, - default_model_path: Optional[str] = None, - ): - """ - Pipeline implementer class decorator that registers the pipeline - task name and its aliases as valid tasks that can be used to load - the pipeline through `BasePipeline.create()` or `Pipeline.create()` - - Multiple pipelines may not have the same task name. An error will - be raised if two different pipelines attempt to register the same task name - - :param task: main task name of this pipeline - :param task_aliases: list of extra task names that may be used to reference - this pipeline. Default is None - :param default_model_path: path (ie zoo stub) to use as default for this - task if None is provided - """ - task_names = [task] - if task_aliases: - task_names.extend(task_aliases) - - task_names = [task_name.lower().replace("-", "_") for task_name in task_names] - - def _register_task(task_name, pipeline_class): - if task_name in _REGISTERED_PIPELINES and ( - pipeline_class is not _REGISTERED_PIPELINES[task_name] - ): - raise RuntimeError( - f"task {task_name} already registered by BasePipeline.register. " - f"attempting to register pipeline: {pipeline_class}, but" - f"pipeline: {_REGISTERED_PIPELINES[task_name]}, already registered" - ) - _REGISTERED_PIPELINES[task_name] = pipeline_class - - def _register_pipeline_tasks_decorator( - pipeline_class: Union[BasePipeline, Pipeline] - ): - if not issubclass(pipeline_class, cls): - raise RuntimeError( - f"Attempting to register pipeline {pipeline_class}. " - f"Registered pipelines must inherit from {cls}" - ) - for task_name in task_names: - _register_task(task_name, pipeline_class) - - # set task and task_aliases as class level property - # leave default_model_path for now as is optional? - pipeline_class.task = task - pipeline_class.task_aliases = task_aliases - # Mandatory? Seems to be only be used in create? Won't be used at all - # BasePipeline - pipeline_class.default_model_path = default_model_path - - return pipeline_class - - return _register_pipeline_tasks_decorator - - @classmethod - def from_config( - cls, - config: Union["PipelineConfig", str, Path], - logger: Optional[BaseLogger] = None, - ) -> "Pipeline": - """ - :param config: PipelineConfig object, filepath to a json serialized - PipelineConfig, or raw string of a json serialized PipelineConfig - :param logger: An optional DeepSparse Logger object for inference - logging. Default is None - :return: loaded Pipeline object from the config - """ - if isinstance(config, Path) or ( - isinstance(config, str) and os.path.exists(config) - ): - if isinstance(config, str): - config = Path(config) - config = PipelineConfig.parse_file(config) - if isinstance(config, str): - config = PipelineConfig.parse_raw(config) - - return cls.create( - task=config.task, - alias=config.alias, - logger=logger, - **config.kwargs, - ) - - @property - @abstractmethod - def input_schema(self) -> Type[BaseModel]: - """ - :return: pydantic model class that inputs to this pipeline must comply to - """ - raise NotImplementedError() - - @property - @abstractmethod - def output_schema(self) -> Type[BaseModel]: - """ - :return: pydantic model class that outputs of this pipeline must comply to - """ - raise NotImplementedError() - - @property - def alias(self) -> str: - """ - :return: optional name to give this pipeline instance, useful when - inferencing with multiple models - """ - return self._alias - - def to_config(self) -> "PipelineConfig": - """ - :return: PipelineConfig that can be used to reload this object - """ - - if not hasattr(self, "task"): - raise RuntimeError( - f"{self.__class__} instance has no attribute task. Pipeline objects " - "must have a task to be serialized to a config. Pipeline objects " - "must be declared with the Pipeline.register object to be assigned a " - "task" - ) - - # parse any additional properties as kwargs - kwargs = {} - for attr_name, attr in self.__class__.__dict__.items(): - if isinstance(attr, property) and attr_name not in dir(PipelineConfig): - kwargs[attr_name] = getattr(self, attr_name) - - return PipelineConfig( - task=self.task, - alias=self.alias, - kwargs=kwargs, - ) - - def log( - self, - identifier: str, - value: Any, - category: str, - ): - """ - Pass the logged data to the DeepSparse logger object (if present). - - :param identifier: The string name assigned to the logged value - :param value: The logged data structure - :param category: The metric category that the log belongs to - """ - if not self.logger: - return - - identifier = f"{self._identifier()}/{identifier}" - validate_identifier(identifier) - self.logger.log( - identifier=identifier, - value=value, - category=category, - pipeline_name=self._identifier(), - ) - return - - def parse_inputs(self, *args, **kwargs) -> BaseModel: - """ - :param args: ordered arguments to pipeline, only an input_schema object - is supported as an arg for this function - :param kwargs: keyword arguments to pipeline - :return: pipeline arguments parsed into the given `input_schema` - schema if necessary. If an instance of the `input_schema` is provided - it will be returned - """ - # passed input_schema schema directly - if len(args) == 1 and isinstance(args[0], self.input_schema) and not kwargs: - return args[0] - - if args: - raise ValueError( - f"pipeline {self.__class__} only supports either only a " - f"{self.input_schema} object. or keyword arguments to be construct " - f"one. Found {len(args)} args and {len(kwargs)} kwargs" - ) - - return self.input_schema(**kwargs) - - def _identifier(self): - # get pipeline identifier; used in the context of logging - if not hasattr(self, "task"): - self.task = None - return f"{self.alias or self.task or 'unknown_pipeline'}" - class Pipeline(BasePipeline): """ diff --git a/tests/deepsparse/pipelines/test_pipeline.py b/tests/deepsparse/pipelines/test_pipeline.py index a7abdf5dde..139e579616 100644 --- a/tests/deepsparse/pipelines/test_pipeline.py +++ b/tests/deepsparse/pipelines/test_pipeline.py @@ -20,7 +20,11 @@ import pytest from deepsparse.base_pipeline import BasePipeline -from deepsparse.pipeline import Pipeline, _initialize_executor_and_workers +from deepsparse.pipeline import ( + Pipeline, + PipelineConfig, + _initialize_executor_and_workers, +) from tests.utils import mock_engine From a621a620dae7edccf645facff9bcdfafb8c0a988 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 27 Jun 2023 01:51:43 +0000 Subject: [PATCH 09/14] test fix --- src/deepsparse/server/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/deepsparse/server/cli.py b/src/deepsparse/server/cli.py index 29cbc9afb0..32c139d7f5 100644 --- a/src/deepsparse/server/cli.py +++ b/src/deepsparse/server/cli.py @@ -27,7 +27,7 @@ import click import yaml -from deepsparse.pipeline import SupportedTasks +from deepsparse.base_pipline import SupportedTasks from deepsparse.server.config import EndpointConfig, ServerConfig from deepsparse.server.server import start_server From baf8e7e1986a593d0714bf1f36cc46975d1ada09 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 27 Jun 2023 02:09:31 +0000 Subject: [PATCH 10/14] anothe test fix --- src/deepsparse/server/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/deepsparse/server/cli.py b/src/deepsparse/server/cli.py index 32c139d7f5..cd2ce6c367 100644 --- a/src/deepsparse/server/cli.py +++ b/src/deepsparse/server/cli.py @@ -27,9 +27,9 @@ import click import yaml -from deepsparse.base_pipline import SupportedTasks from deepsparse.server.config import EndpointConfig, ServerConfig from deepsparse.server.server import start_server +from deepsparse.tasks import SupportedTasks HOST_OPTION = click.option( From 3933fc4ba20475dc651134714bde12f297120292 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 27 Jun 2023 02:16:19 +0000 Subject: [PATCH 11/14] fix import --- tests/deepsparse/pipelines/test_dynamic_import.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/deepsparse/pipelines/test_dynamic_import.py b/tests/deepsparse/pipelines/test_dynamic_import.py index 63096e2365..f6c5ff1e3e 100644 --- a/tests/deepsparse/pipelines/test_dynamic_import.py +++ b/tests/deepsparse/pipelines/test_dynamic_import.py @@ -15,7 +15,8 @@ import os import pytest -from deepsparse.pipeline import _REGISTERED_PIPELINES, Pipeline +from deepsparse.base_pipeline import _REGISTERED_PIPELINES +from deepsparse.pipeline import Pipeline from deepsparse.tasks import _split_dir_and_name, dynamic_import_task From e34249e789c847005911849a76e734ede56fd8ae Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 27 Jun 2023 02:32:10 +0000 Subject: [PATCH 12/14] revert --- src/deepsparse/server/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/deepsparse/server/cli.py b/src/deepsparse/server/cli.py index cd2ce6c367..d14358b0d5 100644 --- a/src/deepsparse/server/cli.py +++ b/src/deepsparse/server/cli.py @@ -27,9 +27,9 @@ import click import yaml +from deepsparse.base_pipeline import SupportedTasks from deepsparse.server.config import EndpointConfig, ServerConfig from deepsparse.server.server import start_server -from deepsparse.tasks import SupportedTasks HOST_OPTION = click.option( From fd49585e4b440b1c35f1c1b2ec30c4ae562b5935 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 13 Jul 2023 15:52:35 +0000 Subject: [PATCH 13/14] rebase fix --- src/deepsparse/pipeline.py | 73 ++++---------------------------------- 1 file changed, 7 insertions(+), 66 deletions(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 7ab90fff3d..7b6f83bcf7 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -31,7 +31,13 @@ from deepsparse.cpu import cpu_details from deepsparse.loggers.base_logger import BaseLogger from deepsparse.loggers.constants import MetricCategories, SystemGroups -from deepsparse.utils import InferenceStages, StagedTimer, TimerManager +from deepsparse.utils import ( + InferenceStages, + StagedTimer, + TimerManager, + join_engine_outputs, + split_engine_inputs, +) __all__ = [ @@ -280,71 +286,6 @@ def __call__(self, *args, **kwargs) -> BaseModel: return pipeline_outputs - @staticmethod - def split_engine_inputs( - items: List[numpy.ndarray], batch_size: int - ) -> List[List[numpy.ndarray]]: - """ - Splits each item into numpy arrays with the first dimension == `batch_size`. - - For example, if `items` has three numpy arrays with the following - shapes: `[(4, 32, 32), (4, 64, 64), (4, 128, 128)]` - - Then with `batch_size==4` the output would be: - ``` - [[(4, 32, 32), (4, 64, 64), (4, 128, 128)]] - ``` - - Then with `batch_size==2` the output would be: - ``` - [ - [(2, 32, 32), (2, 64, 64), (2, 128, 128)], - [(2, 32, 32), (2, 64, 64), (2, 128, 128)], - ] - ``` - - Then with `batch_size==1` the output would be: - ``` - [ - [(1, 32, 32), (1, 64, 64), (1, 128, 128)], - [(1, 32, 32), (1, 64, 64), (1, 128, 128)], - [(1, 32, 32), (1, 64, 64), (1, 128, 128)], - [(1, 32, 32), (1, 64, 64), (1, 128, 128)], - ] - ``` - """ - # if not all items here are numpy arrays, there's an internal - # but in the processing code - assert all(isinstance(item, numpy.ndarray) for item in items) - - # if not all items have the same batch size, there's an - # internal bug in the processing code - total_batch_size = items[0].shape[0] - assert all(item.shape[0] == total_batch_size for item in items) - - if total_batch_size % batch_size != 0: - raise RuntimeError( - f"batch size of {total_batch_size} passed into pipeline " - f"is not divisible by model batch size of {batch_size}" - ) - - batches = [] - for i_batch in range(total_batch_size // batch_size): - start = i_batch * batch_size - batches.append([item[start : start + batch_size] for item in items]) - return batches - - @staticmethod - def join_engine_outputs( - batch_outputs: List[List[numpy.ndarray]], - ) -> List[numpy.ndarray]: - """ - Joins list of engine outputs together into one list using `numpy.concatenate`. - - This is the opposite of `Pipeline.split_engine_inputs`. - """ - return list(map(numpy.concatenate, zip(*batch_outputs))) - @classmethod def from_config( cls, From b94937d2cdbbc03ef420fce76ba42d1f489b4f98 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 17 Jul 2023 15:48:14 +0000 Subject: [PATCH 14/14] expose functions through pipeline.py --- src/deepsparse/pipeline.py | 5 ++++- src/deepsparse/server/cli.py | 2 +- tests/deepsparse/pipelines/test_dynamic_import.py | 3 +-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index 7b6f83bcf7..1c55441a41 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -26,7 +26,7 @@ from pydantic import BaseModel, Field from deepsparse import Context, Engine, MultiModelEngine, Scheduler -from deepsparse.base_pipeline import BasePipeline +from deepsparse.base_pipeline import _REGISTERED_PIPELINES, BasePipeline, SupportedTasks from deepsparse.benchmark import ORTEngine from deepsparse.cpu import cpu_details from deepsparse.loggers.base_logger import BaseLogger @@ -45,6 +45,9 @@ "ORT_ENGINE", "SUPPORTED_PIPELINE_ENGINES", "Pipeline", + "BasePipeline", + "SupportedTasks", + "_REGISTERED_PIPELINES", "PipelineConfig", "question_answering_pipeline", "text_classification_pipeline", diff --git a/src/deepsparse/server/cli.py b/src/deepsparse/server/cli.py index d14358b0d5..29cbc9afb0 100644 --- a/src/deepsparse/server/cli.py +++ b/src/deepsparse/server/cli.py @@ -27,7 +27,7 @@ import click import yaml -from deepsparse.base_pipeline import SupportedTasks +from deepsparse.pipeline import SupportedTasks from deepsparse.server.config import EndpointConfig, ServerConfig from deepsparse.server.server import start_server diff --git a/tests/deepsparse/pipelines/test_dynamic_import.py b/tests/deepsparse/pipelines/test_dynamic_import.py index f6c5ff1e3e..63096e2365 100644 --- a/tests/deepsparse/pipelines/test_dynamic_import.py +++ b/tests/deepsparse/pipelines/test_dynamic_import.py @@ -15,8 +15,7 @@ import os import pytest -from deepsparse.base_pipeline import _REGISTERED_PIPELINES -from deepsparse.pipeline import Pipeline +from deepsparse.pipeline import _REGISTERED_PIPELINES, Pipeline from deepsparse.tasks import _split_dir_and_name, dynamic_import_task