Skip to content

Commit

Permalink
plt update net instance, test=model
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeref996 committed Aug 29, 2024
1 parent 83accc5 commit 0bc98a2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
6 changes: 6 additions & 0 deletions framework/e2e/PaddleLT_new/engine/paddle_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def jit_save(self):

# paddle.jit.save(net, path=os.path.join(self.path, self.case))
paddle.jit.save(st_net, path=os.path.join(self.path, self.layername, "jit_save"))
return {"res": None}

def jit_save_inputspec(self):
"""jit.save(layer)"""
Expand All @@ -95,6 +96,7 @@ def jit_save_inputspec(self):

# paddle.jit.save(net, path=os.path.join(self.path, self.case))
paddle.jit.save(st_net, path=os.path.join(self.path, self.layername, "jit_save_inputspec"))
return {"res": None}

def jit_save_static_inputspec(self):
"""jit.save(layer)"""
Expand All @@ -108,6 +110,7 @@ def jit_save_static_inputspec(self):

# paddle.jit.save(net, path=os.path.join(self.path, self.case))
paddle.jit.save(st_net, path=os.path.join(self.path, self.layername, "jit_save_static_inputspec"))
return {"res": None}

def jit_save_cinn(self):
"""jit.save(layer)"""
Expand All @@ -122,6 +125,7 @@ def jit_save_cinn(self):

# paddle.jit.save(net, path=os.path.join(self.path, self.case))
paddle.jit.save(cinn_net, path=os.path.join(self.path, self.layername, "jit_save_cinn"))
return {"res": None}

def jit_save_cinn_inputspec(self):
"""jit.save(layer)"""
Expand All @@ -137,6 +141,7 @@ def jit_save_cinn_inputspec(self):

# paddle.jit.save(net, path=os.path.join(self.path, self.case))
paddle.jit.save(cinn_net, path=os.path.join(self.path, self.layername, "jit_save_cinn_inputspec"))
return {"res": None}

def jit_save_cinn_static_inputspec(self):
"""jit.save(layer)"""
Expand All @@ -152,3 +157,4 @@ def jit_save_cinn_static_inputspec(self):

# paddle.jit.save(net, path=os.path.join(self.path, self.case))
paddle.jit.save(cinn_net, path=os.path.join(self.path, self.layername, "jit_save_cinn_static_inputspec"))
return {"res": None}
10 changes: 5 additions & 5 deletions framework/e2e/PaddleLT_new/engine/paddle_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def paddle_infer_gpu(self):
else:
output_handle = predictor.get_output_handle(output_names[0])
infer_res = output_handle.copy_to_cpu()
return {"logit": infer_res}
return {"res": {"logit": infer_res}}

def paddle_infer_cpu(self):
"""infer load (layer)"""
Expand Down Expand Up @@ -92,7 +92,7 @@ def paddle_infer_cpu(self):
else:
output_handle = predictor.get_output_handle(output_names[0])
infer_res = output_handle.copy_to_cpu()
return {"logit": infer_res}
return {"res": {"logit": infer_res}}

def paddle_infer_mkldnn(self):
"""infer load (layer)"""
Expand Down Expand Up @@ -123,7 +123,7 @@ def paddle_infer_mkldnn(self):
else:
output_handle = predictor.get_output_handle(output_names[0])
infer_res = output_handle.copy_to_cpu()
return {"logit": infer_res}
return {"res": {"logit": infer_res}}

def paddle_infer_ort(self):
"""infer load (layer)"""
Expand Down Expand Up @@ -153,7 +153,7 @@ def paddle_infer_ort(self):
else:
output_handle = predictor.get_output_handle(output_names[0])
infer_res = output_handle.copy_to_cpu()
return {"logit": infer_res}
return {"res": {"logit": infer_res}}

def paddle_infer_new_exc_pir(self):
"""infer load (layer)"""
Expand Down Expand Up @@ -186,4 +186,4 @@ def paddle_infer_new_exc_pir(self):
else:
output_handle = predictor.get_output_handle(output_names[0])
infer_res = output_handle.copy_to_cpu()
return {"logit": infer_res}
return {"res": {"logit": infer_res}}
2 changes: 1 addition & 1 deletion framework/e2e/PaddleLT_new/layertest.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _case_run(self):
testing=testing, layerfile=self.layerfile, device_place_id=self.device_place_id, upstream_net=net
)
res_dict[testing] = res["res"]
net = res["net"]
net = res.get("net", None)
if os.environ.get("PLT_SAVE_GT") == "True": # 开启gt保存
gt_path = os.path.join("plt_gt", os.environ.get("PLT_SET_DEVICE"), testing)
if not os.path.exists(gt_path):
Expand Down

0 comments on commit 0bc98a2

Please sign in to comment.