Skip to content

Commit

Permalink
Refactor imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Ashwin Vaidya committed Sep 11, 2023
1 parent 941afe2 commit 0c4066c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.distributions import LogNormal

from anomalib import trainer as t
import anomalib
from anomalib.models import get_model
from anomalib.models.components import AnomalyModule
from anomalib.post_processing.normalization.cdf import normalize, standardize
Expand All @@ -27,7 +27,7 @@ def __init__(self) -> None:
self.image_dist: LogNormal | None = None
self.pixel_dist: LogNormal | None = None

def setup(self, trainer: "t.AnomalibTrainer", pl_module: AnomalyModule, stage: str | None = None) -> None:
def setup(self, trainer: "anomalib.AnomalibTrainer", pl_module: AnomalyModule, stage: str | None = None) -> None:
"""Adds training_distribution metrics to normalization metrics."""
del trainer, stage # These variabels are not used.

Expand All @@ -39,7 +39,7 @@ def setup(self, trainer: "t.AnomalibTrainer", pl_module: AnomalyModule, stage: s
f" got {type(pl_module.normalization_metrics)}"
)

def on_test_start(self, trainer: "t.AnomalibTrainer", pl_module: AnomalyModule) -> None:
def on_test_start(self, trainer: "anomalib.AnomalibTrainer", pl_module: AnomalyModule) -> None:
"""Called when the test begins."""
del trainer # `trainer` variable is not used.

Expand All @@ -48,7 +48,7 @@ def on_test_start(self, trainer: "t.AnomalibTrainer", pl_module: AnomalyModule)
if pl_module.pixel_metrics is not None:
pl_module.pixel_metrics.set_threshold(0.5)

def on_validation_epoch_start(self, trainer: "t.AnomalibTrainer", pl_module: AnomalyModule) -> None:
def on_validation_epoch_start(self, trainer: "anomalib.AnomalibTrainer", pl_module: AnomalyModule) -> None:
"""Called when the validation starts after training.
Use the current model to compute the anomaly score distributions
Expand All @@ -60,7 +60,7 @@ def on_validation_epoch_start(self, trainer: "t.AnomalibTrainer", pl_module: Ano

def on_validation_batch_end(
self,
trainer: "t.AnomalibTrainer",
trainer: "anomalib.AnomalibTrainer",
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any,
Expand All @@ -74,7 +74,7 @@ def on_validation_batch_end(

def on_test_batch_end(
self,
trainer: "t.AnomalibTrainer",
trainer: "anomalib.AnomalibTrainer",
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any,
Expand All @@ -89,7 +89,7 @@ def on_test_batch_end(

def on_predict_batch_end(
self,
trainer: "t.AnomalibTrainer",
trainer: "anomalib.AnomalibTrainer",
pl_module: AnomalyModule,
outputs: dict,
batch: Any,
Expand All @@ -103,14 +103,14 @@ def on_predict_batch_end(
self._normalize_batch(outputs, pl_module)
outputs["pred_labels"] = outputs["pred_scores"] >= 0.5

def _collect_stats(self, trainer: "t.AnomalibTrainer", pl_module: AnomalyModule) -> None:
def _collect_stats(self, trainer: "anomalib.AnomalibTrainer", pl_module: AnomalyModule) -> None:
"""Collect the statistics of the normal training data.
Create a trainer and use it to predict the anomaly maps and scores of the normal training data. Then
estimate the distribution of anomaly scores for normal data at the image and pixel level by computing
the mean and standard deviations. A dictionary containing the computed statistics is stored in self.stats.
"""
predictions = t.AnomalibTrainer(
predictions = anomalib.AnomalibTrainer(
accelerator=trainer.accelerator, devices=trainer.num_devices, normalizer="none"
).predict(model=self._create_inference_model(pl_module), dataloaders=trainer.datamodule.train_dataloader())
pl_module.normalization_metrics.reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

from typing import Any

import lightning.pytorch as pl
import torch
from lightning.pytorch import Callback
from lightning.pytorch.utilities.types import STEP_OUTPUT

import anomalib
from anomalib.models.components import AnomalyModule
from anomalib.post_processing.normalization.min_max import normalize
from anomalib.utils.metrics import MinMax
Expand All @@ -19,7 +19,7 @@
class MinMaxNormalizationCallback(Callback):
"""Callback that normalizes the image-level and pixel-level anomaly scores using min-max normalization."""

def setup(self, trainer: pl.Trainer, pl_module: AnomalyModule, stage: str | None = None) -> None:
def setup(self, trainer: "anomalib.AnomalibTrainer", pl_module: AnomalyModule, stage: str | None = None) -> None:
"""Adds min_max metrics to normalization metrics."""
del trainer, stage # These variables are not used.

Expand All @@ -30,7 +30,7 @@ def setup(self, trainer: pl.Trainer, pl_module: AnomalyModule, stage: str | None
f"Expected normalization_metrics to be of type MinMax, got {type(pl_module.normalization_metrics)}"
)

def on_test_start(self, trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
def on_test_start(self, trainer: "anomalib.AnomalibTrainer", pl_module: AnomalyModule) -> None:
"""Called when the test begins."""
del trainer # `trainer` variable is not used.

Expand All @@ -40,7 +40,7 @@ def on_test_start(self, trainer: pl.Trainer, pl_module: AnomalyModule) -> None:

def on_validation_batch_end(
self,
trainer: pl.Trainer,
trainer: "anomalib.AnomalibTrainer",
pl_module: AnomalyModule,
outputs: STEP_OUTPUT,
batch: Any,
Expand All @@ -61,7 +61,7 @@ def on_validation_batch_end(

def on_test_batch_end(
self,
trainer: pl.Trainer,
trainer: "anomalib.AnomalibTrainer",
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any,
Expand All @@ -75,7 +75,7 @@ def on_test_batch_end(

def on_predict_batch_end(
self,
trainer: pl.Trainer,
trainer: "anomalib.AnomalibTrainer",
pl_module: AnomalyModule,
outputs: Any,
batch: Any,
Expand Down
10 changes: 5 additions & 5 deletions src/anomalib/trainer/callbacks/post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import Tensor

from anomalib import trainer as t
import anomalib
from anomalib.data.utils import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes
from anomalib.models import AnomalyModule

Expand All @@ -22,7 +22,7 @@ def __init__(self) -> None:

def on_validation_batch_end(
self,
trainer: "t.AnomalibTrainer",
trainer: "anomalib.AnomalibTrainer",
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any,
Expand All @@ -34,7 +34,7 @@ def on_validation_batch_end(

def on_test_batch_end(
self,
trainer: "t.AnomalibTrainer",
trainer: "anomalib.AnomalibTrainer",
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any,
Expand All @@ -46,7 +46,7 @@ def on_test_batch_end(

def on_predict_batch_end(
self,
trainer: "t.AnomalibTrainer",
trainer: "anomalib.AnomalibTrainer",
pl_module: AnomalyModule,
outputs: Any,
batch: Any,
Expand All @@ -56,7 +56,7 @@ def on_predict_batch_end(
if outputs is not None:
self.post_process(trainer, pl_module, outputs)

def post_process(self, trainer: "t.AnomalibTrainer", pl_module: AnomalyModule, outputs: STEP_OUTPUT):
def post_process(self, trainer: "anomalib.AnomalibTrainer", pl_module: AnomalyModule, outputs: STEP_OUTPUT):
if isinstance(outputs, dict):
self._post_process(outputs)
if trainer.predicting or trainer.testing:
Expand Down

0 comments on commit 0c4066c

Please sign in to comment.