From a697acff0478a296188b36baea2b2528be780bf2 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Wed, 10 Apr 2024 13:54:42 +0200 Subject: [PATCH 01/13] Adding a verification function for inference --- neural_lam/utils.py | 2 +- neural_lam/vis.py | 89 ++++++++++ neural_lam/weather_dataset.py | 313 +++++++++++++++++++--------------- 3 files changed, 265 insertions(+), 139 deletions(-) diff --git a/neural_lam/utils.py b/neural_lam/utils.py index e21fe083..56a225da 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -256,7 +256,7 @@ def fractional_plot_bundle(fraction): Get the tueplots bundle, but with figure width as a fraction of the page width. """ - bundle = bundles.neurips2023(usetex=True, family="serif") + bundle = bundles.neurips2023(usetex=False, family="serif") bundle.update(figsizes.neurips2023()) original_figsize = bundle["figure.figsize"] bundle["figure.figsize"] = ( diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 6b3e4152..a2cf4f58 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -1,4 +1,5 @@ # Third-party +import os import cartopy.feature as cf import matplotlib import matplotlib.pyplot as plt @@ -8,6 +9,7 @@ # First-party from neural_lam import constants, utils from neural_lam.rotate_grid import unrotate_latlon +from neural_lam.weather_dataset import WeatherDataModule @matplotlib.rc_context(utils.fractional_plot_bundle(1)) @@ -179,3 +181,90 @@ def plot_spatial_error(error, title=None, vrange=None): fig.suptitle(title, size=10) return fig + +@matplotlib.rc_context(utils.fractional_plot_bundle(1)) +def verify_inference(file_path:str, feature_channel:int, vrange=None, save_path=None): + """ + Plot example prediction, forecast, and ground truth. + Each has shape (N_grid,) + """ + + # Verify that feature channel is between 0 and 42 + if not 0 <= feature_channel <= 42: + raise ValueError("feature_channel must be between 0 and 42, inclusive.") + + # Load the inference dataset for plotting + predictions_data_module = WeatherDataModule( + "cosmo", + path_verif_file=file_path, + split="verification", + standardize=False, + subset=False, + batch_size=6, + num_workers=2 + ) + predictions_data_module.setup(stage='verification') + predictions_loader = predictions_data_module.predictions_dataloader() + for predictions_batch in predictions_loader: + predictions = predictions_batch[0] # tensor + break + + # get test data + data_latlon = xr.open_zarr(constants.EXAMPLE_FILE).isel(time=0) + lon, lat = unrotate_latlon(data_latlon) + + # Get common scale for values + total = predictions[0,:,:,feature_channel] + total_array = np.array(total) + if vrange is None: + vmin = total_array.min() + vmax = total_array.max() + else: + vmin, vmax = float(vrange[0].cpu().item()), float(vrange[1].cpu().item()) + + # Plot + for i in range(23): + + feature_array = predictions[0,i,:,feature_channel].reshape(*constants.GRID_SHAPE[::-1]).cpu().numpy() + data_array = np.array(feature_array) + + fig, axes = plt.subplots( + 1, + 1, + figsize=constants.FIG_SIZE, + subplot_kw={"projection": constants.SELECTED_PROJ}, + ) + + contour_set = axes.contourf( + lon, + lat, + data_array, + transform=constants.SELECTED_PROJ, + cmap="plasma", + levels=np.linspace(vmin, vmax, num=100), + ) + axes.add_feature(cf.BORDERS, linestyle="-", edgecolor="black") + axes.add_feature(cf.COASTLINE, linestyle="-", edgecolor="black") + axes.gridlines( + crs=constants.SELECTED_PROJ, + draw_labels=False, + linewidth=0.5, + alpha=0.5, + ) + + # Ticks and labels + axes.set_title("Predictions from model inference", size = 15) + axes.text(0.5, 1.05, f"Feature channel {feature_channel}, time step {i}", + ha='center', va='bottom', transform=axes.transAxes, fontsize=12 + ) + cbar = fig.colorbar(contour_set, orientation="horizontal", aspect=20) + cbar.ax.tick_params(labelsize=10) + + # Save the plot! + if save_path: + directory = os.path.dirname(save_path) + if not os.path.exists(directory): + os.makedirs(directory) + plt.savefig(save_path + f"_feature_channel_{feature_channel}_"+ f"{i}.png", bbox_inches='tight') + + return fig \ No newline at end of file diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 17a00f64..b136b0ed 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -2,6 +2,7 @@ import glob import os from datetime import datetime, timedelta +import numpy as np # Third-party import pytorch_lightning as pl @@ -30,6 +31,7 @@ class WeatherDataset(torch.utils.data.Dataset): def __init__( self, dataset_name, + path_verif_file=None, split="train", standardize=True, subset=False, @@ -38,124 +40,130 @@ def __init__( ): super().__init__() - assert split in ("train", "val", "test"), "Unknown dataset split" - self.sample_dir_path = os.path.join( - "data", dataset_name, "samples", split - ) + assert split in ("train", "val", "test", "verification"), "Unknown dataset split" - self.batch_size = batch_size - self.batch_index = 0 - self.index_within_batch = 0 + if split == "verification": + self.sample_dir_path = path_verif_file + if os.path.exists(self.sample_dir_path): + self.np_files = np.load(self.sample_dir_path) + else: + self.sample_dir_path = os.path.join( + "data", dataset_name, "samples", split + ) - self.zarr_files = sorted( - glob.glob(os.path.join(self.sample_dir_path, "data*.zarr")) - ) - if len(self.zarr_files) == 0: - raise ValueError("No .zarr files found in directory") + self.batch_size = batch_size + self.batch_index = 0 + self.index_within_batch = 0 - if subset: - if constants.EVAL_DATETIME is not None and split == "test": - eval_datetime_obj = datetime.strptime( - constants.EVAL_DATETIME, "%Y%m%d%H" - ) - for i, file in enumerate(self.zarr_files): - file_datetime_str = file.split("/")[-1].split("_")[1][:-5] - file_datetime_obj = datetime.strptime( - file_datetime_str, "%Y%m%d%H" - ) - if ( - file_datetime_obj - <= eval_datetime_obj - < file_datetime_obj - + timedelta(hours=constants.CHUNK_SIZE) - ): - # Retrieve the current file and the next file if it - # exists - next_file_index = i + 1 - if next_file_index < len(self.zarr_files): - self.zarr_files = [ - file, - self.zarr_files[next_file_index], - ] - else: - self.zarr_files = [file] - position_within_file = int( - ( - eval_datetime_obj - file_datetime_obj - ).total_seconds() - // 3600 - ) - self.batch_index = ( - position_within_file // self.batch_size - ) - self.index_within_batch = ( - position_within_file % self.batch_size - ) - break - else: - self.zarr_files = self.zarr_files[0:2] - - start_datetime = ( - self.zarr_files[0] - .split("/")[-1] - .split("_")[1] - .replace(".zarr", "") + self.zarr_files = sorted( + glob.glob(os.path.join(self.sample_dir_path, "data*.zarr")) ) + if len(self.zarr_files) == 0: + raise ValueError("No .zarr files found in directory") - print("Data subset of 200 samples starts on the", start_datetime) - - # Separate 3D and 2D variables - variables_3d = [ - var for var in constants.PARAM_NAMES_SHORT if constants.IS_3D[var] - ] - variables_2d = [ - var - for var in constants.PARAM_NAMES_SHORT - if not constants.IS_3D[var] - ] - - # Stack 3D variables - datasets_3d = [ - xr.open_zarr(file, consolidated=True)[variables_3d] - .sel(z_1=constants.VERTICAL_LEVELS) - .to_array() - .stack(var=("variable", "z_1")) - .transpose("time", "x_1", "y_1", "var") - for file in self.zarr_files - ] - - # Stack 2D variables without selecting along z_1 - datasets_2d = [ - xr.open_zarr(file, consolidated=True)[variables_2d] - .to_array() - .expand_dims(z_1=[0]) - .stack(var=("variable", "z_1")) - .transpose("time", "x_1", "y_1", "var") - for file in self.zarr_files - ] - - # Combine 3D and 2D datasets - self.zarr_datasets = [ - xr.concat([ds_3d, ds_2d], dim="var").sortby("var") - for ds_3d, ds_2d in zip(datasets_3d, datasets_2d) - ] - - self.standardize = standardize - if standardize: - ds_stats = utils.load_dataset_stats(dataset_name, "cpu") - if constants.GRID_FORCING_DIM > 0: - self.data_mean, self.data_std, self.flux_mean, self.flux_std = ( - ds_stats["data_mean"], - ds_stats["data_std"], - ds_stats["flux_mean"], - ds_stats["flux_std"], - ) - else: - self.data_mean, self.data_std = ( - ds_stats["data_mean"], - ds_stats["data_std"], + if subset: + if constants.EVAL_DATETIME is not None and split == "test": + eval_datetime_obj = datetime.strptime( + constants.EVAL_DATETIME, "%Y%m%d%H" + ) + for i, file in enumerate(self.zarr_files): + file_datetime_str = file.split("/")[-1].split("_")[1][:-5] + file_datetime_obj = datetime.strptime( + file_datetime_str, "%Y%m%d%H" + ) + if ( + file_datetime_obj + <= eval_datetime_obj + < file_datetime_obj + + timedelta(hours=constants.CHUNK_SIZE) + ): + # Retrieve the current file and the next file if it + # exists + next_file_index = i + 1 + if next_file_index < len(self.zarr_files): + self.zarr_files = [ + file, + self.zarr_files[next_file_index], + ] + else: + self.zarr_files = [file] + position_within_file = int( + ( + eval_datetime_obj - file_datetime_obj + ).total_seconds() + // 3600 + ) + self.batch_index = ( + position_within_file // self.batch_size + ) + self.index_within_batch = ( + position_within_file % self.batch_size + ) + break + else: + self.zarr_files = self.zarr_files[0:2] + + start_datetime = ( + self.zarr_files[0] + .split("/")[-1] + .split("_")[1] + .replace(".zarr", "") ) + print("Data subset of 200 samples starts on the", start_datetime) + + # Separate 3D and 2D variables + variables_3d = [ + var for var in constants.PARAM_NAMES_SHORT if constants.IS_3D[var] + ] + variables_2d = [ + var + for var in constants.PARAM_NAMES_SHORT + if not constants.IS_3D[var] + ] + + # Stack 3D variables + datasets_3d = [ + xr.open_zarr(file, consolidated=True)[variables_3d] + .sel(z_1=constants.VERTICAL_LEVELS) + .to_array() + .stack(var=("variable", "z_1")) + .transpose("time", "x_1", "y_1", "var") + for file in self.zarr_files + ] + + # Stack 2D variables without selecting along z_1 + datasets_2d = [ + xr.open_zarr(file, consolidated=True)[variables_2d] + .to_array() + .pipe(lambda ds: ds if 'z_1' in ds.dims else ds.expand_dims(z_1=[0])) + .stack(var=("variable", "z_1")) + .transpose("time", "x_1", "y_1", "var") + for file in self.zarr_files + ] + + # Combine 3D and 2D datasets + self.zarr_datasets = [ + xr.concat([ds_3d, ds_2d], dim="var").sortby("var") + for ds_3d, ds_2d in zip(datasets_3d, datasets_2d) + ] + + self.standardize = standardize + if standardize: + ds_stats = utils.load_dataset_stats(dataset_name, "cpu") + if constants.GRID_FORCING_DIM > 0: + self.data_mean, self.data_std, self.flux_mean, self.flux_std = ( + ds_stats["data_mean"], + ds_stats["data_std"], + ds_stats["flux_mean"], + ds_stats["flux_std"], + ) + else: + self.data_mean, self.data_std = ( + ds_stats["data_mean"], + ds_stats["data_std"], + ) + self.random_subsample = split == "train" self.split = split @@ -165,44 +173,51 @@ def __len__(self): if self.split == "train" else constants.EVAL_HORIZON ) - total_time = len(self.zarr_files) * constants.CHUNK_SIZE - num_steps + total_time = 1 + if hasattr(self, "zarr_files"): + total_time = len(self.zarr_files) * constants.CHUNK_SIZE - num_steps return total_time def __getitem__(self, idx): - num_steps = ( - constants.TRAIN_HORIZON - if self.split == "train" - else constants.EVAL_HORIZON - ) - # Calculate which zarr files need to be loaded - start_file_idx = idx // constants.CHUNK_SIZE - end_file_idx = (idx + num_steps) // constants.CHUNK_SIZE - # Index of current slice - idx_sample = idx % constants.CHUNK_SIZE + if self.split == "verification": + return self.np_files - sample_archive = xr.concat( - self.zarr_datasets[start_file_idx : end_file_idx + 1], dim="time" - ) + else: + num_steps = ( + constants.TRAIN_HORIZON + if self.split == "train" + else constants.EVAL_HORIZON + ) - sample_xr = sample_archive.isel( - time=slice(idx_sample, idx_sample + num_steps) - ) + # Calculate which zarr files need to be loaded + start_file_idx = idx // constants.CHUNK_SIZE + end_file_idx = (idx + num_steps) // constants.CHUNK_SIZE + # Index of current slice + idx_sample = idx % constants.CHUNK_SIZE + + sample_archive = xr.concat( + self.zarr_datasets[start_file_idx : end_file_idx + 1], dim="time" + ) + + sample_xr = sample_archive.isel( + time=slice(idx_sample, idx_sample + num_steps) + ) - # (N_t', N_x, N_y, d_features') - sample = torch.tensor(sample_xr.values, dtype=torch.float32) + # (N_t', N_x, N_y, d_features') + sample = torch.tensor(sample_xr.values, dtype=torch.float32) - sample = sample.flatten(1, 2) # (N_t, N_grid, d_features) + sample = sample.flatten(1, 2) # (N_t, N_grid, d_features) - if self.standardize: - # Standardize sample - sample = (sample - self.data_mean) / self.data_std + if self.standardize: + # Standardize sample + sample = (sample - self.data_mean) / self.data_std - # Split up sample in init. states and target states - init_states = sample[:2] # (2, N_grid, d_features) - target_states = sample[2:] # (sample_length-2, N_grid, d_features) + # Split up sample in init. states and target states + init_states = sample[:2] # (2, N_grid, d_features) + target_states = sample[2:] # (sample_length-2, N_grid, d_features) - return init_states, target_states + return init_states, target_states class WeatherDataModule(pl.LightningDataModule): @@ -212,6 +227,7 @@ def __init__( self, dataset_name, split="train", + path_verif_file=None, standardize=True, subset=False, batch_size=4, @@ -219,6 +235,7 @@ def __init__( ): super().__init__() self.dataset_name = dataset_name + self.path_verif_file = path_verif_file self.batch_size = batch_size self.num_workers = num_workers self.standardize = standardize @@ -256,6 +273,16 @@ def setup(self, stage=None): batch_size=self.batch_size, ) + if stage == "verification": + self.predictions_dataset = WeatherDataset( + self.dataset_name, + self.path_verif_file, + split="verification", + standardize=False, + subset=False, + batch_size=self.batch_size, + ) + def train_dataloader(self): return torch.utils.data.DataLoader( self.train_dataset, @@ -282,3 +309,13 @@ def test_dataloader(self): shuffle=False, pin_memory=False, ) + + def predictions_dataloader(self): + return torch.utils.data.DataLoader( + self.predictions_dataset, + path_verif_file=None, + batch_size=1, + num_workers=self.num_workers, + shuffle=False, + pin_memory=False, + ) From 34bbfecbb5ef0dc133fe973d5963e206d0a9c4ce Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Fri, 12 Apr 2024 09:34:31 +0200 Subject: [PATCH 02/13] removing path to file from DataLoader --- neural_lam/weather_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 415b50a3..5ca34d6c 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -375,6 +375,7 @@ def predict_dataloader(self): def verification_dataloader(self): return torch.utils.data.DataLoader( self.verification_dataset, - path_verif_file=None, batch_size=1, + shuffle=False, + pin_memory=False, ) From b135f1a812d297f38957082712136d25a654152e Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 12 Apr 2024 10:26:33 +0200 Subject: [PATCH 03/13] removed code duplication --- neural_lam/weather_dataset.py | 37 +---------------------------------- 1 file changed, 1 insertion(+), 36 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 5ca34d6c..4dc0d856 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -119,42 +119,7 @@ def __init__( .split("_")[1] .replace(".zarr", "") ) - - # Stack 2D variables without selecting along z_1 - datasets_2d = [ - xr.open_zarr(file, consolidated=True)[variables_2d] - .to_array() - .pipe( - lambda ds: ds if "z_1" in ds.dims else ds.expand_dims(z_1=[0]) - ) - .stack(var=("variable", "z_1")) - .transpose("time", "x_1", "y_1", "var") - for file in self.zarr_files - ] - - # Combine 3D and 2D datasets - self.zarr_datasets = [ - xr.concat([ds_3d, ds_2d], dim="var").sortby("var") - for ds_3d, ds_2d in zip(datasets_3d, datasets_2d) - ] - - self.standardize = standardize - if standardize: - ds_stats = utils.load_dataset_stats(dataset_name, "cpu") - if constants.GRID_FORCING_DIM > 0: - self.data_mean, self.data_std, self.flux_mean, self.flux_std = ( - ds_stats["data_mean"], - ds_stats["data_std"], - ds_stats["flux_mean"], - ds_stats["flux_std"], - ) - else: - self.data_mean, self.data_std = ( - ds_stats["data_mean"], - ds_stats["data_std"], - ) - - print("Data subset of 200 samples starts on the", start_datetime) + print("Data subset of 200 samples starts on the", start_datetime) # Separate 3D and 2D variables variables_3d = [ From 7715d2be8001de4a2853cf448c8699efb29dee23 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 12 Apr 2024 10:27:35 +0200 Subject: [PATCH 04/13] unified naming with other stages/splits --- neural_lam/weather_dataset.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 4dc0d856..6b19034e 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -45,14 +45,14 @@ def __init__( "val", "test", "pred", - "verification" + "verif" ), "Unknown dataset split" self.sample_dir_path = os.path.join( "data", dataset_name, "samples", split ) print(self.sample_dir_path) - if split == "verification": + if split == "verif": self.sample_dir_path = path_verif_file if os.path.exists(self.sample_dir_path): self.np_files = np.load(self.sample_dir_path) @@ -189,7 +189,7 @@ def __len__(self): def __getitem__(self, idx): - if self.split == "verification": + if self.split == "verif": return self.np_files else: @@ -255,7 +255,7 @@ def prepare_data(self): pass def setup(self, stage=None): - # make assignments here (val/train/test/predict split) + # make assignments here (val/train/test/pred split) # called on every process in DDP if stage == "fit" or stage is None: self.train_dataset = WeatherDataset( @@ -282,18 +282,18 @@ def setup(self, stage=None): batch_size=self.batch_size, ) - if stage == "verification": - self.verification_dataset = WeatherDataset( + if stage == "verif": + self.verif_dataset = WeatherDataset( self.dataset_name, self.path_verif_file, - split="verification", + split="verif", standardize=False, subset=False, batch_size=self.batch_size, ) - if stage == "predict" or stage is None: - self.predict_dataset = WeatherDataset( + if stage == "pred" or stage is None: + self.pred_dataset = WeatherDataset( self.dataset_name, split="pred", standardize=self.standardize, @@ -328,18 +328,18 @@ def test_dataloader(self): pin_memory=False, ) - def predict_dataloader(self): + def pred_dataloader(self): return torch.utils.data.DataLoader( - self.predict_dataset, + self.pred_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, pin_memory=False, ) - def verification_dataloader(self): + def verif_dataloader(self): return torch.utils.data.DataLoader( - self.verification_dataset, + self.verif_dataset, batch_size=1, shuffle=False, pin_memory=False, From c6d1d7fbad65afcae434c22a43eaba4f4c698d57 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Fri, 12 Apr 2024 10:41:30 +0200 Subject: [PATCH 05/13] Linters --- neural_lam/vis.py | 69 +++++--- neural_lam/weather_dataset.py | 285 +++++++++++++++++----------------- verify.py | 4 + 3 files changed, 193 insertions(+), 165 deletions(-) create mode 100644 verify.py diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 49ef6647..8b150bf4 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -1,5 +1,7 @@ -# Third-party +# Standard library import os + +# Third-party import cartopy.feature as cf import matplotlib import matplotlib.pyplot as plt @@ -182,8 +184,11 @@ def plot_spatial_error(error, title=None, vrange=None): return fig + @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def verify_inference(file_path:str, feature_channel:int, vrange=None, save_path=None): +def verify_inference( + file_path: str, feature_channel: int, vrange=None, save_path=None +): """ Plot example prediction, forecast, and ground truth. Each has shape (N_grid,) @@ -195,37 +200,44 @@ def verify_inference(file_path:str, feature_channel:int, vrange=None, save_path= # Load the inference dataset for plotting predictions_data_module = WeatherDataModule( - "cosmo", - path_verif_file=file_path, - split="verification", - standardize=False, - subset=False, - batch_size=6, - num_workers=2 - ) - predictions_data_module.setup(stage='verification') - predictions_loader = predictions_data_module.verification_dataloader() + "cosmo", + path_verif_file=file_path, + split="verif", + standardize=False, + subset=False, + batch_size=6, + num_workers=2, + ) + predictions_data_module.setup(stage="verif") + predictions_loader = predictions_data_module.verif_dataloader() for predictions_batch in predictions_loader: - predictions = predictions_batch[0] # tensor - break + predictions = predictions_batch[0] # tensor + break # get test data data_latlon = xr.open_zarr(constants.EXAMPLE_FILE).isel(time=0) lon, lat = unrotate_latlon(data_latlon) # Get common scale for values - total = predictions[0,:,:,feature_channel] + total = predictions[0, :, :, feature_channel] total_array = np.array(total) if vrange is None: vmin = total_array.min() vmax = total_array.max() else: - vmin, vmax = float(vrange[0].cpu().item()), float(vrange[1].cpu().item()) + vmin, vmax = float(vrange[0].cpu().item()), float( + vrange[1].cpu().item() + ) # Plot for i in range(23): - feature_array = predictions[0,i,:,feature_channel].reshape(*constants.GRID_SHAPE[::-1]).cpu().numpy() + feature_array = ( + predictions[0, i, :, feature_channel] + .reshape(*constants.GRID_SHAPE[::-1]) + .cpu() + .numpy() + ) data_array = np.array(feature_array) fig, axes = plt.subplots( @@ -253,18 +265,27 @@ def verify_inference(file_path:str, feature_channel:int, vrange=None, save_path= ) # Ticks and labels - axes.set_title("Predictions from model inference", size = 15) - axes.text(0.5, 1.05, f"Feature channel {feature_channel}, time step {i}", - ha='center', va='bottom', transform=axes.transAxes, fontsize=12 - ) + axes.set_title("Predictions from model inference", size=15) + axes.text( + 0.5, + 1.05, + f"Feature channel {feature_channel}, time step {i}", + ha="center", + va="bottom", + transform=axes.transAxes, + fontsize=12, + ) cbar = fig.colorbar(contour_set, orientation="horizontal", aspect=20) cbar.ax.tick_params(labelsize=10) - # Save the plot! + # Save the plot! if save_path: directory = os.path.dirname(save_path) if not os.path.exists(directory): os.makedirs(directory) - plt.savefig(save_path + f"_feature_channel_{feature_channel}_"+ f"{i}.png", bbox_inches='tight') + plt.savefig( + save_path + f"_feature_channel_{feature_channel}_" + f"{i}.png", + bbox_inches="tight", + ) - return fig \ No newline at end of file + return fig diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 6b19034e..c3de2efd 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -2,9 +2,9 @@ import glob import os from datetime import datetime, timedelta -import numpy as np # Third-party +import numpy as np import pytorch_lightning as pl import torch import xarray as xr @@ -45,73 +45,70 @@ def __init__( "val", "test", "pred", - "verif" + "verif", ), "Unknown dataset split" self.sample_dir_path = os.path.join( "data", dataset_name, "samples", split ) print(self.sample_dir_path) - if split == "verif": - self.sample_dir_path = path_verif_file - if os.path.exists(self.sample_dir_path): - self.np_files = np.load(self.sample_dir_path) - else: - self.sample_dir_path = os.path.join( - "data", dataset_name, "samples", split - ) - - self.batch_size = batch_size - self.batch_index = 0 - self.index_within_batch = 0 + self.batch_size = batch_size + self.batch_index = 0 + self.index_within_batch = 0 + self.sample_dir_path = path_verif_file - self.zarr_files = sorted( - glob.glob(os.path.join(self.sample_dir_path, "data*.zarr")) - ) - if len(self.zarr_files) == 0: - raise ValueError("No .zarr files found in directory") + if split == "verif" and os.path.exists(self.sample_dir_path): + self.np_files = np.load(self.sample_dir_path) + self.split = split + return - if subset: - if constants.EVAL_DATETIME is not None and split == "test": - eval_datetime_obj = datetime.strptime( - constants.EVAL_DATETIME, "%Y%m%d%H" + self.zarr_files = sorted( + glob.glob(os.path.join(self.sample_dir_path, "data*.zarr")) + ) + if len(self.zarr_files) == 0: + raise ValueError("No .zarr files found in directory") + + if subset: + if constants.EVAL_DATETIME is not None and split == "test": + eval_datetime_obj = datetime.strptime( + constants.EVAL_DATETIME, "%Y%m%d%H" + ) + for i, file in enumerate(self.zarr_files): + file_datetime_str = file.split("/")[-1].split("_")[1][:-5] + file_datetime_obj = datetime.strptime( + file_datetime_str, "%Y%m%d%H" ) - for i, file in enumerate(self.zarr_files): - file_datetime_str = file.split("/")[-1].split("_")[1][:-5] - file_datetime_obj = datetime.strptime( - file_datetime_str, "%Y%m%d%H" + if ( + file_datetime_obj + <= eval_datetime_obj + < file_datetime_obj + + timedelta(hours=constants.CHUNK_SIZE) + ): + # Retrieve the current file and the next file if it + # exists + next_file_index = i + 1 + if next_file_index < len(self.zarr_files): + self.zarr_files = [ + file, + self.zarr_files[next_file_index], + ] + else: + self.zarr_files = [file] + position_within_file = int( + ( + eval_datetime_obj - file_datetime_obj + ).total_seconds() + // 3600 ) - if ( - file_datetime_obj - <= eval_datetime_obj - < file_datetime_obj - + timedelta(hours=constants.CHUNK_SIZE) - ): - # Retrieve the current file and the next file if it - # exists - next_file_index = i + 1 - if next_file_index < len(self.zarr_files): - self.zarr_files = [ - file, - self.zarr_files[next_file_index], - ] - else: - self.zarr_files = [file] - position_within_file = int( - ( - eval_datetime_obj - file_datetime_obj - ).total_seconds() - // 3600 - ) - self.batch_index = ( - position_within_file // self.batch_size - ) - self.index_within_batch = ( - position_within_file % self.batch_size - ) - break - else: - self.zarr_files = self.zarr_files[0:2] + self.batch_index = ( + position_within_file // self.batch_size + ) + self.index_within_batch = ( + position_within_file % self.batch_size + ) + break + else: + self.zarr_files = self.zarr_files[0:2] start_datetime = ( self.zarr_files[0] @@ -119,60 +116,67 @@ def __init__( .split("_")[1] .replace(".zarr", "") ) + print("Data subset of 200 samples starts on the", start_datetime) - # Separate 3D and 2D variables - variables_3d = [ - var for var in constants.PARAM_NAMES_SHORT if constants.IS_3D[var] - ] - variables_2d = [ - var - for var in constants.PARAM_NAMES_SHORT - if not constants.IS_3D[var] - ] - - # Stack 3D variables - datasets_3d = [ - xr.open_zarr(file, consolidated=True)[variables_3d] - .sel(z_1=constants.VERTICAL_LEVELS) - .to_array() - .stack(var=("variable", "z_1")) - .transpose("time", "x_1", "y_1", "var") - for file in self.zarr_files - ] - - # Stack 2D variables without selecting along z_1 - datasets_2d = [ - xr.open_zarr(file, consolidated=True)[variables_2d] - .to_array() - .pipe(lambda ds: ds if 'z_1' in ds.dims else ds.expand_dims(z_1=[0])) - .stack(var=("variable", "z_1")) - .transpose("time", "x_1", "y_1", "var") - for file in self.zarr_files - ] - - # Combine 3D and 2D datasets - self.zarr_datasets = [ - xr.concat([ds_3d, ds_2d], dim="var").sortby("var") - for ds_3d, ds_2d in zip(datasets_3d, datasets_2d) - ] - - self.standardize = standardize - if standardize: - ds_stats = utils.load_dataset_stats(dataset_name, "cpu") - if constants.GRID_FORCING_DIM > 0: - self.data_mean, self.data_std, self.flux_mean, self.flux_std = ( - ds_stats["data_mean"], - ds_stats["data_std"], - ds_stats["flux_mean"], - ds_stats["flux_std"], - ) - else: - self.data_mean, self.data_std = ( - ds_stats["data_mean"], - ds_stats["data_std"], - ) + # Separate 3D and 2D variables + variables_3d = [ + var for var in constants.PARAM_NAMES_SHORT if constants.IS_3D[var] + ] + variables_2d = [ + var + for var in constants.PARAM_NAMES_SHORT + if not constants.IS_3D[var] + ] + + # Stack 3D variables + datasets_3d = [ + xr.open_zarr(file, consolidated=True)[variables_3d] + .sel(z_1=constants.VERTICAL_LEVELS) + .to_array() + .stack(var=("variable", "z_1")) + .transpose("time", "x_1", "y_1", "var") + for file in self.zarr_files + ] + + # Stack 2D variables without selecting along z_1 + datasets_2d = [ + xr.open_zarr(file, consolidated=True)[variables_2d] + .to_array() + .pipe( + lambda ds: (ds if "z_1" in ds.dims else ds.expand_dims(z_1=[0])) + ) + .stack(var=("variable", "z_1")) + .transpose("time", "x_1", "y_1", "var") + for file in self.zarr_files + ] + + # Combine 3D and 2D datasets + self.zarr_datasets = [ + xr.concat([ds_3d, ds_2d], dim="var").sortby("var") + for ds_3d, ds_2d in zip(datasets_3d, datasets_2d) + ] + self.standardize = standardize + if standardize: + ds_stats = utils.load_dataset_stats(dataset_name, "cpu") + if constants.GRID_FORCING_DIM > 0: + ( + self.data_mean, + self.data_std, + self.flux_mean, + self.flux_std, + ) = ( + ds_stats["data_mean"], + ds_stats["data_std"], + ds_stats["flux_mean"], + ds_stats["flux_std"], + ) + else: + self.data_mean, self.data_std = ( + ds_stats["data_mean"], + ds_stats["data_std"], + ) self.random_subsample = split == "train" self.split = split @@ -188,45 +192,44 @@ def __len__(self): return total_time def __getitem__(self, idx): - - if self.split == "verif": + if self.split == "verif": return self.np_files - else: - num_steps = ( - constants.TRAIN_HORIZON - if self.split == "train" - else constants.EVAL_HORIZON - ) + num_steps = ( + constants.TRAIN_HORIZON + if self.split == "train" + else constants.EVAL_HORIZON + ) - # Calculate which zarr files need to be loaded - start_file_idx = idx // constants.CHUNK_SIZE - end_file_idx = (idx + num_steps) // constants.CHUNK_SIZE - # Index of current slice - idx_sample = idx % constants.CHUNK_SIZE + # Calculate which zarr files need to be loaded + start_file_idx = idx // constants.CHUNK_SIZE + end_file_idx = (idx + num_steps) // constants.CHUNK_SIZE + # Index of current slice + idx_sample = idx % constants.CHUNK_SIZE - sample_archive = xr.concat( - self.zarr_datasets[start_file_idx : end_file_idx + 1], dim="time" - ) + sample_archive = xr.concat( + self.zarr_datasets[start_file_idx : end_file_idx + 1], + dim="time", + ) - sample_xr = sample_archive.isel( - time=slice(idx_sample, idx_sample + num_steps) - ) + sample_xr = sample_archive.isel( + time=slice(idx_sample, idx_sample + num_steps) + ) - # (N_t', N_x, N_y, d_features') - sample = torch.tensor(sample_xr.values, dtype=torch.float32) + # (N_t', N_x, N_y, d_features') + sample = torch.tensor(sample_xr.values, dtype=torch.float32) - sample = sample.flatten(1, 2) # (N_t, N_grid, d_features) + sample = sample.flatten(1, 2) # (N_t, N_grid, d_features) - if self.standardize: - # Standardize sample - sample = (sample - self.data_mean) / self.data_std + if self.standardize: + # Standardize sample + sample = (sample - self.data_mean) / self.data_std - # Split up sample in init. states and target states - init_states = sample[:2] # (2, N_grid, d_features) - target_states = sample[2:] # (sample_length-2, N_grid, d_features) + # Split up sample in init. states and target states + init_states = sample[:2] # (2, N_grid, d_features) + target_states = sample[2:] # (sample_length-2, N_grid, d_features) - return init_states, target_states + return init_states, target_states class WeatherDataModule(pl.LightningDataModule): @@ -336,11 +339,11 @@ def pred_dataloader(self): shuffle=False, pin_memory=False, ) - - def verif_dataloader(self): + + def verif_dataloader(self): return torch.utils.data.DataLoader( self.verif_dataset, batch_size=1, - shuffle=False, + shuffle=False, pin_memory=False, ) diff --git a/verify.py b/verify.py new file mode 100644 index 00000000..c0e408e6 --- /dev/null +++ b/verify.py @@ -0,0 +1,4 @@ +from neural_lam.vis import verify_inference + + +verify_inference(file_path = "/users/clechart/clechart/neural-lam/wandb/run-20240411_140635-cux0r96n/files/results/inference/prediction_0.npy", feature_channel=24, save_path= "/users/clechart/clechart/neural-lam/neural_lam") \ No newline at end of file From 5ff7cb0178ba255e34f865480704574073e99a6a Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Fri, 12 Apr 2024 14:06:46 +0200 Subject: [PATCH 06/13] cli entry point --- README.md | 16 ++++++++++++++++ cli_plotting.py | 31 +++++++++++++++++++++++++++++++ neural_lam/vis.py | 10 +++++----- neural_lam/weather_dataset.py | 5 +++++ 4 files changed, 57 insertions(+), 5 deletions(-) create mode 100644 cli_plotting.py diff --git a/README.md b/README.md index 86630683..029614c4 100644 --- a/README.md +++ b/README.md @@ -258,6 +258,22 @@ Some options specifically important for evaluation are: **Note:** While it is technically possible to use multiple GPUs for running evaluation, this is strongly discouraged. If using multiple devices the `DistributedSampler` will replicate some samples to make sure all devices have the same batch size, meaning that evaluation metrics will be unreliable. This issue stems from PyTorch Lightning. See for example [this draft PR](https://github.com/Lightning-AI/torchmetrics/pull/1886) for more discussion and ongoing work to remedy this. +## Plot Model output +One can use the command-line tool `cli_plotting.py` to generate the plotting and verifying of inference results stored in `.npy` files. + +Arguments and options: + + * `--file_path`: The path to the .npy file that contains the inferred values. This argument is required. + * `--save_path`: The path where the output files will be saved. This argument is required. + * `--feature_channel` (Optional): Specifies the feature channel to use during the verification. Default is 0. + +Example Usage to run the plotting on a the 10th feature channel of a .npy file located at /data/results/inference_output.npy and save the output in /outputs:: + +``` +python cli_plotting.py --file_path /data/results/inference_output.npy --save_path /outputs --feature_channel 10 +``` + + # Repository Structure Except for training and pre-processing scripts all the source code can be found in the `neural_lam` directory. Model classes, including abstract base classes, are located in `neural_lam/models`. diff --git a/cli_plotting.py b/cli_plotting.py new file mode 100644 index 00000000..ff3d495f --- /dev/null +++ b/cli_plotting.py @@ -0,0 +1,31 @@ +# Third-party +import click + +# First-party +from neural_lam.vis import verify_inference + + +@click.command() +@click.argument("file_path", type=click.Path(exists=True), required=True) +@click.argument("save_path", type=click.Path(), required=True) +@click.option( + "--feature_channel", + "-f", + default=0, + help="Feature channel to use. Default is 0.", + type=int, + show_default=True, +) +def main(file_path: str, save_path: str, feature_channel: int) -> None: + """ + Command line tool for verifying neural_lam inference results. + """ + verify_inference(file_path, feature_channel, save_path) + + +if __name__ == "__main__": + main( + file_path="/", + save_path="./", + feature_channel=0, + ) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 8b150bf4..54bc924e 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -194,10 +194,6 @@ def verify_inference( Each has shape (N_grid,) """ - # Verify that feature channel is between 0 and 42 - if not 0 <= feature_channel <= 42: - raise ValueError("feature_channel must be between 0 and 42, inclusive.") - # Load the inference dataset for plotting predictions_data_module = WeatherDataModule( "cosmo", @@ -214,6 +210,10 @@ def verify_inference( predictions = predictions_batch[0] # tensor break + # Verify that feature channel is between 0 and 42 + if not 0 <= feature_channel < predictions.shape[-1]: + raise ValueError("feature_channel must be between 0 and 42, inclusive.") + # get test data data_latlon = xr.open_zarr(constants.EXAMPLE_FILE).isel(time=0) lon, lat = unrotate_latlon(data_latlon) @@ -230,7 +230,7 @@ def verify_inference( ) # Plot - for i in range(23): + for i in range(constants.EVAL_HORIZON - 2): feature_array = ( predictions[0, i, :, feature_channel] diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index c3de2efd..4a7eb0e9 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -305,6 +305,7 @@ def setup(self, stage=None): ) def train_dataloader(self): + """Load train dataset.""" return torch.utils.data.DataLoader( self.train_dataset, batch_size=self.batch_size, @@ -314,6 +315,7 @@ def train_dataloader(self): ) def val_dataloader(self): + """Load validation dataset.""" return torch.utils.data.DataLoader( self.val_dataset, batch_size=self.batch_size // self.batch_size, @@ -323,6 +325,7 @@ def val_dataloader(self): ) def test_dataloader(self): + """Load test dataset.""" return torch.utils.data.DataLoader( self.test_dataset, batch_size=self.batch_size, @@ -332,6 +335,7 @@ def test_dataloader(self): ) def pred_dataloader(self): + """Load prediction dataset.""" return torch.utils.data.DataLoader( self.pred_dataset, batch_size=self.batch_size, @@ -341,6 +345,7 @@ def pred_dataloader(self): ) def verif_dataloader(self): + """Load inference output dataset.""" return torch.utils.data.DataLoader( self.verif_dataset, batch_size=1, From eff6c5612cd62f61fdc5c2e7d1cad8c4f81d6bd1 Mon Sep 17 00:00:00 2001 From: Capucine Lechartre Date: Fri, 12 Apr 2024 14:12:57 +0200 Subject: [PATCH 07/13] remove useless file --- verify.py | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 verify.py diff --git a/verify.py b/verify.py deleted file mode 100644 index c0e408e6..00000000 --- a/verify.py +++ /dev/null @@ -1,4 +0,0 @@ -from neural_lam.vis import verify_inference - - -verify_inference(file_path = "/users/clechart/clechart/neural-lam/wandb/run-20240411_140635-cux0r96n/files/results/inference/prediction_0.npy", feature_channel=24, save_path= "/users/clechart/clechart/neural-lam/neural_lam") \ No newline at end of file From 9141dbecdbf9fcd1dec17dc63c2331491ca64194 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 12 Apr 2024 15:20:46 +0200 Subject: [PATCH 08/13] Fixed issue with CLI paths --- README.md | 2 +- cli_plotting.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 029614c4..c1df8e3d 100644 --- a/README.md +++ b/README.md @@ -270,7 +270,7 @@ Arguments and options: Example Usage to run the plotting on a the 10th feature channel of a .npy file located at /data/results/inference_output.npy and save the output in /outputs:: ``` -python cli_plotting.py --file_path /data/results/inference_output.npy --save_path /outputs --feature_channel 10 +python cli_plotting.py data/results/inference_output.npy outputs/ --feature_channel 10 ``` diff --git a/cli_plotting.py b/cli_plotting.py index ff3d495f..d6c7b0ef 100644 --- a/cli_plotting.py +++ b/cli_plotting.py @@ -24,8 +24,4 @@ def main(file_path: str, save_path: str, feature_channel: int) -> None: if __name__ == "__main__": - main( - file_path="/", - save_path="./", - feature_channel=0, - ) + main() From 5d51fb6a788b66bac0b3bf64e09659b329637b4d Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 12 Apr 2024 15:39:10 +0200 Subject: [PATCH 09/13] Font "serif" not available on Balfrin? --- neural_lam/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 56a225da..28c151ae 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -256,7 +256,7 @@ def fractional_plot_bundle(fraction): Get the tueplots bundle, but with figure width as a fraction of the page width. """ - bundle = bundles.neurips2023(usetex=False, family="serif") + bundle = bundles.neurips2023(usetex=False, family="DejaVu Sans") bundle.update(figsizes.neurips2023()) original_figsize = bundle["figure.figsize"] bundle["figure.figsize"] = ( From ecddf12893439c4c186d7e6855e0add711a4da7b Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 12 Apr 2024 15:40:07 +0200 Subject: [PATCH 10/13] Some minor fixes - Arguments of function were in wrong order - No need to return the last plot but free up memory --- neural_lam/vis.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 54bc924e..6d5cc2a3 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -187,7 +187,7 @@ def plot_spatial_error(error, title=None, vrange=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def verify_inference( - file_path: str, feature_channel: int, vrange=None, save_path=None + file_path: str, feature_channel: int, save_path=None, vrange=None ): """ Plot example prediction, forecast, and ground truth. @@ -212,7 +212,8 @@ def verify_inference( # Verify that feature channel is between 0 and 42 if not 0 <= feature_channel < predictions.shape[-1]: - raise ValueError("feature_channel must be between 0 and 42, inclusive.") + raise ValueError( + f"feature_channel must be between 0 and {predictions.shape[-1]}, inclusive.") # get test data data_latlon = xr.open_zarr(constants.EXAMPLE_FILE).isel(time=0) @@ -287,5 +288,4 @@ def verify_inference( save_path + f"_feature_channel_{feature_channel}_" + f"{i}.png", bbox_inches="tight", ) - - return fig + plt.close() From 6bfb825dad7b56ae7999dbbc215671d33ffeab11 Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 12 Apr 2024 16:10:26 +0200 Subject: [PATCH 11/13] fix linter --- cli_plotting.py | 2 +- neural_lam/vis.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/cli_plotting.py b/cli_plotting.py index d6c7b0ef..b916e919 100644 --- a/cli_plotting.py +++ b/cli_plotting.py @@ -24,4 +24,4 @@ def main(file_path: str, save_path: str, feature_channel: int) -> None: if __name__ == "__main__": - main() + main() # pylint: disable=no-value-for-parameter diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 6d5cc2a3..f10dde53 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -210,10 +210,12 @@ def verify_inference( predictions = predictions_batch[0] # tensor break - # Verify that feature channel is between 0 and 42 + # Verify that feature channel is within bounds if not 0 <= feature_channel < predictions.shape[-1]: raise ValueError( - f"feature_channel must be between 0 and {predictions.shape[-1]}, inclusive.") + f"feature_channel must be between 0 and " + f"{predictions.shape[-1]-1}, inclusive." + ) # get test data data_latlon = xr.open_zarr(constants.EXAMPLE_FILE).isel(time=0) From 237f0a0157238e9f9e52f97c1ddee01d17c5840c Mon Sep 17 00:00:00 2001 From: Simon Adamov Date: Fri, 12 Apr 2024 16:26:17 +0200 Subject: [PATCH 12/13] - added progress indication - slight changes to the save_path --- cli_plotting.py | 2 +- neural_lam/vis.py | 23 ++++++++++++----------- neural_lam/weather_dataset.py | 1 - 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/cli_plotting.py b/cli_plotting.py index b916e919..61239cea 100644 --- a/cli_plotting.py +++ b/cli_plotting.py @@ -20,7 +20,7 @@ def main(file_path: str, save_path: str, feature_channel: int) -> None: """ Command line tool for verifying neural_lam inference results. """ - verify_inference(file_path, feature_channel, save_path) + verify_inference(file_path, save_path, feature_channel) if __name__ == "__main__": diff --git a/neural_lam/vis.py b/neural_lam/vis.py index f10dde53..04873cf6 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import numpy as np import xarray as xr +from tqdm import tqdm # First-party from neural_lam import constants, utils @@ -187,7 +188,7 @@ def plot_spatial_error(error, title=None, vrange=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def verify_inference( - file_path: str, feature_channel: int, save_path=None, vrange=None + file_path: str, save_path: str, feature_channel: int, vrange=None ): """ Plot example prediction, forecast, and ground truth. @@ -233,8 +234,9 @@ def verify_inference( ) # Plot - for i in range(constants.EVAL_HORIZON - 2): - + for i in tqdm( + range(constants.EVAL_HORIZON - 2), desc="Plotting predictions" + ): feature_array = ( predictions[0, i, :, feature_channel] .reshape(*constants.GRID_SHAPE[::-1]) @@ -282,12 +284,11 @@ def verify_inference( cbar.ax.tick_params(labelsize=10) # Save the plot! - if save_path: - directory = os.path.dirname(save_path) - if not os.path.exists(directory): - os.makedirs(directory) - plt.savefig( - save_path + f"_feature_channel_{feature_channel}_" + f"{i}.png", - bbox_inches="tight", - ) + directory = os.path.dirname(save_path) + if not os.path.exists(directory): + os.makedirs(directory) + plt.savefig( + f"{save_path}feature_channel_{feature_channel}_{i}.png", + bbox_inches="tight", + ) plt.close() diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 4a7eb0e9..adfd81f3 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -50,7 +50,6 @@ def __init__( self.sample_dir_path = os.path.join( "data", dataset_name, "samples", split ) - print(self.sample_dir_path) self.batch_size = batch_size self.batch_index = 0 From 15b62a0f9a7eb322400dd851c2e509fbf328b6c6 Mon Sep 17 00:00:00 2001 From: sadamov <45732287+sadamov@users.noreply.github.com> Date: Fri, 12 Apr 2024 16:27:51 +0200 Subject: [PATCH 13/13] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c1df8e3d..fe4f10f4 100644 --- a/README.md +++ b/README.md @@ -267,7 +267,7 @@ Arguments and options: * `--save_path`: The path where the output files will be saved. This argument is required. * `--feature_channel` (Optional): Specifies the feature channel to use during the verification. Default is 0. -Example Usage to run the plotting on a the 10th feature channel of a .npy file located at /data/results/inference_output.npy and save the output in /outputs:: +Example Usage to run the plotting on the 10th feature channel of a .npy file located at /data/results/inference_output.npy and save the output in /outputs:: ``` python cli_plotting.py data/results/inference_output.npy outputs/ --feature_channel 10