Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Dec 22, 2024
1 parent 4daa408 commit 4b82c3c
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 0 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class OnnxConfig(ExportConfig, ABC):
"text2text-generation": OrderedDict({"logits": {0: "batch_size", 1: "decoder_sequence_length"}}),
"text-classification": OrderedDict({"logits": {0: "batch_size"}}),
"text-generation": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"time-series-forecasting": OrderedDict({"prediction_outputs": {0: "batch_size"}}),
"token-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"visual-question-answering": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"zero-shot-image-classification": OrderedDict(
Expand Down
44 changes: 44 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedTimeSeriesForecastingConfig,
NormalizedVisionConfig,
check_if_diffusers_greater,
check_if_transformers_greater,
Expand Down Expand Up @@ -2499,3 +2500,46 @@ class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig

DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.




class TimesFMDummyInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ("inputs",)

def __init__(
self,
task: str,
normalized_config: NormalizedConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
**kwargs,
):
self.task = task
self.normalized_config = normalized_config

self.batch_size = batch_size
self.context_len = normalized_config.context_len

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
return self.random_float_tensor(
shape=[self.batch_size, self.context_len],
min_value=-1,
max_value=1,
framework=framework,
dtype=float_dtype,
)


class TimesFMOnnxConfig(OnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTimeSeriesForecastingConfig
MIN_TRANSFORMERS_VERSION = version.parse("4.47.0")
DUMMY_INPUT_GENERATOR_CLASSES = (TimesFMDummyInputGenerator,)


@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"inputs": {0: "batch_size", 1: "sequence_length"}}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return super().outputs
5 changes: 5 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ class TasksManager:
("pt", "visual-bert", "question-answering"): ("transformers", "VisualBertForQuestionAnswering"),
# VisionEncoderDecoderModel is not registered in AutoModelForDocumentQuestionAnswering
("pt", "vision-encoder-decoder", "document-question-answering"): ("transformers", "VisionEncoderDecoderModel"),
("pt", "timesfm", "time-series-forecasting"): ("transformers", "TimesFMModelForPrediction"),
}

_ENCODER_DECODER_TASKS = (
Expand Down Expand Up @@ -939,6 +940,10 @@ class TasksManager:
"text-classification",
onnx="Qwen2OnnxConfig",
),
"timesfm": supported_tasks_mapping(
"time-series-forecasting",
onnx="TimesFMOnnxConfig",
),
"llama": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,6 @@
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedTimeSeriesForecastingConfig,
NormalizedVisionConfig,
)
4 changes: 4 additions & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class NormalizedSeq2SeqConfig(NormalizedTextConfig):
DECODER_NUM_ATTENTION_HEADS = NormalizedTextConfig.NUM_ATTENTION_HEADS


class NormalizedTimeSeriesForecastingConfig(NormalizedConfig):
CONTEXT_LEN = "context_len"


class NormalizedVisionConfig(NormalizedConfig):
IMAGE_SIZE = "image_size"
NUM_CHANNELS = "num_channels"
Expand Down

0 comments on commit 4b82c3c

Please sign in to comment.