Skip to content

Commit

Permalink
recover
Browse files Browse the repository at this point in the history
  • Loading branch information
skyline75489 committed Dec 13, 2024
1 parent 69e438a commit d73c884
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions test/test_tools_add_pre_post_processing_to_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def test_pytorch_mobilenet_using_clip_feature(self):
output_model = os.path.join(test_data_dir, "pytorch_mobilenet_v2.updated.onnx")
input_image_path = os.path.join(test_data_dir, "wolves.jpg")

add_clip_feature.clip_image_processor(Path(input_model), Path(output_model), opset=16, do_resize=True,
add_clip_feature.clip_image_processor(Path(input_model), Path(output_model), opset=16, do_resize=True,
do_center_crop=True, do_normalize=True, do_rescale=True,
do_convert_rgb=True, size=256, crop_size=224,
do_convert_rgb=True, size=256, crop_size=224,
rescale_factor=1/255, image_mean=[0.485, 0.456, 0.406],
image_std=[0.229, 0.224, 0.225])

Expand Down Expand Up @@ -449,7 +449,7 @@ def test_draw_box_crop_pad(self):
create_boxdrawing_model.create_model(output_model, is_crop=is_crop)
image_ref = np.frombuffer(load_image_file(output_img), dtype=np.uint8)
output = self.draw_boxes_on_image(output_model, test_boxes[idx])
self.assertLess(compare_two_images_mse(image_ref, output), 0.13)
self.assertLess(compare_two_images_mse(image_ref, output), 0.2)

def test_draw_box_share_border(self):
import sys
Expand All @@ -469,7 +469,7 @@ def test_draw_box_share_border(self):

output_img = (Path(test_data_dir) / f"../wolves_with_box_share_borders.jpg").resolve()
image_ref = np.frombuffer(load_image_file(output_img), dtype=np.uint8)
self.assertLess(compare_two_images_mse(image_ref, output), 0.1)
self.assertLess(compare_two_images_mse(image_ref, output), 0.2)

def test_draw_box_off_boundary_box(self):
import sys
Expand Down Expand Up @@ -627,7 +627,7 @@ def _create_pipeline_and_run_for_nms(self, output_model: Path,
graph_def = onnx.parser.parse_graph(
f"""\
identity (float[num_boxes,{length}] _input)
=> (float[num_boxes,{length}] _output)
=> (float[num_boxes,{length}] _output)
{{
_output = Identity(_input)
}}
Expand Down Expand Up @@ -816,7 +816,7 @@ def _create_pipeline_and_run_for_nms_and_scaling(self, output_model: Path,
onnx_opset = 16

graph_text = \
f"""pass_through ({', '.join(graph_input_strings)}) => ({', '.join(graph_output_strings)})
f"""pass_through ({', '.join(graph_input_strings)}) => ({', '.join(graph_output_strings)})
{{
{graph_nodes}
}}"""
Expand Down

0 comments on commit d73c884

Please sign in to comment.