Skip to content

Commit

Permalink
No grad during the TorchFX model validation
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Dec 5, 2024
1 parent ef1a7fd commit 8e83db5
Showing 1 changed file with 34 additions and 21 deletions.
55 changes: 34 additions & 21 deletions tests/post_training/pipelines/image_classification_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from nncf.common.logging.track_progress import track
from nncf.experimental.common.quantization.algorithms.quantizer.openvino_quantizer import OpenVINOQuantizer
from nncf.experimental.torch.fx.quantization.quantize_pt2e import quantize_pt2e
from nncf.torch import disable_patching
from tests.post_training.pipelines.base import DEFAULT_VAL_THREADS
from tests.post_training.pipelines.base import FX_BACKENDS
from tests.post_training.pipelines.base import BackendType
Expand Down Expand Up @@ -95,18 +96,20 @@ def _validate_torch_compile(

print(f"Qunatize ops num: {q_num}")

if self.backend in [BackendType.X86_QUANTIZER_AO, BackendType.X86_QUANTIZER_NNCF]:
compiled_model = torch.compile(self.compressed_model)
else:
compiled_model = torch.compile(self.compressed_model, backend="openvino")

for i, (images, target) in enumerate(val_loader):
# W/A for memory leaks when using torch DataLoader and OpenVINO
pred = compiled_model(images)
pred = torch.argmax(pred, dim=1)
predictions[i] = pred.numpy()
references[i] = target.numpy()
return predictions, references
with disable_patching():
with torch.no_grad():
if self.backend in [BackendType.X86_QUANTIZER_AO, BackendType.X86_QUANTIZER_NNCF]:
compiled_model = torch.compile(self.compressed_model)
else:
compiled_model = torch.compile(self.compressed_model, backend="openvino")

for i, (images, target) in enumerate(val_loader):
# W/A for memory leaks when using torch DataLoader and OpenVINO
pred = compiled_model(images)
pred = torch.argmax(pred, dim=1)
predictions[i] = pred.numpy()
references[i] = target.numpy()
return predictions, references

def _validate(self):
val_dataset = datasets.ImageFolder(root=self.data_dir / "imagenet" / "val", transform=self.transform)
Expand All @@ -130,11 +133,13 @@ def _validate(self):

def _compress_torch_ao(self, quantizer):

prepared_model = prepare_pt2e(self.model, quantizer)
subset_size = self.compression_params.get("subset_size", 300)
for data in islice(self.calibration_dataset.get_inference_data(), subset_size):
prepared_model(data)
self.compressed_model = convert_pt2e(prepared_model)
with disable_patching():
with torch.no_grad():
prepared_model = prepare_pt2e(self.model, quantizer)
subset_size = self.compression_params.get("subset_size", 300)
for data in islice(self.calibration_dataset.get_inference_data(), subset_size):
prepared_model(data)
self.compressed_model = convert_pt2e(prepared_model)

def _compress_nncf_pt2e(self, quantizer):
pt2e_kwargs = {}
Expand All @@ -152,17 +157,25 @@ def _compress_nncf_pt2e(self, quantizer):
smooth_quant = False
if self.compression_params.get("model_type", False):
smooth_quant = self.compression_params["model_type"] == nncf.ModelType.TRANSFORMER
self.compressed_model = quantize_pt2e(
self.model, quantizer, self.calibration_dataset, smooth_quant=smooth_quant, fold_quantize=False
)
with disable_patching():
with torch.no_grad():
self.compressed_model = quantize_pt2e(
self.model, quantizer, self.calibration_dataset, smooth_quant=smooth_quant, fold_quantize=False
)

def _compress(self):
"""
Quantize self.model
"""
if self.backend not in FX_BACKENDS or self.backend == BackendType.FX_TORCH:
if self.backend not in FX_BACKENDS:
super()._compress()

return
if self.backend == BackendType.FX_TORCH:
with disable_patching():
with torch.no_grad():
super()._compress()
return

if self.backend in [BackendType.OV_QUANTIZER_AO, BackendType.OV_QUANTIZER_NNCF]:
quantizer_kwargs = {}
Expand Down

0 comments on commit 8e83db5

Please sign in to comment.