Skip to content

Commit

Permalink
plt update net instance, test=model
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeref996 committed Aug 30, 2024
1 parent b59f018 commit 430e1ea
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions framework/e2e/PaddleLT_new/strategy/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def base_compare(result, expect, res_name, exp_name, logger, delta=1e-10, rtol=1
exc_dict=exc_dict,
)
else:
Logger("PLT_compare").get_log().info(f"expect有 {k}, 但是result没有 {k}, 所以跳过 {k} 精度对比")
Logger("PLT_compare").get_log().info(f"{exp_name}{k}, 但是 {res_name} 没有 {k}, 所以跳过 {k} 精度对比")
elif isinstance(expect, list) or isinstance(expect, tuple):
for i, element in enumerate(expect):
if isinstance(result, (np.generic, np.ndarray)) or isinstance(result, eval(f"{framework}.Tensor")):
Expand Down Expand Up @@ -136,7 +136,11 @@ def base_compare(result, expect, res_name, exp_name, logger, delta=1e-10, rtol=1
)
elif isinstance(expect, (bool, int, float)):
assert expect == result
elif expect is None:
elif expect is None or result is None:
if expect is None:
Logger("PLT_compare").get_log().info(f"{exp_name} 结果为None, 所以跳过 {exp_name}{res_name} 精度对比")
if result is None:
Logger("PLT_compare").get_log().info(f"{res_name} 结果为None, 所以跳过 {exp_name}{res_name} 精度对比")
pass
else:
raise Exception("expect is unknown data struction in compare_tool!!!")
Expand Down

0 comments on commit 430e1ea

Please sign in to comment.