From 5904cbe9da67d3e98eaab0cebd501a2ad0ded7f3 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Mon, 25 Nov 2024 16:42:21 +0100 Subject: [PATCH 01/10] identified issue, cleanup next --- neural_lam/datastore/base.py | 9 ++++- neural_lam/datastore/mdp.py | 5 ++- neural_lam/models/ar_model.py | 46 ++++++++++++++++++++-- neural_lam/train_model.py | 2 +- neural_lam/vis.py | 73 +++++++++++++++++++++++++---------- 5 files changed, 107 insertions(+), 28 deletions(-) diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 0317c2e5..b0055e39 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -295,8 +295,13 @@ def get_xy_extent(self, category: str) -> List[float]: The extent of the x, y coordinates. """ - xy = self.get_xy(category, stacked=False) - extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()] + xy = self.get_xy(category, stacked=True) + extent = [ + xy[:, 0].min(), + xy[:, 0].max(), + xy[:, 1].min(), + xy[:, 1].max(), + ] return [float(v) for v in extent] @property diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 10593a82..0d1aac7b 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -1,4 +1,5 @@ # Standard library +import copy import warnings from functools import cached_property from pathlib import Path @@ -394,7 +395,9 @@ def coords_projection(self) -> ccrs.Projection: class_name = projection_info["class_name"] ProjectionClass = getattr(ccrs, class_name) - kwargs = projection_info["kwargs"] + # need to copy otherwise we modify the dict stored in the dataclass + # in-place + kwargs = copy.deepcopy(projection_info["kwargs"]) globe_kwargs = kwargs.pop("globe", {}) if len(globe_kwargs) > 0: diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index bc4c6719..b55143f0 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -7,12 +7,14 @@ import pytorch_lightning as pl import torch import wandb +from loguru import logger # Local from .. import metrics, vis from ..config import NeuralLAMConfig from ..datastore import BaseDatastore from ..loss_weighting import get_state_feature_weighting +from ..weather_dataset import WeatherDataset class ARModel(pl.LightningModule): @@ -147,6 +149,14 @@ def __init__( # For storing spatial loss maps during evaluation self.spatial_loss_maps = [] + def _create_dataarray_from_tensor(self, tensor, time, split, category): + weather_dataset = WeatherDataset(datastore=self._datastore, split=split) + time = np.array(time, dtype="datetime64[ns]") + da = weather_dataset.create_dataarray_from_tensor( + tensor=tensor, time=time, category=category + ) + return da + def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) @@ -406,10 +416,13 @@ def test_step(self, batch, batch_idx): ) self.plot_examples( - batch, n_additional_examples, prediction=prediction + batch, + n_additional_examples, + prediction=prediction, + split="test", ) - def plot_examples(self, batch, n_examples, prediction=None): + def plot_examples(self, batch, n_examples, split, prediction=None): """ Plot the first n_examples forecasts from batch @@ -422,18 +435,34 @@ def plot_examples(self, batch, n_examples, prediction=None): prediction, target, _, _ = self.common_step(batch) target = batch[1] + time = batch[3] # Rescale to original data scale prediction_rescaled = prediction * self.state_std + self.state_mean target_rescaled = target * self.state_std + self.state_mean # Iterate over the examples - for pred_slice, target_slice in zip( - prediction_rescaled[:n_examples], target_rescaled[:n_examples] + for pred_slice, target_slice, time_slice in zip( + prediction_rescaled[:n_examples], + target_rescaled[:n_examples], + time[:n_examples], ): # Each slice is (pred_steps, num_grid_nodes, d_f) self.plotted_examples += 1 # Increment already here + da_prediction = self._create_dataarray_from_tensor( + tensor=pred_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + da_target = self._create_dataarray_from_tensor( + tensor=target_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + var_vmin = ( torch.minimum( pred_slice.flatten(0, 1).min(dim=0)[0], @@ -465,6 +494,10 @@ def plot_examples(self, batch, n_examples, prediction=None): title=f"{var_name} ({var_unit}), " f"t={t_i} ({self._datastore.step_length * t_i} h)", vrange=var_vrange, + da_prediction=da_prediction.isel( + state_feature=var_i + ).squeeze(), + da_target=da_target.isel(state_feature=var_i).squeeze(), ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( @@ -476,6 +509,11 @@ def plot_examples(self, batch, n_examples, prediction=None): ] example_i = self.plotted_examples + for i, fig in enumerate(var_figs): + fn = f"example_{i}_{example_i}_t{t_i}.png" + fig.savefig(fn) + logger.info(f"Saved example plot to {fn}") + wandb.log( { f"{var_name}_example_{example_i}": wandb.Image(fig) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 74146c89..9d1d5039 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -23,7 +23,7 @@ } -@logger.catch +@logger.catch(reraise=True) def main(input_args=None): """Main function for training and evaluating models.""" parser = ArgumentParser( diff --git a/neural_lam/vis.py b/neural_lam/vis.py index b9d18b39..357a8977 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -68,6 +68,8 @@ def plot_prediction( pred, target, datastore: BaseRegularGridDatastore, + da_prediction=None, + da_target=None, title=None, vrange=None, ): @@ -88,10 +90,8 @@ def plot_prediction( # Set up masking of border region da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region + mask_values = np.invert(da_mask.values.astype(bool)).astype(float) + pixel_alpha = mask_values.clip(0.7, 1) # Faded border region fig, axes = plt.subplots( 1, @@ -100,29 +100,62 @@ def plot_prediction( subplot_kw={"projection": datastore.coords_projection}, ) + use_xarray = True + # Plot pred and target - for ax, data in zip(axes, (target, pred)): + + if not use_xarray: + for ax, data in zip(axes, (target, pred)): + ax.coastlines() # Add coastline outlines + data_grid = ( + data.reshape( + [datastore.grid_shape_state.x, datastore.grid_shape_state.y] + ) + .T.cpu() + .numpy() + ) + im = ax.imshow( + data_grid, + origin="lower", + extent=extent, + alpha=pixel_alpha, + vmin=vmin, + vmax=vmax, + cmap="plasma", + ) + + cbar = fig.colorbar(im, aspect=30) + cbar.ax.tick_params(labelsize=10) + + x = da_target.x.values + y = da_target.y.values + extent = [x.min(), x.max(), y.min(), y.max()] + for ax, da in zip(axes, (da_target, da_prediction)): ax.coastlines() # Add coastline outlines - data_grid = ( - data.reshape(list(datastore.grid_shape_state.values.values())) - .cpu() - .numpy() - ) - im = ax.imshow( - data_grid, + im = da.plot.imshow( + ax=ax, origin="lower", + x="x", extent=extent, - alpha=pixel_alpha, + alpha=pixel_alpha.T, vmin=vmin, vmax=vmax, cmap="plasma", + transform=datastore.coords_projection, ) + # da.plot.pcolormesh( + # ax=ax, + # x="x", + # vmin=vmin, + # vmax=vmax, + # transform=datastore.coords_projection, + # cmap="plasma", + # ) + # Ticks and labels axes[0].set_title("Ground Truth", size=15) axes[1].set_title("Prediction", size=15) - cbar = fig.colorbar(im, aspect=30) - cbar.ax.tick_params(labelsize=10) if title: fig.suptitle(title, size=20) @@ -150,9 +183,7 @@ def plot_spatial_error( # Set up masking of border region da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region + pixel_alpha = mask_reshaped.clip(0.7, 1) # Faded border region fig, ax = plt.subplots( figsize=(5, 4.8), @@ -161,8 +192,10 @@ def plot_spatial_error( ax.coastlines() # Add coastline outlines error_grid = ( - error.reshape(list(datastore.grid_shape_state.values.values())) - .cpu() + error.reshape( + [datastore.grid_shape_state.x, datastore.grid_shape_state.y] + ) + .T.cpu() .numpy() ) From efe03027842a22139d6554d68ffee7b6ebe0ad73 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 26 Nov 2024 13:46:05 +0100 Subject: [PATCH 02/10] use xarray plot only --- neural_lam/models/ar_model.py | 47 +++++++++++++++++++++++++++-------- neural_lam/vis.py | 43 +++----------------------------- 2 files changed, 39 insertions(+), 51 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index b55143f0..0af25367 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -1,5 +1,6 @@ # Standard library import os +from typing import List, Union # Third-party import matplotlib.pyplot as plt @@ -7,7 +8,7 @@ import pytorch_lightning as pl import torch import wandb -from loguru import logger +import xarray as xr # Local from .. import metrics, vis @@ -149,7 +150,35 @@ def __init__( # For storing spatial loss maps during evaluation self.spatial_loss_maps = [] - def _create_dataarray_from_tensor(self, tensor, time, split, category): + def _create_dataarray_from_tensor( + self, + tensor: torch.Tensor, + time: Union[int, List[int]], + split: str, + category: str, + ) -> xr.DataArray: + """ + Create an `xr.DataArray` from a tensor, with the correct dimensions and + coordinates to match the datastore used by the model. This function in + in effect is the inverse of what is returned by + `WeatherDataset.__getitem__`. + + Parameters + ---------- + tensor : torch.Tensor + The tensor to convert to a `xr.DataArray` with dimensions [time, + grid_index, feature] + time : Union[int,List[int]] + The time index or indices for the data, given as integers or a list + of integers representing epoch time in nanoseconds. + split : str + The split of the data, either 'train', 'val', or 'test' + category : str + The category of the data, either 'state' or 'forcing' + """ + # TODO: creating an instance of WeatherDataset here on every call is + # not how this should be done but whether WeatherDataset should be + # provided to ARModel or where to put plotting still needs discussion weather_dataset = WeatherDataset(datastore=self._datastore, split=split) time = np.array(time, dtype="datetime64[ns]") da = weather_dataset.create_dataarray_from_tensor( @@ -482,14 +511,10 @@ def plot_examples(self, batch, n_examples, split, prediction=None): var_vranges = list(zip(var_vmin, var_vmax)) # Iterate over prediction horizon time steps - for t_i, (pred_t, target_t) in enumerate( - zip(pred_slice, target_slice), start=1 - ): + for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1): # Create one figure per variable at this time step var_figs = [ vis.plot_prediction( - pred=pred_t[:, var_i], - target=target_t[:, var_i], datastore=self._datastore, title=f"{var_name} ({var_unit}), " f"t={t_i} ({self._datastore.step_length * t_i} h)", @@ -509,10 +534,10 @@ def plot_examples(self, batch, n_examples, split, prediction=None): ] example_i = self.plotted_examples - for i, fig in enumerate(var_figs): - fn = f"example_{i}_{example_i}_t{t_i}.png" - fig.savefig(fn) - logger.info(f"Saved example plot to {fn}") + # for i, fig in enumerate(var_figs): + # fn = f"example_{i}_{example_i}_t{t_i}.png" + # fig.savefig(fn) + # logger.info(f"Saved example plot to {fn}") wandb.log( { diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 357a8977..47c68e4f 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -65,8 +65,6 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( - pred, - target, datastore: BaseRegularGridDatastore, da_prediction=None, da_target=None, @@ -81,8 +79,8 @@ def plot_prediction( """ # Get common scale for values if vrange is None: - vmin = min(vals.min().cpu().item() for vals in (pred, target)) - vmax = max(vals.max().cpu().item() for vals in (pred, target)) + vmin = min(da_prediction.min(), da_target.min()) + vmax = max(da_prediction.max(), da_target.max()) else: vmin, vmax = vrange @@ -100,39 +98,13 @@ def plot_prediction( subplot_kw={"projection": datastore.coords_projection}, ) - use_xarray = True - # Plot pred and target - - if not use_xarray: - for ax, data in zip(axes, (target, pred)): - ax.coastlines() # Add coastline outlines - data_grid = ( - data.reshape( - [datastore.grid_shape_state.x, datastore.grid_shape_state.y] - ) - .T.cpu() - .numpy() - ) - im = ax.imshow( - data_grid, - origin="lower", - extent=extent, - alpha=pixel_alpha, - vmin=vmin, - vmax=vmax, - cmap="plasma", - ) - - cbar = fig.colorbar(im, aspect=30) - cbar.ax.tick_params(labelsize=10) - x = da_target.x.values y = da_target.y.values extent = [x.min(), x.max(), y.min(), y.max()] for ax, da in zip(axes, (da_target, da_prediction)): ax.coastlines() # Add coastline outlines - im = da.plot.imshow( + da.plot.imshow( ax=ax, origin="lower", x="x", @@ -144,15 +116,6 @@ def plot_prediction( transform=datastore.coords_projection, ) - # da.plot.pcolormesh( - # ax=ax, - # x="x", - # vmin=vmin, - # vmax=vmax, - # transform=datastore.coords_projection, - # cmap="plasma", - # ) - # Ticks and labels axes[0].set_title("Ground Truth", size=15) axes[1].set_title("Prediction", size=15) From a489c2ed974397ea230d2e61b842d8d9384867dc Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 26 Nov 2024 14:07:06 +0100 Subject: [PATCH 03/10] don't reraise --- neural_lam/train_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 9d1d5039..74146c89 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -23,7 +23,7 @@ } -@logger.catch(reraise=True) +@logger.catch def main(input_args=None): """Main function for training and evaluating models.""" parser = ArgumentParser( From 242d08bcb5374cdd90aecfd49f501ed233f1ce0c Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 26 Nov 2024 14:50:03 +0100 Subject: [PATCH 04/10] remove debug plot --- neural_lam/models/ar_model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 0af25367..c875688b 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -534,10 +534,6 @@ def plot_examples(self, batch, n_examples, split, prediction=None): ] example_i = self.plotted_examples - # for i, fig in enumerate(var_figs): - # fn = f"example_{i}_{example_i}_t{t_i}.png" - # fig.savefig(fn) - # logger.info(f"Saved example plot to {fn}") wandb.log( { From c1f706c29542d770ed49e910f8b9bd5caff1fdec Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 26 Nov 2024 16:04:24 +0100 Subject: [PATCH 05/10] remove extent calc used in diagnosing issue --- neural_lam/vis.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 47c68e4f..c814aacf 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -99,9 +99,6 @@ def plot_prediction( ) # Plot pred and target - x = da_target.x.values - y = da_target.y.values - extent = [x.min(), x.max(), y.min(), y.max()] for ax, da in zip(axes, (da_target, da_prediction)): ax.coastlines() # Add coastline outlines da.plot.imshow( From cf8e3e4c1be93a6ec074368aaf6f91c8042b5278 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Fri, 29 Nov 2024 14:51:36 +0100 Subject: [PATCH 06/10] add type annotation --- neural_lam/vis.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index c814aacf..d6b57f88 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -2,6 +2,7 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np +import xarray as xr # Local from . import utils @@ -66,8 +67,8 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( datastore: BaseRegularGridDatastore, - da_prediction=None, - da_target=None, + da_prediction: xr.DataArray = None, + da_target: xr.DataArray = None, title=None, vrange=None, ): From 85160cecf13ecfc9fc6a589ac1a9e3542da45e23 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Fri, 29 Nov 2024 15:03:06 +0100 Subject: [PATCH 07/10] ensure tensor copy to cpu mem before data-array creation --- neural_lam/models/ar_model.py | 10 ++++++---- neural_lam/weather_dataset.py | 5 +++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index c875688b..0d8e6e3c 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -167,10 +167,12 @@ def _create_dataarray_from_tensor( ---------- tensor : torch.Tensor The tensor to convert to a `xr.DataArray` with dimensions [time, - grid_index, feature] + grid_index, feature]. The tensor will be copied to the CPU if it is + not already there. time : Union[int,List[int]] The time index or indices for the data, given as integers or a list - of integers representing epoch time in nanoseconds. + of integers representing epoch time in nanoseconds. The ints will be + copied to the CPU memory if they are not already there. split : str The split of the data, either 'train', 'val', or 'test' category : str @@ -180,9 +182,9 @@ def _create_dataarray_from_tensor( # not how this should be done but whether WeatherDataset should be # provided to ARModel or where to put plotting still needs discussion weather_dataset = WeatherDataset(datastore=self._datastore, split=split) - time = np.array(time, dtype="datetime64[ns]") + time = np.array(time.cpu(), dtype="datetime64[ns]") da = weather_dataset.create_dataarray_from_tensor( - tensor=tensor, time=time, category=category + tensor=tensor.cpu().numpy(), time=time, category=category ) return da diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 532e3c90..b5f85580 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -529,7 +529,8 @@ def create_dataarray_from_tensor( tensor : torch.Tensor The tensor to construct the DataArray from, this assumed to have the same dimension ordering as returned by the __getitem__ method - (i.e. time, grid_index, {category}_feature). + (i.e. time, grid_index, {category}_feature). The tensor will be + copied to the CPU before constructing the DataArray. time : datetime.datetime or list[datetime.datetime] The time or times of the tensor. category : str @@ -581,7 +582,7 @@ def _is_listlike(obj): coords["time"] = time da = xr.DataArray( - tensor.numpy(), + tensor.cpu().numpy(), dims=dims, coords=coords, ) From 52c452879f56c7f982cfd5d55a5259f37cb6b030 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Fri, 29 Nov 2024 15:05:36 +0100 Subject: [PATCH 08/10] apply time-indexing to support ar_steps_val > 1 --- neural_lam/models/ar_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 0d8e6e3c..44baf9c2 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -522,9 +522,11 @@ def plot_examples(self, batch, n_examples, split, prediction=None): f"t={t_i} ({self._datastore.step_length * t_i} h)", vrange=var_vrange, da_prediction=da_prediction.isel( - state_feature=var_i + state_feature=var_i, time=t_i - 1 + ).squeeze(), + da_target=da_target.isel( + state_feature=var_i, time=t_i - 1 ).squeeze(), - da_target=da_target.isel(state_feature=var_i).squeeze(), ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( From 6205dbd88f1b208118d93da6d12c0a1be672caef Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Mon, 2 Dec 2024 10:26:54 +0100 Subject: [PATCH 09/10] pin dataclass-wizard <0.31.0 to avoid bug in dataclass-wizard --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f0bc0851..fdcb7f3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "torch>=2.3.0", "torch-geometric==2.3.1", "parse>=1.20.2", - "dataclass-wizard>=0.22.3", + "dataclass-wizard<0.31.0", "mllam-data-prep>=0.5.0", ] requires-python = ">=3.9" From 2aa0a1aade7f86d2158b07a0ccc33cd560ac5166 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 3 Dec 2024 09:20:42 +0100 Subject: [PATCH 10/10] changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 12cf54f6..01d4cac9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [\#66](https://github.com/mllam/neural-lam/pull/66) @leifdenby @sadamov +### Fixed + +- Fix bugs introduced with datastores functionality relating visualation plots [\#91](https://github.com/mllam/neural-lam/pull/91) @leifdenby + ## [v0.2.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.2.0) ### Added