diff --git a/README.md b/README.md index 86630683..fe4f10f4 100644 --- a/README.md +++ b/README.md @@ -258,6 +258,22 @@ Some options specifically important for evaluation are: **Note:** While it is technically possible to use multiple GPUs for running evaluation, this is strongly discouraged. If using multiple devices the `DistributedSampler` will replicate some samples to make sure all devices have the same batch size, meaning that evaluation metrics will be unreliable. This issue stems from PyTorch Lightning. See for example [this draft PR](https://github.com/Lightning-AI/torchmetrics/pull/1886) for more discussion and ongoing work to remedy this. +## Plot Model output +One can use the command-line tool `cli_plotting.py` to generate the plotting and verifying of inference results stored in `.npy` files. + +Arguments and options: + + * `--file_path`: The path to the .npy file that contains the inferred values. This argument is required. + * `--save_path`: The path where the output files will be saved. This argument is required. + * `--feature_channel` (Optional): Specifies the feature channel to use during the verification. Default is 0. + +Example Usage to run the plotting on the 10th feature channel of a .npy file located at /data/results/inference_output.npy and save the output in /outputs:: + +``` +python cli_plotting.py data/results/inference_output.npy outputs/ --feature_channel 10 +``` + + # Repository Structure Except for training and pre-processing scripts all the source code can be found in the `neural_lam` directory. Model classes, including abstract base classes, are located in `neural_lam/models`. diff --git a/cli_plotting.py b/cli_plotting.py new file mode 100644 index 00000000..61239cea --- /dev/null +++ b/cli_plotting.py @@ -0,0 +1,27 @@ +# Third-party +import click + +# First-party +from neural_lam.vis import verify_inference + + +@click.command() +@click.argument("file_path", type=click.Path(exists=True), required=True) +@click.argument("save_path", type=click.Path(), required=True) +@click.option( + "--feature_channel", + "-f", + default=0, + help="Feature channel to use. Default is 0.", + type=int, + show_default=True, +) +def main(file_path: str, save_path: str, feature_channel: int) -> None: + """ + Command line tool for verifying neural_lam inference results. + """ + verify_inference(file_path, save_path, feature_channel) + + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 56a225da..28c151ae 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -256,7 +256,7 @@ def fractional_plot_bundle(fraction): Get the tueplots bundle, but with figure width as a fraction of the page width. """ - bundle = bundles.neurips2023(usetex=False, family="serif") + bundle = bundles.neurips2023(usetex=False, family="DejaVu Sans") bundle.update(figsizes.neurips2023()) original_figsize = bundle["figure.figsize"] bundle["figure.figsize"] = ( diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 6b3e4152..04873cf6 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -1,13 +1,18 @@ +# Standard library +import os + # Third-party import cartopy.feature as cf import matplotlib import matplotlib.pyplot as plt import numpy as np import xarray as xr +from tqdm import tqdm # First-party from neural_lam import constants, utils from neural_lam.rotate_grid import unrotate_latlon +from neural_lam.weather_dataset import WeatherDataModule @matplotlib.rc_context(utils.fractional_plot_bundle(1)) @@ -179,3 +184,111 @@ def plot_spatial_error(error, title=None, vrange=None): fig.suptitle(title, size=10) return fig + + +@matplotlib.rc_context(utils.fractional_plot_bundle(1)) +def verify_inference( + file_path: str, save_path: str, feature_channel: int, vrange=None +): + """ + Plot example prediction, forecast, and ground truth. + Each has shape (N_grid,) + """ + + # Load the inference dataset for plotting + predictions_data_module = WeatherDataModule( + "cosmo", + path_verif_file=file_path, + split="verif", + standardize=False, + subset=False, + batch_size=6, + num_workers=2, + ) + predictions_data_module.setup(stage="verif") + predictions_loader = predictions_data_module.verif_dataloader() + for predictions_batch in predictions_loader: + predictions = predictions_batch[0] # tensor + break + + # Verify that feature channel is within bounds + if not 0 <= feature_channel < predictions.shape[-1]: + raise ValueError( + f"feature_channel must be between 0 and " + f"{predictions.shape[-1]-1}, inclusive." + ) + + # get test data + data_latlon = xr.open_zarr(constants.EXAMPLE_FILE).isel(time=0) + lon, lat = unrotate_latlon(data_latlon) + + # Get common scale for values + total = predictions[0, :, :, feature_channel] + total_array = np.array(total) + if vrange is None: + vmin = total_array.min() + vmax = total_array.max() + else: + vmin, vmax = float(vrange[0].cpu().item()), float( + vrange[1].cpu().item() + ) + + # Plot + for i in tqdm( + range(constants.EVAL_HORIZON - 2), desc="Plotting predictions" + ): + feature_array = ( + predictions[0, i, :, feature_channel] + .reshape(*constants.GRID_SHAPE[::-1]) + .cpu() + .numpy() + ) + data_array = np.array(feature_array) + + fig, axes = plt.subplots( + 1, + 1, + figsize=constants.FIG_SIZE, + subplot_kw={"projection": constants.SELECTED_PROJ}, + ) + + contour_set = axes.contourf( + lon, + lat, + data_array, + transform=constants.SELECTED_PROJ, + cmap="plasma", + levels=np.linspace(vmin, vmax, num=100), + ) + axes.add_feature(cf.BORDERS, linestyle="-", edgecolor="black") + axes.add_feature(cf.COASTLINE, linestyle="-", edgecolor="black") + axes.gridlines( + crs=constants.SELECTED_PROJ, + draw_labels=False, + linewidth=0.5, + alpha=0.5, + ) + + # Ticks and labels + axes.set_title("Predictions from model inference", size=15) + axes.text( + 0.5, + 1.05, + f"Feature channel {feature_channel}, time step {i}", + ha="center", + va="bottom", + transform=axes.transAxes, + fontsize=12, + ) + cbar = fig.colorbar(contour_set, orientation="horizontal", aspect=20) + cbar.ax.tick_params(labelsize=10) + + # Save the plot! + directory = os.path.dirname(save_path) + if not os.path.exists(directory): + os.makedirs(directory) + plt.savefig( + f"{save_path}feature_channel_{feature_channel}_{i}.png", + bbox_inches="tight", + ) + plt.close() diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 29a2a5cd..adfd81f3 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -4,6 +4,7 @@ from datetime import datetime, timedelta # Third-party +import numpy as np import pytorch_lightning as pl import torch import xarray as xr @@ -30,6 +31,7 @@ class WeatherDataset(torch.utils.data.Dataset): def __init__( self, dataset_name, + path_verif_file=None, split="train", standardize=True, subset=False, @@ -43,15 +45,21 @@ def __init__( "val", "test", "pred", + "verif", ), "Unknown dataset split" self.sample_dir_path = os.path.join( "data", dataset_name, "samples", split ) - print(self.sample_dir_path) self.batch_size = batch_size self.batch_index = 0 self.index_within_batch = 0 + self.sample_dir_path = path_verif_file + + if split == "verif" and os.path.exists(self.sample_dir_path): + self.np_files = np.load(self.sample_dir_path) + self.split = split + return self.zarr_files = sorted( glob.glob(os.path.join(self.sample_dir_path, "data*.zarr")) @@ -135,7 +143,7 @@ def __init__( xr.open_zarr(file, consolidated=True)[variables_2d] .to_array() .pipe( - lambda ds: ds if "z_1" in ds.dims else ds.expand_dims(z_1=[0]) + lambda ds: (ds if "z_1" in ds.dims else ds.expand_dims(z_1=[0])) ) .stack(var=("variable", "z_1")) .transpose("time", "x_1", "y_1", "var") @@ -152,7 +160,12 @@ def __init__( if standardize: ds_stats = utils.load_dataset_stats(dataset_name, "cpu") if constants.GRID_FORCING_DIM > 0: - self.data_mean, self.data_std, self.flux_mean, self.flux_std = ( + ( + self.data_mean, + self.data_std, + self.flux_mean, + self.flux_std, + ) = ( ds_stats["data_mean"], ds_stats["data_std"], ds_stats["flux_mean"], @@ -163,7 +176,6 @@ def __init__( ds_stats["data_mean"], ds_stats["data_std"], ) - self.random_subsample = split == "train" self.split = split @@ -173,10 +185,15 @@ def __len__(self): if self.split == "train" else constants.EVAL_HORIZON ) - total_time = len(self.zarr_files) * constants.CHUNK_SIZE - num_steps + total_time = 1 + if hasattr(self, "zarr_files"): + total_time = len(self.zarr_files) * constants.CHUNK_SIZE - num_steps return total_time def __getitem__(self, idx): + if self.split == "verif": + return self.np_files + num_steps = ( constants.TRAIN_HORIZON if self.split == "train" @@ -190,7 +207,8 @@ def __getitem__(self, idx): idx_sample = idx % constants.CHUNK_SIZE sample_archive = xr.concat( - self.zarr_datasets[start_file_idx : end_file_idx + 1], dim="time" + self.zarr_datasets[start_file_idx : end_file_idx + 1], + dim="time", ) sample_xr = sample_archive.isel( @@ -220,6 +238,7 @@ def __init__( self, dataset_name, split="train", + path_verif_file=None, standardize=True, subset=False, batch_size=4, @@ -227,6 +246,7 @@ def __init__( ): super().__init__() self.dataset_name = dataset_name + self.path_verif_file = path_verif_file self.batch_size = batch_size self.num_workers = num_workers self.standardize = standardize @@ -237,7 +257,7 @@ def prepare_data(self): pass def setup(self, stage=None): - # make assignments here (val/train/test/predict split) + # make assignments here (val/train/test/pred split) # called on every process in DDP if stage == "fit" or stage is None: self.train_dataset = WeatherDataset( @@ -264,8 +284,18 @@ def setup(self, stage=None): batch_size=self.batch_size, ) - if stage == "predict" or stage is None: - self.predict_dataset = WeatherDataset( + if stage == "verif": + self.verif_dataset = WeatherDataset( + self.dataset_name, + self.path_verif_file, + split="verif", + standardize=False, + subset=False, + batch_size=self.batch_size, + ) + + if stage == "pred" or stage is None: + self.pred_dataset = WeatherDataset( self.dataset_name, split="pred", standardize=self.standardize, @@ -274,6 +304,7 @@ def setup(self, stage=None): ) def train_dataloader(self): + """Load train dataset.""" return torch.utils.data.DataLoader( self.train_dataset, batch_size=self.batch_size, @@ -283,6 +314,7 @@ def train_dataloader(self): ) def val_dataloader(self): + """Load validation dataset.""" return torch.utils.data.DataLoader( self.val_dataset, batch_size=self.batch_size // self.batch_size, @@ -292,6 +324,7 @@ def val_dataloader(self): ) def test_dataloader(self): + """Load test dataset.""" return torch.utils.data.DataLoader( self.test_dataset, batch_size=self.batch_size, @@ -300,11 +333,21 @@ def test_dataloader(self): pin_memory=False, ) - def predict_dataloader(self): + def pred_dataloader(self): + """Load prediction dataset.""" return torch.utils.data.DataLoader( - self.predict_dataset, + self.pred_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, pin_memory=False, ) + + def verif_dataloader(self): + """Load inference output dataset.""" + return torch.utils.data.DataLoader( + self.verif_dataset, + batch_size=1, + shuffle=False, + pin_memory=False, + )