From 3b0a7ae732433accbc59cd968b4b82295c366ef2 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Wed, 24 Jul 2024 16:22:32 +0200 Subject: [PATCH 1/2] Refactor create metrics Signed-off-by: Ashwin Vaidya --- src/anomalib/metrics/__init__.py | 21 +++++++---- tests/unit/metrics/test_create_metrics.py | 45 +++++++++++++++++++++++ 2 files changed, 59 insertions(+), 7 deletions(-) create mode 100644 tests/unit/metrics/test_create_metrics.py diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index 4c3eafa811..f135393d45 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -10,6 +10,7 @@ import torchmetrics from omegaconf import DictConfig, ListConfig +from torchmetrics import Metric from .anomaly_score_distribution import AnomalyScoreDistribution from .aupr import AUPR @@ -162,7 +163,7 @@ def metric_collection_from_dicts(metrics: dict[str, dict[str, Any]], prefix: str def create_metric_collection( - metrics: list[str] | dict[str, dict[str, Any]], + metrics: list[str] | dict[str, dict[str, Any]] | Metric | list[Metric], prefix: str | None = None, ) -> AnomalibMetricCollection: """Create a metric collection from a list of metric names or dictionaries. @@ -171,25 +172,31 @@ def create_metric_collection( - if list[str] (names of metrics): see `metric_collection_from_names` - if dict[str, dict[str, Any]] (path and init args of a class): see `metric_collection_from_dicts` + - if list[Metric] (metric objects): A collection is returned with those metrics. The function will first try to retrieve the metric from the metrics defined in Anomalib metrics module, then in TorchMetrics package. Args: - metrics (list[str] | dict[str, dict[str, Any]]): List of metrics or dictionaries to create metric collection. + metrics (list[str] | dict[str, dict[str, Any]] | Metric | list[Metric]): List of metrics or dictionaries to + create metric collection. prefix (str | None): Prefix to assign to the metrics in the collection. Returns: AnomalibMetricCollection: Collection of metrics. """ - # fallback is using the names - if isinstance(metrics, ListConfig | list): - if not all(isinstance(metric, str) for metric in metrics): - msg = f"All metrics must be strings, found {metrics}" + if not ( + all(isinstance(metric, str) for metric in metrics) or all(isinstance(metric, Metric) for metric in metrics) + ): + msg = f"All metrics must be either string or Metric objects, found {metrics}" raise TypeError(msg) + if all(isinstance(metric, str) for metric in metrics): + return metric_collection_from_names(metrics, prefix) + return AnomalibMetricCollection(metrics, prefix) - return metric_collection_from_names(metrics, prefix) + if isinstance(metrics, Metric): + return AnomalibMetricCollection([metrics], prefix) if isinstance(metrics, DictConfig | dict): _validate_metrics_dict(metrics) diff --git a/tests/unit/metrics/test_create_metrics.py b/tests/unit/metrics/test_create_metrics.py new file mode 100644 index 0000000000..cf2ea7cac9 --- /dev/null +++ b/tests/unit/metrics/test_create_metrics.py @@ -0,0 +1,45 @@ +"""Test metrics collection creation.""" + +from torchmetrics.classification import Accuracy + +from anomalib.metrics import AUPRO, create_metric_collection + + +def test_string_initialization() -> None: + """Pass metrics as a list of string.""" + metrics_list = ["AUROC", "AUPR"] + collection = create_metric_collection(metrics_list, prefix=None) + assert len(collection) == 2 + assert "AUROC" in collection + assert "AUPR" in collection + + +def test_dict_initialization() -> None: + """Pass metrics as a dictionary.""" + metrics_dict = { + "PixelWiseAUROC": { + "class_path": "anomalib.metrics.AUROC", + "init_args": {}, + }, + "Precision": { + "class_path": "torchmetrics.Precision", + "init_args": {"task": "binary"}, + }, + } + collection = create_metric_collection(metrics_dict, prefix=None) + assert len(collection) == 2 + assert "PixelWiseAUROC" in collection + assert "Precision" in collection + + +def test_metric_object_initialization() -> None: + """Pass metrics as a list of metric objects.""" + metrics_list = [AUPRO(), Accuracy(task="binary")] + collection = create_metric_collection(metrics_list, prefix=None) + assert len(collection) == 2 + assert "AUPRO" in collection + assert "BinaryAccuracy" in collection + + collection = create_metric_collection(AUPRO(), prefix=None) + assert len(collection) == 1 + assert "AUPRO" in collection From b1c0abe4cef5ad2c43d0e105fa5d5372bcb5b835 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Wed, 24 Jul 2024 16:28:35 +0200 Subject: [PATCH 2/2] update changelog Signed-off-by: Ashwin Vaidya --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 362856e3ec..7c80893f5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Changed -- 🔨 Replace "./dtasets/BTech" to "./dtasets/BTech" +- 🔨 Allow passing metrics objects directly to `create_metrics_collection` by @ashwinvaidya17 in https://github.com/openvinotoolkit/anomalib/pull/2212 +- 🔨 Replace "./dtasets/BTech" to "./datasets/BTech" by @samet-akcay in https://github.com/openvinotoolkit/anomalib/pull/2180 ### Deprecated