Skip to content

Commit

Permalink
No grad during the TorchFX model validation
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Dec 3, 2024
1 parent 9e90443 commit 840d9ca
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions tests/post_training/pipelines/image_classification_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,19 @@ def _validate_torch_compile(

print(f"Qunatize ops num: {q_num}")

if self.backend in [BackendType.X86_QUANTIZER_AO, BackendType.X86_QUANTIZER_NNCF]:
compiled_model = torch.compile(self.compressed_model)
else:
compiled_model = torch.compile(self.compressed_model, backend="openvino")

for i, (images, target) in enumerate(val_loader):
# W/A for memory leaks when using torch DataLoader and OpenVINO
pred = compiled_model(images)
pred = torch.argmax(pred, dim=1)
predictions[i] = pred.numpy()
references[i] = target.numpy()
return predictions, references
with torch.no_grad():
if self.backend in [BackendType.X86_QUANTIZER_AO, BackendType.X86_QUANTIZER_NNCF]:
compiled_model = torch.compile(self.compressed_model)
else:
compiled_model = torch.compile(self.compressed_model, backend="openvino")

for i, (images, target) in enumerate(val_loader):
# W/A for memory leaks when using torch DataLoader and OpenVINO
pred = compiled_model(images)
pred = torch.argmax(pred, dim=1)
predictions[i] = pred.numpy()
references[i] = target.numpy()
return predictions, references

def _validate(self):
val_dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform)
Expand Down

0 comments on commit 840d9ca

Please sign in to comment.