Skip to content

Commit

Permalink
🛠 Fix mask filenames in folder dataset (#249)
Browse files Browse the repository at this point in the history
* 🛠  Fix  mask filenames in folder ddataset

* fix folder dataset (#252)

Co-authored-by: Samet Akcay <samet.akcay@intel.com>

Co-authored-by: Alexander Riedel <54716527+alexriedel1@users.noreply.github.com>
  • Loading branch information
samet-akcay and alexriedel1 authored Apr 22, 2022
1 parent b2dfdd2 commit 24ccb58
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions anomalib/data/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def _prepare_files_labels(
if extensions is None:
extensions = IMG_EXTENSIONS

filenames = [f for f in path.glob(r"**/*") if f.suffix in extensions]
if isinstance(extensions, str):
extensions = (extensions,)

filenames = [f for f in path.glob(r"**/*") if f.suffix in extensions and not f.is_dir()]
if len(filenames) == 0:
raise RuntimeError(f"Found 0 {path_type} images in {path}")

Expand Down Expand Up @@ -140,11 +143,10 @@ def make_dataset(
# If a path to mask is provided, add it to the sample dataframe.
if mask_dir is not None:
mask_dir = _check_and_convert_path(mask_dir)
normal_gt = ["" for f in samples.loc[samples.label_index == 0]["image_path"]]
abnormal_gt = [str(mask_dir / f.name) for f in samples.loc[samples.label_index == 1]["image_path"]]
gt_filenames = normal_gt + abnormal_gt

samples["mask_path"] = gt_filenames
samples["mask_path"] = ""
for index, row in samples.iterrows():
if row.label_index == 1:
samples["mask_path"][index] = str(mask_dir / row.image_path.name)

# Ensure the pathlib objects are converted to str.
# This is because torch dataloader doesn't like pathlib.
Expand Down Expand Up @@ -463,6 +465,7 @@ def setup(self, stage: Optional[str] = None) -> None:
self.train_data = FolderDataset(
normal_dir=self.normal_dir,
abnormal_dir=self.abnormal_dir,
normal_test_dir=self.normal_test,
split="train",
split_ratio=self.split_ratio,
mask_dir=self.mask_dir,
Expand All @@ -477,6 +480,7 @@ def setup(self, stage: Optional[str] = None) -> None:
self.val_data = FolderDataset(
normal_dir=self.normal_dir,
abnormal_dir=self.abnormal_dir,
normal_test_dir=self.normal_test,
split="val",
split_ratio=self.split_ratio,
mask_dir=self.mask_dir,
Expand Down

0 comments on commit 24ccb58

Please sign in to comment.