From 45f1f6bf38bbcd8c90a950f28223c0a0f7f7b825 Mon Sep 17 00:00:00 2001 From: Ewan Thompson Date: Tue, 12 Nov 2024 00:47:40 +0800 Subject: [PATCH] [BUG] fix `AttributeError: 'ExperimentWriter' object has no attribute 'add_figure'` (#1694) ### Description This PR is a bugfix for #1256. At present, the `AttributeError` is raised (and the script exits) if the logger used does not have the `add_figure` method. This fix adds a simple check prior to calling `add_figure`. --- pytorch_forecasting/models/base_model.py | 19 ++++++++++++++++++- pytorch_forecasting/models/nbeats/__init__.py | 3 ++- pytorch_forecasting/models/nhits/__init__.py | 3 ++- .../temporal_fusion_transformer/__init__.py | 8 +++++++- 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index 77d45197..3068f558 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -955,6 +955,18 @@ def log_interval(self) -> float: else: return self.hparams.log_val_interval + def _logger_supports(self, method: str) -> bool: + """Whether logger supports method. + + Returns + ------- + supports_method : bool + True if attribute self.logger.experiment.method exists, False otherwise. + """ + if not hasattr(self, "logger") or not hasattr(self.logger, "experiment"): + return False + return hasattr(self.logger.experiment, method) + def log_prediction( self, x: Dict[str, torch.Tensor], out: Dict[str, torch.Tensor], batch_idx: int, **kwargs ) -> None: @@ -981,6 +993,10 @@ def log_prediction( if not mpl_available: return None # don't log matplotlib plots if not available + # Don't log figures if add_figure is not available + if not self._logger_supports("add_figure"): + return None + for idx in log_indices: fig = self.plot_prediction(x, out, idx=idx, add_loss_to_title=True, **kwargs) tag = f"{self.current_stage} prediction" @@ -1156,7 +1172,8 @@ def log_gradient_flow(self, named_parameters: Dict[str, torch.Tensor]) -> None: mpl_available = _check_matplotlib("log_gradient_flow", raise_error=False) - if not mpl_available: + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): return None import matplotlib.pyplot as plt diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index 1aeedcf2..8d00392c 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -277,7 +277,8 @@ def log_interpretation(self, x, out, batch_idx): """ mpl_available = _check_matplotlib("log_interpretation", raise_error=False) - if not mpl_available: + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): return None label = ["val", "train"][self.training] diff --git a/pytorch_forecasting/models/nhits/__init__.py b/pytorch_forecasting/models/nhits/__init__.py index 53cf1892..6d790213 100644 --- a/pytorch_forecasting/models/nhits/__init__.py +++ b/pytorch_forecasting/models/nhits/__init__.py @@ -567,7 +567,8 @@ def log_interpretation(self, x, out, batch_idx): """ mpl_available = _check_matplotlib("log_interpretation", raise_error=False) - if not mpl_available: + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): return None label = ["val", "train"][self.training] diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index e260b928..5a5ebeac 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -846,7 +846,8 @@ def log_interpretation(self, outputs): mpl_available = _check_matplotlib("log_interpretation", raise_error=False) - if not mpl_available: + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): return None import matplotlib.pyplot as plt @@ -885,6 +886,11 @@ def log_embeddings(self): """ Log embeddings to tensorboard """ + + # Don't log embeddings if add_embedding is not available + if not self._logger_supports("add_embedding"): + return None + for name, emb in self.input_embeddings.items(): labels = self.hparams.embedding_labels[name] self.logger.experiment.add_embedding(