Skip to content

Commit

Permalink
fix tests/test_pipeline_benchmark.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Apr 26, 2024
1 parent 7f903a8 commit 6e18775
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions src/deepsparse/benchmark/data_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import random
import string
from os import path
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, get_args

import numpy

Expand Down Expand Up @@ -58,15 +58,11 @@ def get_input_schema_type(pipeline: Pipeline) -> str:
if SchemaType.TEXT_SEQ in input_schema_requirements:
if input_schema_fields.get(SchemaType.TEXT_SEQ).alias == SchemaType.TEXT_PROMPT:
return SchemaType.TEXT_PROMPT
sequence_types = [
f.outer_type_ for f in input_schema_fields[SchemaType.TEXT_SEQ].sub_fields
]
sequence_types = get_args(input_schema_fields[SchemaType.TEXT_SEQ].annotation)
if List[str] in sequence_types:
return SchemaType.TEXT_SEQ
elif SchemaType.TEXT_INPUT in input_schema_requirements:
sequence_types = [
f.outer_type_ for f in input_schema_fields[SchemaType.TEXT_INPUT].sub_fields
]
sequence_types = get_args(input_schema_fields[SchemaType.TEXT_INPUT].annotation)
if List[str] in sequence_types:
return SchemaType.TEXT_INPUT
elif SchemaType.QUESTION in input_schema_requirements:
Expand Down

0 comments on commit 6e18775

Please sign in to comment.