From d1f824a5798262891dbbe583fb291e1cf9aa7d2a Mon Sep 17 00:00:00 2001 From: Alexander Riedel <54716527+alexriedel1@users.noreply.github.com> Date: Tue, 16 Jul 2024 13:26:03 +0200 Subject: [PATCH] Fix normalization (#2130) * fix normalization * precommit config... * reset normalization metrics on validation start * fix model loading and saving normalitzation metrics * Update src/anomalib/callbacks/normalization/min_max_normalization.py * Update src/anomalib/callbacks/normalization/min_max_normalization.py --------- Co-authored-by: Samet Akcay --- .../normalization/min_max_normalization.py | 48 +++++++++++++------ .../deploy/inferencers/base_inferencer.py | 12 ++--- .../models/components/base/anomaly_module.py | 26 +++++++--- .../models/image/csflow/lightning_model.py | 6 +-- .../models/image/csflow/torch_model.py | 2 +- 5 files changed, 64 insertions(+), 30 deletions(-) diff --git a/src/anomalib/callbacks/normalization/min_max_normalization.py b/src/anomalib/callbacks/normalization/min_max_normalization.py index 6eb5eaf7c0..cdd9760b42 100644 --- a/src/anomalib/callbacks/normalization/min_max_normalization.py +++ b/src/anomalib/callbacks/normalization/min_max_normalization.py @@ -8,6 +8,7 @@ import torch from lightning.pytorch import Trainer from lightning.pytorch.utilities.types import STEP_OUTPUT +from torchmetrics import MetricCollection from anomalib.metrics import MinMax from anomalib.models.components import AnomalyModule @@ -27,13 +28,26 @@ def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None = del trainer, stage # These variables are not used. if not hasattr(pl_module, "normalization_metrics"): - pl_module.normalization_metrics = MinMax().cpu() - elif not isinstance(pl_module.normalization_metrics, MinMax): - msg = f"Expected normalization_metrics to be of type MinMax, got {type(pl_module.normalization_metrics)}" - raise AttributeError( - msg, + pl_module.normalization_metrics = MetricCollection( + { + "anomaly_maps": MinMax().cpu(), + "box_scores": MinMax().cpu(), + "pred_scores": MinMax().cpu(), + }, ) + elif not isinstance(pl_module.normalization_metrics, MetricCollection): + msg = ( + f"Expected normalization_metrics to be of type MetricCollection" + f"got {type(pl_module.normalization_metrics)}" + ) + raise TypeError(msg) + + for name, metric in pl_module.normalization_metrics.items(): + if not isinstance(metric, MinMax): + msg = f"Expected normalization_metric {name} to be of type MinMax, got {type(metric)}" + raise TypeError(msg) + def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None: """Call when the test begins.""" del trainer # `trainer` variable is not used. @@ -42,6 +56,13 @@ def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None: if metric is not None: metric.set_threshold(0.5) + def on_validation_epoch_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None: + """Call when the validation epoch begins.""" + del trainer # `trainer` variable is not used. + + if hasattr(pl_module, "normalization_metrics"): + pl_module.normalization_metrics.reset() + def on_validation_batch_end( self, trainer: Trainer, @@ -55,14 +76,11 @@ def on_validation_batch_end( del trainer, batch, batch_idx, dataloader_idx # These variables are not used. if "anomaly_maps" in outputs: - pl_module.normalization_metrics(outputs["anomaly_maps"]) - elif "box_scores" in outputs: - pl_module.normalization_metrics(torch.cat(outputs["box_scores"])) - elif "pred_scores" in outputs: - pl_module.normalization_metrics(outputs["pred_scores"]) - else: - msg = "No values found for normalization, provide anomaly maps, bbox scores, or image scores" - raise ValueError(msg) + pl_module.normalization_metrics["anomaly_maps"](outputs["anomaly_maps"]) + if "box_scores" in outputs: + pl_module.normalization_metrics["box_scores"](torch.cat(outputs["box_scores"])) + if "pred_scores" in outputs: + pl_module.normalization_metrics["pred_scores"](outputs["pred_scores"]) def on_test_batch_end( self, @@ -97,12 +115,14 @@ def _normalize_batch(outputs: Any, pl_module: AnomalyModule) -> None: # noqa: A """Normalize a batch of predictions.""" image_threshold = pl_module.image_threshold.value.cpu() pixel_threshold = pl_module.pixel_threshold.value.cpu() - stats = pl_module.normalization_metrics.cpu() if "pred_scores" in outputs: + stats = pl_module.normalization_metrics["pred_scores"].cpu() outputs["pred_scores"] = normalize(outputs["pred_scores"], image_threshold, stats.min, stats.max) if "anomaly_maps" in outputs: + stats = pl_module.normalization_metrics["anomaly_maps"].cpu() outputs["anomaly_maps"] = normalize(outputs["anomaly_maps"], pixel_threshold, stats.min, stats.max) if "box_scores" in outputs: + stats = pl_module.normalization_metrics["box_scores"].cpu() outputs["box_scores"] = [ normalize(scores, pixel_threshold, stats.min, stats.max) for scores in outputs["box_scores"] ] diff --git a/src/anomalib/deploy/inferencers/base_inferencer.py b/src/anomalib/deploy/inferencers/base_inferencer.py index 50c8721804..05f8d65ba0 100644 --- a/src/anomalib/deploy/inferencers/base_inferencer.py +++ b/src/anomalib/deploy/inferencers/base_inferencer.py @@ -101,19 +101,19 @@ def _normalize( visualized and predicted scores. """ # min max normalization - if "min" in metadata and "max" in metadata: - if anomaly_maps is not None: + if "pred_scores.min" in metadata and "pred_scores.max" in metadata: + if anomaly_maps is not None and "anomaly_maps.max" in metadata: anomaly_maps = normalize_min_max( anomaly_maps, metadata["pixel_threshold"], - metadata["min"], - metadata["max"], + metadata["anomaly_maps.min"], + metadata["anomaly_maps.max"], ) pred_scores = normalize_min_max( pred_scores, metadata["image_threshold"], - metadata["min"], - metadata["max"], + metadata["pred_scores.min"], + metadata["pred_scores.max"], ) return anomaly_maps, float(pred_scores) diff --git a/src/anomalib/models/components/base/anomaly_module.py b/src/anomalib/models/components/base/anomaly_module.py index affe858b43..7e2d9479cf 100644 --- a/src/anomalib/models/components/base/anomaly_module.py +++ b/src/anomalib/models/components/base/anomaly_module.py @@ -15,6 +15,7 @@ from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.types import STEP_OUTPUT from torch import nn +from torchmetrics import MetricCollection from torchvision.transforms.v2 import Compose, Normalize, Resize, Transform from anomalib import LearningType @@ -25,7 +26,6 @@ if TYPE_CHECKING: from lightning.pytorch.callbacks import Callback - from torchmetrics import Metric logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ def __init__(self) -> None: self.image_threshold: BaseThreshold self.pixel_threshold: BaseThreshold - self.normalization_metrics: Metric + self.normalization_metrics: MetricCollection self.image_metrics: AnomalibMetricCollection self.pixel_metrics: AnomalibMetricCollection @@ -155,8 +155,9 @@ def _save_to_state_dict(self, destination: OrderedDict, prefix: str, keep_vars: f"{self.pixel_threshold.__class__.__module__}.{self.pixel_threshold.__class__.__name__}" ) if hasattr(self, "normalization_metrics"): - normalization_class = self.normalization_metrics.__class__ - destination["normalization_class"] = f"{normalization_class.__module__}.{normalization_class.__name__}" + for metric in self.normalization_metrics: + metric_class = self.normalization_metrics[metric].__class__ + destination[f"{metric}_normalization_class"] = f"{metric_class.__module__}.{metric_class.__name__}" return super()._save_to_state_dict(destination, prefix, keep_vars) @@ -166,8 +167,21 @@ def load_state_dict(self, state_dict: OrderedDict[str, Any], strict: bool = True self.image_threshold = self._get_instance(state_dict, "image_threshold_class") if "pixel_threshold_class" in state_dict: self.pixel_threshold = self._get_instance(state_dict, "pixel_threshold_class") - if "normalization_class" in state_dict: - self.normalization_metrics = self._get_instance(state_dict, "normalization_class") + + if "anomaly_maps_normalization_class" in state_dict: + self.anomaly_maps_normalization_metrics = self._get_instance(state_dict, "anomaly_maps_normalization_class") + if "box_scores_normalization_class" in state_dict: + self.box_scores_normalization_metrics = self._get_instance(state_dict, "box_scores_normalization_class") + if "pred_scores_normalization_class" in state_dict: + self.pred_scores_normalization_metrics = self._get_instance(state_dict, "pred_scores_normalization_class") + + self.normalization_metrics = MetricCollection( + { + "anomaly_maps": self.anomaly_maps_normalization_metrics, + "box_scores": self.box_scores_normalization_metrics, + "pred_scores": self.pred_scores_normalization_metrics, + }, + ) # Used to load metrics if there is any related data in state_dict self._load_metrics(state_dict) diff --git a/src/anomalib/models/image/csflow/lightning_model.py b/src/anomalib/models/image/csflow/lightning_model.py index a759000aa2..e4b4dad2ef 100644 --- a/src/anomalib/models/image/csflow/lightning_model.py +++ b/src/anomalib/models/image/csflow/lightning_model.py @@ -100,9 +100,9 @@ def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) """ del args, kwargs # These variables are not used. - anomaly_maps, anomaly_scores = self.model(batch["image"]) - batch["anomaly_maps"] = anomaly_maps - batch["pred_scores"] = anomaly_scores + output = self.model(batch["image"]) + batch["anomaly_maps"] = output["anomaly_map"] + batch["pred_scores"] = output["pred_score"] return batch @property diff --git a/src/anomalib/models/image/csflow/torch_model.py b/src/anomalib/models/image/csflow/torch_model.py index 14b819a771..08c9106a3c 100644 --- a/src/anomalib/models/image/csflow/torch_model.py +++ b/src/anomalib/models/image/csflow/torch_model.py @@ -588,7 +588,7 @@ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: z_dist, _ = self.graph(features) # Ignore Jacobians anomaly_scores = self._compute_anomaly_scores(z_dist) anomaly_maps = self.anomaly_map_generator(z_dist) - output = anomaly_maps, anomaly_scores + output = {"anomaly_map": anomaly_maps, "pred_score": anomaly_scores} return output def _compute_anomaly_scores(self, z_dists: torch.Tensor) -> torch.Tensor: