Skip to content

Commit

Permalink
[BUG] fix `AttributeError: 'ExperimentWriter' object has no attribute…
Browse files Browse the repository at this point in the history
… 'add_figure'` (#1694)

### Description

This PR is a bugfix for #1256.

At present, the `AttributeError` is raised (and the script exits) if the
logger used does not have the `add_figure` method.

This fix adds a simple check prior to calling `add_figure`.
  • Loading branch information
ewth authored Nov 11, 2024
1 parent c44cbb1 commit 45f1f6b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 4 deletions.
19 changes: 18 additions & 1 deletion pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,18 @@ def log_interval(self) -> float:
else:
return self.hparams.log_val_interval

def _logger_supports(self, method: str) -> bool:
"""Whether logger supports method.
Returns
-------
supports_method : bool
True if attribute self.logger.experiment.method exists, False otherwise.
"""
if not hasattr(self, "logger") or not hasattr(self.logger, "experiment"):
return False
return hasattr(self.logger.experiment, method)

def log_prediction(
self, x: Dict[str, torch.Tensor], out: Dict[str, torch.Tensor], batch_idx: int, **kwargs
) -> None:
Expand All @@ -981,6 +993,10 @@ def log_prediction(
if not mpl_available:
return None # don't log matplotlib plots if not available

# Don't log figures if add_figure is not available
if not self._logger_supports("add_figure"):
return None

for idx in log_indices:
fig = self.plot_prediction(x, out, idx=idx, add_loss_to_title=True, **kwargs)
tag = f"{self.current_stage} prediction"
Expand Down Expand Up @@ -1156,7 +1172,8 @@ def log_gradient_flow(self, named_parameters: Dict[str, torch.Tensor]) -> None:

mpl_available = _check_matplotlib("log_gradient_flow", raise_error=False)

if not mpl_available:
# Don't log figures if matplotlib or add_figure is not available
if not mpl_available or not self._logger_supports("add_figure"):
return None

import matplotlib.pyplot as plt
Expand Down
3 changes: 2 additions & 1 deletion pytorch_forecasting/models/nbeats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ def log_interpretation(self, x, out, batch_idx):
"""
mpl_available = _check_matplotlib("log_interpretation", raise_error=False)

if not mpl_available:
# Don't log figures if matplotlib or add_figure is not available
if not mpl_available or not self._logger_supports("add_figure"):
return None

label = ["val", "train"][self.training]
Expand Down
3 changes: 2 additions & 1 deletion pytorch_forecasting/models/nhits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,8 @@ def log_interpretation(self, x, out, batch_idx):
"""
mpl_available = _check_matplotlib("log_interpretation", raise_error=False)

if not mpl_available:
# Don't log figures if matplotlib or add_figure is not available
if not mpl_available or not self._logger_supports("add_figure"):
return None

label = ["val", "train"][self.training]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,8 @@ def log_interpretation(self, outputs):

mpl_available = _check_matplotlib("log_interpretation", raise_error=False)

if not mpl_available:
# Don't log figures if matplotlib or add_figure is not available
if not mpl_available or not self._logger_supports("add_figure"):
return None

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -885,6 +886,11 @@ def log_embeddings(self):
"""
Log embeddings to tensorboard
"""

# Don't log embeddings if add_embedding is not available
if not self._logger_supports("add_embedding"):
return None

for name, emb in self.input_embeddings.items():
labels = self.hparams.embedding_labels[name]
self.logger.experiment.add_embedding(
Expand Down

0 comments on commit 45f1f6b

Please sign in to comment.