Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix evaluation example visualisation plots #91

Merged
merged 10 commits into from
Dec 4, 2024
9 changes: 7 additions & 2 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,13 @@ def get_xy_extent(self, category: str) -> List[float]:
The extent of the x, y coordinates.

"""
xy = self.get_xy(category, stacked=False)
extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
xy = self.get_xy(category, stacked=True)
extent = [
xy[:, 0].min(),
xy[:, 0].max(),
xy[:, 1].min(),
xy[:, 1].max(),
]
return [float(v) for v in extent]

@property
Expand Down
5 changes: 4 additions & 1 deletion neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Standard library
import copy
import warnings
from functools import cached_property
from pathlib import Path
Expand Down Expand Up @@ -394,7 +395,9 @@ def coords_projection(self) -> ccrs.Projection:

class_name = projection_info["class_name"]
ProjectionClass = getattr(ccrs, class_name)
kwargs = projection_info["kwargs"]
# need to copy otherwise we modify the dict stored in the dataclass
# in-place
kwargs = copy.deepcopy(projection_info["kwargs"])

globe_kwargs = kwargs.pop("globe", {})
if len(globe_kwargs) > 0:
Expand Down
77 changes: 68 additions & 9 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
# Standard library
import os
from typing import List, Union

# Third-party
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import wandb
import xarray as xr

# Local
from .. import metrics, vis
from ..config import NeuralLAMConfig
from ..datastore import BaseDatastore
from ..loss_weighting import get_state_feature_weighting
from ..weather_dataset import WeatherDataset


class ARModel(pl.LightningModule):
Expand Down Expand Up @@ -147,6 +150,42 @@ def __init__(
# For storing spatial loss maps during evaluation
self.spatial_loss_maps = []

def _create_dataarray_from_tensor(
self,
tensor: torch.Tensor,
time: Union[int, List[int]],
split: str,
category: str,
) -> xr.DataArray:
"""
Create an `xr.DataArray` from a tensor, with the correct dimensions and
coordinates to match the datastore used by the model. This function in
in effect is the inverse of what is returned by
`WeatherDataset.__getitem__`.

Parameters
----------
tensor : torch.Tensor
The tensor to convert to a `xr.DataArray` with dimensions [time,
grid_index, feature]
time : Union[int,List[int]]
The time index or indices for the data, given as integers or a list
of integers representing epoch time in nanoseconds.
split : str
The split of the data, either 'train', 'val', or 'test'
category : str
The category of the data, either 'state' or 'forcing'
"""
# TODO: creating an instance of WeatherDataset here on every call is
# not how this should be done but whether WeatherDataset should be
# provided to ARModel or where to put plotting still needs discussion
weather_dataset = WeatherDataset(datastore=self._datastore, split=split)
joeloskarsson marked this conversation as resolved.
Show resolved Hide resolved
time = np.array(time, dtype="datetime64[ns]")
khintz marked this conversation as resolved.
Show resolved Hide resolved
da = weather_dataset.create_dataarray_from_tensor(
tensor=tensor, time=time, category=category
)
return da

khintz marked this conversation as resolved.
Show resolved Hide resolved
def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(), lr=self.args.lr, betas=(0.9, 0.95)
Expand Down Expand Up @@ -406,10 +445,13 @@ def test_step(self, batch, batch_idx):
)

self.plot_examples(
batch, n_additional_examples, prediction=prediction
batch,
n_additional_examples,
prediction=prediction,
split="test",
)

def plot_examples(self, batch, n_examples, prediction=None):
def plot_examples(self, batch, n_examples, split, prediction=None):
"""
Plot the first n_examples forecasts from batch

Expand All @@ -422,18 +464,34 @@ def plot_examples(self, batch, n_examples, prediction=None):
prediction, target, _, _ = self.common_step(batch)

target = batch[1]
time = batch[3]
khintz marked this conversation as resolved.
Show resolved Hide resolved

# Rescale to original data scale
prediction_rescaled = prediction * self.state_std + self.state_mean
target_rescaled = target * self.state_std + self.state_mean

# Iterate over the examples
for pred_slice, target_slice in zip(
prediction_rescaled[:n_examples], target_rescaled[:n_examples]
for pred_slice, target_slice, time_slice in zip(
prediction_rescaled[:n_examples],
target_rescaled[:n_examples],
time[:n_examples],
):
# Each slice is (pred_steps, num_grid_nodes, d_f)
self.plotted_examples += 1 # Increment already here

da_prediction = self._create_dataarray_from_tensor(
tensor=pred_slice,
time=time_slice,
split=split,
category="state",
).unstack("grid_index")
da_target = self._create_dataarray_from_tensor(
tensor=target_slice,
time=time_slice,
split=split,
category="state",
).unstack("grid_index")

var_vmin = (
torch.minimum(
pred_slice.flatten(0, 1).min(dim=0)[0],
Expand All @@ -453,18 +511,18 @@ def plot_examples(self, batch, n_examples, prediction=None):
var_vranges = list(zip(var_vmin, var_vmax))

# Iterate over prediction horizon time steps
for t_i, (pred_t, target_t) in enumerate(
zip(pred_slice, target_slice), start=1
):
for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1):
# Create one figure per variable at this time step
var_figs = [
vis.plot_prediction(
pred=pred_t[:, var_i],
target=target_t[:, var_i],
datastore=self._datastore,
title=f"{var_name} ({var_unit}), "
f"t={t_i} ({self._datastore.step_length * t_i} h)",
vrange=var_vrange,
da_prediction=da_prediction.isel(
state_feature=var_i
).squeeze(),
da_target=da_target.isel(state_feature=var_i).squeeze(),
)
for var_i, (var_name, var_unit, var_vrange) in enumerate(
zip(
Expand All @@ -476,6 +534,7 @@ def plot_examples(self, batch, n_examples, prediction=None):
]

example_i = self.plotted_examples

wandb.log(
{
f"{var_name}_example_{example_i}": wandb.Image(fig)
Expand Down
41 changes: 17 additions & 24 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None):

@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_prediction(
pred,
target,
datastore: BaseRegularGridDatastore,
da_prediction=None,
da_target=None,
khintz marked this conversation as resolved.
Show resolved Hide resolved
title=None,
vrange=None,
):
Expand All @@ -79,19 +79,17 @@ def plot_prediction(
"""
# Get common scale for values
if vrange is None:
vmin = min(vals.min().cpu().item() for vals in (pred, target))
vmax = max(vals.max().cpu().item() for vals in (pred, target))
vmin = min(da_prediction.min(), da_target.min())
vmax = max(da_prediction.max(), da_target.max())
else:
vmin, vmax = vrange

extent = datastore.get_xy_extent("state")

# Set up masking of border region
da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
mask_reshaped = da_mask.values
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region
mask_values = np.invert(da_mask.values.astype(bool)).astype(float)
pixel_alpha = mask_values.clip(0.7, 1) # Faded border region

fig, axes = plt.subplots(
1,
Expand All @@ -101,28 +99,23 @@ def plot_prediction(
)

# Plot pred and target
for ax, data in zip(axes, (target, pred)):
for ax, da in zip(axes, (da_target, da_prediction)):
ax.coastlines() # Add coastline outlines
data_grid = (
data.reshape(list(datastore.grid_shape_state.values.values()))
.cpu()
.numpy()
)
im = ax.imshow(
data_grid,
da.plot.imshow(
ax=ax,
origin="lower",
x="x",
extent=extent,
alpha=pixel_alpha,
alpha=pixel_alpha.T,
vmin=vmin,
vmax=vmax,
cmap="plasma",
transform=datastore.coords_projection,
)

# Ticks and labels
axes[0].set_title("Ground Truth", size=15)
axes[1].set_title("Prediction", size=15)
cbar = fig.colorbar(im, aspect=30)
cbar.ax.tick_params(labelsize=10)

if title:
fig.suptitle(title, size=20)
Expand Down Expand Up @@ -150,9 +143,7 @@ def plot_spatial_error(
# Set up masking of border region
da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
mask_reshaped = da_mask.values
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region
pixel_alpha = mask_reshaped.clip(0.7, 1) # Faded border region

fig, ax = plt.subplots(
figsize=(5, 4.8),
Expand All @@ -161,8 +152,10 @@ def plot_spatial_error(

ax.coastlines() # Add coastline outlines
error_grid = (
error.reshape(list(datastore.grid_shape_state.values.values()))
.cpu()
error.reshape(
[datastore.grid_shape_state.x, datastore.grid_shape_state.y]
)
.T.cpu()
.numpy()
)

Expand Down
Loading