From 470bb860d8478298fdf3f10a530184f8cdc9c938 Mon Sep 17 00:00:00 2001 From: abc-125 <63813435+abc-125@users.noreply.github.com> Date: Mon, 8 Jan 2024 14:43:05 +0100 Subject: [PATCH] Refactor/extensions custom dataset (#1562) * Explanation how to use extension names in the config file * Added information about extensions to the error message and control of the user input * Easier to read code * Replacing assert with raise --- README.md | 2 +- src/anomalib/data/utils/path.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7fb4c3ea80..97c8e390be 100644 --- a/README.md +++ b/README.md @@ -173,7 +173,7 @@ dataset: normal_test_dir: null # name of the folder containing normal test images. task: segmentation # classification or segmentation mask: #optional - extensions: null + extensions: null # .ext or [.ext1, .ext2, ...] split_ratio: 0.2 # ratio of the normal images that will be used to create a test split image_size: 256 train_batch_size: 32 diff --git a/src/anomalib/data/utils/path.py b/src/anomalib/data/utils/path.py index 93784ba57b..663e201243 100644 --- a/src/anomalib/data/utils/path.py +++ b/src/anomalib/data/utils/path.py @@ -58,13 +58,16 @@ def _prepare_files_labels( if isinstance(extensions, str): extensions = (extensions,) + if not all(extension.startswith(".") for extension in extensions): + raise RuntimeError(f"All extensions {extensions} must start with the dot") + filenames = [ f for f in path.glob("**/*") if f.suffix in extensions and not f.is_dir() and not any(part.startswith(".") for part in f.parts) ] if not filenames: - raise RuntimeError(f"Found 0 {path_type} images in {path}") + raise RuntimeError(f"Found 0 {path_type} images in {path} with extensions {extensions}") labels = [path_type] * len(filenames)