Skip to content

Commit

Permalink
Fix tiling for Reverse Distillation and STFPM (#1319)
Browse files Browse the repository at this point in the history
* Fix tiling for stfpm

* Fix tiling for reverse distillation

* Add tiling to fastflow

* Make tiling in fastflow backwards compatible

* Rename image_size to anomaly_map_size.

* Reverse fastflow tiling

* Remove tiling calls in fastflow

---------

Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
blaz-r and samet-akcay authored Sep 22, 2023
1 parent 5eb20c6 commit 2a04129
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
15 changes: 7 additions & 8 deletions src/anomalib/models/reverse_distillation/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,7 @@ def __init__(
self.bottleneck = get_bottleneck_layer(backbone)
self.decoder = get_decoder(backbone)

if self.tiler:
image_size = (self.tiler.tile_size_h, self.tiler.tile_size_w)
else:
image_size = input_size

self.anomaly_map_generator = AnomalyMapGenerator(image_size=image_size, mode=anomaly_map_mode)
self.anomaly_map_generator = AnomalyMapGenerator(image_size=input_size, mode=anomaly_map_mode)

def forward(self, images: Tensor) -> Tensor | list[Tensor] | tuple[list[Tensor]]:
"""Forward-pass images to the network.
Expand All @@ -73,11 +68,15 @@ def forward(self, images: Tensor) -> Tensor | list[Tensor] | tuple[list[Tensor]]
encoder_features = list(encoder_features.values())
decoder_features = self.decoder(self.bottleneck(encoder_features))

if self.tiler:
for i, features in enumerate(encoder_features):
encoder_features[i] = self.tiler.untile(features)
for i, features in enumerate(decoder_features):
decoder_features[i] = self.tiler.untile(features)

if self.training:
output = encoder_features, decoder_features
else:
output = self.anomaly_map_generator(encoder_features, decoder_features)
if self.tiler:
output = self.tiler.untile(output)

return output
17 changes: 8 additions & 9 deletions src/anomalib/models/stfpm/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,7 @@ def __init__(
for parameters in self.teacher_model.parameters():
parameters.requires_grad = False

# Create the anomaly heatmap generator whether tiling is set.
# TODO: Check whether Tiler is properly initialized here.
if self.tiler:
image_size = (self.tiler.tile_size_h, self.tiler.tile_size_w)
else:
image_size = input_size
self.anomaly_map_generator = AnomalyMapGenerator(image_size=image_size)
self.anomaly_map_generator = AnomalyMapGenerator(image_size=input_size)

def forward(self, images: Tensor) -> Tensor | dict[str, Tensor] | tuple[dict[str, Tensor]]:
"""Forward-pass images into the network.
Expand All @@ -64,11 +58,16 @@ def forward(self, images: Tensor) -> Tensor | dict[str, Tensor] | tuple[dict[str
images = self.tiler.tile(images)
teacher_features: dict[str, Tensor] = self.teacher_model(images)
student_features: dict[str, Tensor] = self.student_model(images)

if self.tiler:
for layer, data in teacher_features.items():
teacher_features[layer] = self.tiler.untile(data)
for layer, data in student_features.items():
student_features[layer] = self.tiler.untile(data)

if self.training:
output = teacher_features, student_features
else:
output = self.anomaly_map_generator(teacher_features=teacher_features, student_features=student_features)
if self.tiler:
output = self.tiler.untile(output)

return output

0 comments on commit 2a04129

Please sign in to comment.