From 840d9cafcc597c3ceead07a80b57c969a8b8a85e Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Tue, 3 Dec 2024 15:19:07 +0100 Subject: [PATCH] No grad during the TorchFX model validation --- .../pipelines/image_classification_base.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/post_training/pipelines/image_classification_base.py b/tests/post_training/pipelines/image_classification_base.py index c28c2762474..9a686bcd3ce 100644 --- a/tests/post_training/pipelines/image_classification_base.py +++ b/tests/post_training/pipelines/image_classification_base.py @@ -95,18 +95,19 @@ def _validate_torch_compile( print(f"Qunatize ops num: {q_num}") - if self.backend in [BackendType.X86_QUANTIZER_AO, BackendType.X86_QUANTIZER_NNCF]: - compiled_model = torch.compile(self.compressed_model) - else: - compiled_model = torch.compile(self.compressed_model, backend="openvino") - - for i, (images, target) in enumerate(val_loader): - # W/A for memory leaks when using torch DataLoader and OpenVINO - pred = compiled_model(images) - pred = torch.argmax(pred, dim=1) - predictions[i] = pred.numpy() - references[i] = target.numpy() - return predictions, references + with torch.no_grad(): + if self.backend in [BackendType.X86_QUANTIZER_AO, BackendType.X86_QUANTIZER_NNCF]: + compiled_model = torch.compile(self.compressed_model) + else: + compiled_model = torch.compile(self.compressed_model, backend="openvino") + + for i, (images, target) in enumerate(val_loader): + # W/A for memory leaks when using torch DataLoader and OpenVINO + pred = compiled_model(images) + pred = torch.argmax(pred, dim=1) + predictions[i] = pred.numpy() + references[i] = target.numpy() + return predictions, references def _validate(self): val_dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform)