diff --git a/.gitignore b/.gitignore index 5ca89369..474e5649 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ sweeps test_*.sh lightning_logs .vscode +outputs ### Python ### # Byte-compiled / optimized / DLL files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f48eca67..a32ddc51 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,6 +18,7 @@ repos: description: Check for spelling errors language: system entry: codespell + args: ['--ignore-words-list=laf'] - repo: local hooks: - id: black diff --git a/create_parameter_weights.py b/create_parameter_weights.py index f9cab328..5042c576 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -88,12 +88,12 @@ def main(): flux_squares = [] for batch_data in tqdm(loader): if constants.GRID_FORCING_DIM > 0: - init_batch, target_batch, forcing_batch = batch_data + init_batch, target_batch, _, forcing_batch = batch_data flux_batch = forcing_batch[:, :, :, 0] # Flux is first index flux_means.append(torch.mean(flux_batch)) # (,) flux_squares.append(torch.mean(flux_batch**2)) # (,) else: - init_batch, target_batch = batch_data + init_batch, target_batch, _ = batch_data batch = torch.cat( (init_batch, target_batch), dim=1 @@ -134,12 +134,12 @@ def main(): diff_squares = [] for batch_data in tqdm(loader_standard): if constants.GRID_FORCING_DIM > 0: - init_batch, target_batch, forcing_batch = batch_data + init_batch, target_batch, _, forcing_batch = batch_data flux_batch = forcing_batch[:, :, :, 0] # Flux is first index flux_means.append(torch.mean(flux_batch)) # (,) flux_squares.append(torch.mean(flux_batch**2)) # (,) else: - init_batch, target_batch = batch_data + init_batch, target_batch, _ = batch_data batch_diffs = init_batch[:, 1:] - target_batch # (N_batch', N_t-1, N_grid, d_features) diff --git a/create_single_zarr.py b/create_single_zarr.py deleted file mode 100644 index 929239c9..00000000 --- a/create_single_zarr.py +++ /dev/null @@ -1,109 +0,0 @@ -# Standard library -import argparse -import os -import re - -# Third-party -import numcodecs -import xarray as xr -from tqdm import tqdm - - -def create_single_zarr_archive(config: dict, is_test: bool) -> None: - """ - Create a single large Zarr archive for either test or train data. - """ - # Determine the path based on whether it's test or train data - zarr_path = os.path.join( - config["zarr_path"], "test" if is_test else "train" - ) - zarr_name = "test_data.zarr" if is_test else "train_data.zarr" - full_zarr_path = os.path.join(zarr_path, zarr_name) - - # Ensure the directory exists - os.makedirs(zarr_path, exist_ok=True) - - # Initialize an empty list to store datasets - datasets = [] - - # Loop through all files and process - for root, _, files in os.walk(config["data_path"]): - for file in tqdm(files, desc="Processing files"): - full_path = os.path.join(root, file) - match = config["filename_pattern"].match(file) - if not match: - continue - - # Open the dataset - data = xr.open_dataset( - full_path, - engine="netcdf4", - chunks={"time": 1}, # Chunk only along the time dimension - autoclose=True, - ).drop_vars("grid_mapping_1", errors="ignore") - - # Check if the data belongs to the test year - data_is_test = config["test_year"] in data.time.dt.year.values - - # If the current data matches the desired type (test/train) - if data_is_test == is_test: - datasets.append(data) - - # Combine all datasets along the time dimension - combined_data = xr.concat(datasets, dim="time") - - # Set optimal compression - for var in combined_data.variables: - combined_data[var].encoding = {"compressor": config["compressor"]} - - # Save the combined dataset to a Zarr archive - combined_data.to_zarr( - store=full_zarr_path, - mode="w", - consolidated=True, - ) - print(f"Created Zarr archive at {full_zarr_path}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Create Zarr archives for weather data." - ) - parser.add_argument( - "--data_path", - type=str, - default="/scratch/mch/sadamov/ml_v1/", - help="Path to the raw data", - ) - parser.add_argument( - "--zarr_path", - type=str, - default="data/cosmo/samples/", - help="Path to the zarr output", - ) - parser.add_argument("--test_year", type=int, default=2020) - parser.add_argument( - "--filename_regex", - type=str, - help="Filename regex", - default="(.*)_extr.nc", - ) - - args = parser.parse_args() - - data_config = { - "data_path": args.data_path, - "filename_regex": args.filename_regex, - "zarr_path": args.zarr_path, - "compressor": numcodecs.Blosc( - cname="lz4", clevel=7, shuffle=numcodecs.Blosc.SHUFFLE - ), - "test_year": args.test_year, - "filename_pattern": re.compile(args.filename_regex), - } - - # Create Zarr archive for test data - create_single_zarr_archive(data_config, is_test=True) - - # Create Zarr archive for train data - create_single_zarr_archive(data_config, is_test=False) diff --git a/create_static_features.py b/create_static_features.py index cbab9259..dd629ccd 100644 --- a/create_static_features.py +++ b/create_static_features.py @@ -1,4 +1,5 @@ # Standard library +import os from argparse import ArgumentParser # Third-party @@ -15,28 +16,28 @@ def main(): parser.add_argument( "--xdim", type=str, - default="x_1", - help="Name of the x-dimension in the dataset (default: x_1)", + default="x", + help="Name of the x-dimension in the dataset (default: x)", ) parser.add_argument( "--ydim", type=str, - default="y_1", - help="Name of the x-dimension in the dataset (default: y_1)", + default="y", + help="Name of the x-dimension in the dataset (default: y)", ) parser.add_argument( "--zdim", type=str, - default="z_1", - help="Name of the x-dimension in the dataset (default: z_1)", + default="z", + help="Name of the x-dimension in the dataset (default: z)", ) parser.add_argument( "--field_names", nargs="+", - default=["hsurf", "FI", "P0FL"], + default=["HSURF", "FI", "HFL"], help=( "Names of the fields to extract from the .nc file " - '(default: ["hsurf", "FI", "P0FL"])' + '(default: ["HSURF", "FI", "HFL"])' ), ) parser.add_argument( @@ -49,14 +50,12 @@ def main(): ), ) parser.add_argument( - "--outdir", + "--dataset", type=str, - default="data/cosmo/static/", - help=( - "Output directory for the static features " - "(default: data/cosmo/static/)" - ), + default="cosmo", + help=("Name of the dataset (default: cosmo)"), ) + args = parser.parse_args() ds = xr.open_zarr(constants.EXAMPLE_FILE).isel(time=0) @@ -82,8 +81,10 @@ def main(): ) np_fields = np.concatenate(np_fields, axis=-1) # (N_x, N_y, N_fields) + outdir = os.path.join("data", args.dataset, "static/") + # Save the numpy array to a .npy file - np.save(args.outdir + "reference_geopotential_pressure.npy", np_fields) + np.save(outdir + "reference_geopotential_pressure.npy", np_fields) # Get the dimensions of the dataset dims = ds.sizes @@ -95,7 +96,7 @@ def main(): # Stack the 2D arrays into a 3D array with x and y as the first dimension grid_xy = np.stack((y_grid, x_grid)) - np.save(args.outdir + "nwp_xy.npy", grid_xy) # (2, N_x, N_y) + np.save(outdir + "nwp_xy.npy", grid_xy) # (2, N_x, N_y) # Create a mask with the same dimensions, initially set to False mask = np.full((dims[args.xdim], dims[args.ydim]), False) @@ -107,7 +108,7 @@ def main(): mask[:, -args.boundaries :] = True # right boundary # Save the numpy array to a .npy file - np.save(args.outdir + "border_mask", mask) # (N_x, N_y) + np.save(outdir + "border_mask", mask) # (N_x, N_y) if __name__ == "__main__": diff --git a/create_zarr_archive.py b/create_zarr_archive.py deleted file mode 100644 index 2a81f5b3..00000000 --- a/create_zarr_archive.py +++ /dev/null @@ -1,174 +0,0 @@ -# Standard library -import argparse -import glob -import os -import re -import shutil - -# Third-party -import numcodecs -import xarray as xr -from tqdm import tqdm - -# First-party -from neural_lam import constants - - -def append_or_create_zarr( - data_out: xr.Dataset, config: dict, zarr_name: str -) -> None: - """Append data to an existing Zarr archive or create a new one.""" - - if config["test_year"] in data_out.time.dt.year.values: - zarr_path = os.path.join(config["zarr_path"], "test", zarr_name) - else: - zarr_path = os.path.join(config["zarr_path"], "train", zarr_name) - - if os.path.exists(zarr_path): - data_out.to_zarr( - store=zarr_path, - mode="a", - consolidated=True, - append_dim="time", - ) - else: - data_out.to_zarr( - zarr_path, - mode="w", - consolidated=True, - ) - - -def load_data(config: dict) -> None: - """Load weather data from NetCDF files and store it in a Zarr archive.""" - - file_paths = [] - for root, _, files in os.walk(config["data_path"]): - for file in files: - full_path = os.path.join(root, file) - file_paths.append(full_path) - file_paths.sort() - - # Group file paths into chunks - file_groups = [ - file_paths[i : i + config["chunk_size"]] - for i in range(0, len(file_paths), config["chunk_size"]) - ] - - for group in tqdm(file_groups, desc="Processing file groups"): - # Create a new Zarr archive for each group Extract the date from the - # first file in the group - date = os.path.basename(group[0]).split("_")[0][3:] - zarr_name = f"data_{date}.zarr" - if not os.path.exists( - os.path.join(config["zarr_path"], "train", zarr_name) - ) and not os.path.exists( - os.path.join(config["zarr_path"], "test", zarr_name) - ): - for full_path in group: - process_file(full_path, config, zarr_name) - - -def process_file(full_path, config, zarr_name): - """Process a single NetCDF file and store it in a Zarr archive.""" - try: - # if zarr_name directory exists, skip - match = config["filename_pattern"].match(full_path) - if not match: - return None - data: xr.Dataset = xr.open_dataset( - full_path, - engine="netcdf4", - chunks={ - "time": 1, - "x_1": -1, - "y_1": -1, - "z_1": -1, - "zbound": -1, - }, - autoclose=True, - ).drop_vars("grid_mapping_1") - for var in data.variables: - data[var].encoding = {"compressor": config["compressor"]} - data.time.encoding = {"dtype": "float64"} - append_or_create_zarr(data, config, zarr_name) - # Display the progress - print(f"Processed: {full_path}") - except (FileNotFoundError, OSError) as e: - print(f"Error: {e}") - return None - - -def combine_zarr_archives(config) -> None: - """Combine the last Zarr archive from the train folder with the first from - the test folder.""" - - # Get the last Zarr archive from the train folder - train_archives = sorted( - glob.glob(os.path.join(config["zarr_path"], "train", "*.zarr")) - ) - - # Get the first Zarr archive from the test folder - test_archives = sorted( - glob.glob(os.path.join(config["zarr_path"], "test", "*.zarr")) - ) - first_test_archive = xr.open_zarr(test_archives[0], consolidated=True) - - val_archives_path = os.path.join(config["zarr_path"], "val") - - for t in range(first_test_archive.time.size): - first_test_archive.isel(time=slice(t, t + 1)).to_zarr( - train_archives[-1], mode="a", append_dim="time", consolidated=True - ) - - shutil.rmtree(test_archives[0]) - shutil.rmtree(test_archives[-1]) - - for file in test_archives[1:]: - filename = os.path.basename(file) - os.symlink(file, os.path.join(val_archives_path, filename)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Create a zarr archive.") - parser.add_argument( - "--data_path", - type=str, - required=True, - help="Path to the raw data", - default="/scratch/mch/sadamov/ml_v1/", - ) - parser.add_argument( - "--test_year", type=int, required=True, help="Test year", default=2020 - ) - parser.add_argument( - "--filename_regex", - type=str, - required=True, - help="Filename regex", - default="(.*)_extr.nc", - ) - - args = parser.parse_args() - - data_config = { - "data_path": args.data_path, - "filename_regex": args.filename_regex, - "zarr_path": ( - "/users/sadamov/pyprojects/" "neural-cosmo/data/cosmo/samples" - ), - "compressor": numcodecs.Blosc( - cname="lz4", clevel=7, shuffle=numcodecs.Blosc.SHUFFLE - ), - "chunk_size": constants.CHUNK_SIZE, - "test_year": args.test_year, - } - data_config.update( - { - "folders": os.listdir(data_config["data_path"]), - "filename_pattern": re.compile(data_config["filename_regex"]), - } - ) - - load_data(data_config) - combine_zarr_archives(data_config) diff --git a/environment.yml b/environment.yml index 912998d0..0866970c 100644 --- a/environment.yml +++ b/environment.yml @@ -6,6 +6,7 @@ channels: dependencies: - Cartopy - dask + - dask-jobqueue - imageio - ipython - matplotlib @@ -17,7 +18,7 @@ dependencies: - pyg - pyproj - pyprojroot - - pytorch + - pytorch=2.2.2=py3.12_cuda11.8_cudnn8.7.0_0 - pytorch-cuda=11.8 - pytorch-lightning - scikit-learn diff --git a/helper.py b/helper.py index 166212b0..159dfb4e 100644 --- a/helper.py +++ b/helper.py @@ -14,7 +14,7 @@ ds = xr.open_zarr(os.path.join(PATH, file)) ds_rechunked = ds.chunk({"time": -1}) - mean_tot_prec = ds_rechunked["TOT_PREC"].mean(dim=["y_1", "x_1"]).compute() + mean_tot_prec = ds_rechunked["TOT_PREC"].mean(dim=["y", "x"]).compute() # Find the maximum precipitation value and its corresponding time max_precip_value = mean_tot_prec.max().item() diff --git a/neural_lam/constants.py b/neural_lam/constants.py index 8da779a2..bb21971d 100644 --- a/neural_lam/constants.py +++ b/neural_lam/constants.py @@ -31,7 +31,7 @@ "Meridional wind component", "Relative humidity", "Pressure at Mean Sea Level", - "Pressure Perturbation", + "Pressure", "Surface Pressure", "Total Precipitation", "Total Water Vapor content", @@ -47,7 +47,7 @@ "V", "RELHUM", "PMSL", - "PP", + "P", "PS", "TOT_PREC", "TQV", @@ -63,7 +63,7 @@ "m/s", "Perc.", "Pa", - "hPa", + "Pa", "Pa", "$kg/m^2$", "$kg/m^2$", @@ -79,7 +79,7 @@ "V": 1, "RELHUM": 1, "PMSL": 1, - "PP": 1, + "P": 1, "PS": 1, "TOT_PREC": 1, "TQV": 1, @@ -89,6 +89,7 @@ } # Vertical levels +# BUG: This will change after sponges VERTICAL_LEVELS = [1, 5, 13, 22, 38, 41, 60] PARAM_CONSTRAINTS = { @@ -103,7 +104,7 @@ "V": 1, "RELHUM": 1, "PMSL": 0, - "PP": 1, + "P": 1, "PS": 0, "TOT_PREC": 0, "TQV": 0, @@ -139,13 +140,11 @@ # Plotting FIG_SIZE = (15, 10) -EXAMPLE_FILE = "data/cosmo/samples/train/data_2015112800.zarr" -CHUNK_SIZE = 100 -EVAL_DATETIME = "2020100215" +EXAMPLE_FILE = "data/cosmo/samples/train/data.zarr" +EVAL_DATETIMES = ["2015112800"] EVAL_PLOT_VARS = ["TQV"] STORE_EXAMPLE_DATA = False -COSMO_PROJ = ccrs.PlateCarree() -SELECTED_PROJ = COSMO_PROJ +SELECTED_PROJ = ccrs.PlateCarree() POLLON = -170.0 POLLAT = 43.0 SMOOTH_BOUNDARIES = False diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 7d49184c..8df5b5bf 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -122,8 +122,9 @@ def __init__(self, args): ) if var_name in constants.EVAL_PLOT_VARS ] - print("variable_indices", self.variable_indices) - print("selected_vars_units", self.selected_vars_units) + + utils.rank_zero_print("variable_indices", self.variable_indices) + utils.rank_zero_print("selected_vars_units", self.selected_vars_units) @pl.utilities.rank_zero_only def log_image(self, name, img): @@ -224,11 +225,12 @@ def single_prediction( """ raise NotImplementedError("No prediction step implemented") + # pylint: disable-next=unused-argument def predict_step(self, batch, batch_idx): """ Run the inference on batch. """ - prediction, target, pred_std = self.common_step(batch) + prediction, target, pred_std, _ = self.common_step(batch) # Compute all evaluation metrics for error maps # Note: explicitly list metrics here, as test_metrics can contain @@ -261,7 +263,7 @@ def predict_step(self, batch, batch_idx): # (B, N_log, num_grid_nodes) if self.trainer.global_rank == 0: - self.plot_examples(batch, batch_idx, prediction=prediction) + self.plot_examples(batch, prediction=prediction) self.inference_output.append(prediction) def unroll_prediction(self, init_states, forcing_features, true_states): @@ -328,8 +330,8 @@ def common_step(self, batch): forcing_features: (B, pred_steps, num_grid_nodes, d_forcing), where index 0 corresponds to index 1 of init_states """ - init_states, target_states = batch[:2] - forcing_features = batch[3] if len(batch) > 3 else None + init_states, target_states, batch_time = batch[:3] + forcing_features = batch[4] if len(batch) > 3 else None prediction, pred_std = self.unroll_prediction( init_states, forcing_features, target_states @@ -337,13 +339,13 @@ def common_step(self, batch): # prediction: (B, pred_steps, num_grid_nodes, d_f) # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) - return prediction, target_states, pred_std + return prediction, target_states, pred_std, batch_time def training_step(self, batch): """ Train on single batch """ - prediction, target, pred_std = self.common_step(batch) + prediction, target, pred_std, _ = self.common_step(batch) # Compute loss batch_loss = torch.mean( @@ -354,7 +356,12 @@ def training_step(self, batch): log_dict = {"train_loss": batch_loss} self.log_dict( - log_dict, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True + log_dict, + prog_bar=True, + on_step=True, + on_epoch=True, + sync_dist=True, + batch_size=batch[0].shape[0], ) return batch_loss @@ -375,7 +382,7 @@ def validation_step(self, batch, batch_idx): """ Run validation on single batch """ - prediction, target, pred_std = self.common_step(batch) + prediction, target, pred_std, _ = self.common_step(batch) time_step_loss = torch.mean( self.loss( @@ -392,7 +399,11 @@ def validation_step(self, batch, batch_idx): } val_log_dict["val_mean_loss"] = mean_loss self.log_dict( - val_log_dict, on_step=False, on_epoch=True, sync_dist=True + val_log_dict, + on_step=False, + on_epoch=True, + sync_dist=True, + batch_size=batch[0].shape[0], ) # Store MSEs @@ -421,7 +432,7 @@ def test_step(self, batch, batch_idx): """ Run test on single batch """ - prediction, target, pred_std = self.common_step(batch) + prediction, target, pred_std, batch_time = self.common_step(batch) # prediction: (B, pred_steps, num_grid_nodes, d_f) # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) @@ -441,7 +452,11 @@ def test_step(self, batch, batch_idx): test_log_dict["test_mean_loss"] = mean_loss self.log_dict( - test_log_dict, on_step=False, on_epoch=True, sync_dist=True + test_log_dict, + on_step=False, + on_epoch=True, + sync_dist=True, + batch_size=batch[0].shape[0], ) # Compute all evaluation metrics for error maps @@ -474,60 +489,39 @@ def test_step(self, batch, batch_idx): self.spatial_loss_maps.append(log_spatial_losses) # (B, N_log, num_grid_nodes) - # Plot example predictions (on rank 0 only) - self.plot_examples(batch, batch_idx, prediction=prediction) + if self.trainer.is_global_zero: + self.plot_examples( + batch, + prediction=prediction, + target=target, + batch_time=batch_time, + ) @rank_zero_only - def plot_examples(self, batch, batch_idx, prediction=None): + def plot_examples( + self, batch, prediction=None, target=None, batch_time=None + ): """ Plot the first n_examples forecasts from batch Parameters: - - batch: Tuple containing data to plot corresponding forecasts for - - batch_idx: Index of the batch being processed - - prediction: Tensor of existing predictions. Generate if None. + - batch: batch with data to plot corresponding forecasts for + - n_examples: number of forecasts to plot + - prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction. + Generate if None. The function checks for the presence of test_dataset or predict_dataset within the trainer's data module, handles indexing within the batch for targeted analysis, performs prediction rescaling, and plots results. """ - if prediction is None: - prediction, target = self.common_step(batch) - - target = batch[1] + if prediction is None or target is None or batch_time is None: + prediction, target, _, batch_time = self.common_step(batch) - # Determine the dataset to work with (test_dataset or predict_dataset) - dataset = None - if ( - hasattr(self.trainer.datamodule, "test_dataset") - and self.trainer.datamodule.test_dataset - ): - dataset = self.trainer.datamodule.test_dataset - plot_name = "test" - elif ( - hasattr(self.trainer.datamodule, "predict_dataset") - and self.trainer.datamodule.predict_dataset + if self.global_rank == 0 and any( + eval_datetime in batch_time + for eval_datetime in constants.EVAL_DATETIMES ): - dataset = self.trainer.datamodule.predict_dataset - plot_name = "prediction" - - if ( - dataset - and self.trainer.global_rank == 0 - and dataset.batch_index == batch_idx - ): - index_within_batch = dataset.index_within_batch - if not torch.is_tensor(index_within_batch): - index_within_batch = torch.tensor( - index_within_batch, - dtype=torch.int64, - device=prediction.device, - ) - - prediction = prediction[index_within_batch] - target = target[index_within_batch] - # Rescale to original data scale prediction_rescaled = prediction * self.data_std + self.data_mean prediction_rescaled = self.apply_constraints(prediction_rescaled) @@ -537,65 +531,68 @@ def plot_examples(self, batch, batch_idx, prediction=None): prediction_rescaled = self.smooth_prediction_borders( prediction_rescaled ) - # Each slice is (pred_steps, N_grid, d_f) Iterate over variables - for var_name, var_unit in self.selected_vars_units: - # Retrieve the indices for the current variable - var_indices = self.variable_indices[var_name] - for lvl_i, var_i in enumerate(var_indices): - # Calculate var_vrange for each index - lvl = constants.VERTICAL_LEVELS[lvl_i] - var_vmin = min( - prediction_rescaled[:, :, var_i].min(), - target_rescaled[:, :, var_i].min(), - ) - var_vmax = max( - prediction_rescaled[:, :, var_i].max(), - target_rescaled[:, :, var_i].max(), - ) - var_vrange = (var_vmin, var_vmax) - # Iterate over time steps - for t_i, (pred_t, target_t) in enumerate( - zip(prediction_rescaled, target_rescaled), start=1 - ): - eval_datetime_obj = datetime.strptime( - constants.EVAL_DATETIME, "%Y%m%d%H" - ) - current_datetime_obj = eval_datetime_obj + timedelta( - hours=t_i - ) - current_datetime_str = current_datetime_obj.strftime( - "%Y%m%d%H" - ) - title = ( - f"{var_name} ({var_unit}), t={current_datetime_str}" + for i, eval_datetime in enumerate(batch_time): + if eval_datetime not in constants.EVAL_DATETIMES: + continue + pred_rescaled = prediction_rescaled[i] + targ_rescaled = target_rescaled[i] + + for var_name, var_unit in self.selected_vars_units: + var_indices = self.variable_indices[var_name] + for lvl_i, var_i in enumerate(var_indices): + lvl = constants.VERTICAL_LEVELS[lvl_i] + var_vmin = min( + pred_rescaled[:, var_i].min(), + targ_rescaled[:, var_i].min(), ) - var_fig = vis.plot_prediction( - pred_t[:, var_i], - target_t[:, var_i], - title=title, - vrange=var_vrange, + var_vmax = max( + pred_rescaled[:, var_i].max(), + targ_rescaled[:, var_i].max(), ) - wandb.log( - { - f"{var_name}_{plot_name}_lvl_{lvl:02}" - f"_t_{current_datetime_str}": wandb.Image( - var_fig - ) - } - ) - plt.close("all") - - if constants.STORE_EXAMPLE_DATA: - # Save pred and target as .pt files - torch.save( - prediction_rescaled.cpu(), - os.path.join(wandb.run.dir, "example_pred.pt"), - ) - torch.save( - target_rescaled.cpu(), - os.path.join(wandb.run.dir, "example_target.pt"), - ) + var_vrange = (var_vmin, var_vmax) + + for t_i, (pred_t, target_t) in enumerate( + zip(pred_rescaled, targ_rescaled), start=1 + ): + print(f"Plotting {var_name} lvl {lvl_i} t {t_i}...") + current_datetime_str = ( + datetime.strptime(eval_datetime, "%Y%m%d%H") + + timedelta(hours=t_i) + ).strftime("%Y%m%d%H") + title = ( + f"{var_name} ({var_unit}), " + f"t={current_datetime_str}" + ) + var_fig = vis.plot_prediction( + pred_t[:, var_i], + target_t[:, var_i], + title=title, + vrange=var_vrange, + ) + wandb.log( + { + f"{var_name}_lvl_{lvl:02}_t_" + f"{current_datetime_str}": wandb.Image( + var_fig + ) + } + ) + plt.close("all") + + if constants.STORE_EXAMPLE_DATA: + torch.save( + pred_rescaled.cpu(), + os.path.join( + wandb.run.dir, f"example_pred_{eval_datetime}.pt" + ), + ) + torch.save( + targ_rescaled.cpu(), + os.path.join( + wandb.run.dir, f"example_target_{eval_datetime}.pt" + ), + ) @rank_zero_only def smooth_prediction_borders(self, prediction_rescaled): diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index a71b0c6b..5a280a1d 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -29,7 +29,7 @@ def __init__(self, args): # Specify dimensions of data self.num_mesh_nodes, _ = self.get_num_mesh() - print( + utils.rank_zero_print( f"Loaded graph with {self.num_grid_nodes + self.num_mesh_nodes} " f"nodes ({self.num_grid_nodes} grid, {self.num_mesh_nodes} mesh)" ) diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py index f767fba0..2c2bb149 100644 --- a/neural_lam/models/graph_lam.py +++ b/neural_lam/models/graph_lam.py @@ -25,7 +25,7 @@ def __init__(self, args): # grid_dim from data + static + batch_static mesh_dim = self.mesh_static_features.shape[1] m2m_edges, m2m_dim = self.m2m_features.shape - print( + utils.rank_zero_print( f"Edges in subgraphs: m2m={m2m_edges}, g2m={self.g2m_edges}, " f"m2g={self.m2g_edges}" ) diff --git a/neural_lam/rotate_grid.py b/neural_lam/rotate_grid.py deleted file mode 100644 index 23294486..00000000 --- a/neural_lam/rotate_grid.py +++ /dev/null @@ -1,94 +0,0 @@ -"""unrotate rotated pole coordinates to geographical lat/lon""" - -# Third-party -import numpy as np - -# First-party -from neural_lam import constants - - -def unrot_lon(rotlon, rotlat, pollon, pollat): - """Transform rotated longitude to longitude. - - Parameters - ---------- - rotlon : np.ndarray(i,j) - rotated longitude (deg) - rotlat : np.ndarray(i,j) - rotated latitude (deg) - pollon : float - rotated pole longitude (deg) - pollat : float - rotated pole latitude (deg) - - Returns - ------- - lon : np.ndarray(i,j) - geographical longitude - - """ - - # to radians - rlo = np.radians(rotlon) - rla = np.radians(rotlat) - - # sin and cos of pole position - s1 = np.sin(np.radians(pollat)) - c1 = np.cos(np.radians(pollat)) - s2 = np.sin(np.radians(pollon)) - c2 = np.cos(np.radians(pollon)) - - # subresults - tmp1 = s2 * ( - -s1 * np.cos(rlo) * np.cos(rla) + c1 * np.sin(rla) - ) - c2 * np.sin(rlo) * np.cos(rla) - tmp2 = c2 * ( - -s1 * np.cos(rlo) * np.cos(rla) + c1 * np.sin(rla) - ) + s2 * np.sin(rlo) * np.cos(rla) - - return np.degrees(np.arctan(tmp1 / tmp2)) - - -def unrot_lat(rotlat, rotlon, pollat): - """Transform rotated latitude to latitude. - - Parameters - ---------- - rotlat : np.ndarray(i,j) - rotated latitude (deg) - rotlon : np.ndarray(i,j) - rotated longitude (deg) - pollon : float - rotated pole longitude (deg) - pollat : float - rotated pole latitude (deg) - - Returns - ------- - lat : np.ndarray(i,j) - geographical latitude - - """ - - # to radians - rlo = np.radians(rotlon) - rla = np.radians(rotlat) - - # sin and cos of pole position - s1 = np.sin(np.radians(pollat)) - c1 = np.cos(np.radians(pollat)) - - # subresults - tmp1 = s1 * np.sin(rla) + c1 * np.cos(rla) * np.cos(rlo) - - return np.degrees(np.arcsin(tmp1)) - - -def unrotate_latlon(data): - """Unrotate lat/lon coordinates from rotated pole grid.""" - xx, yy = np.meshgrid(data.x_1.values, data.y_1.values) - # unrotate lon/lat - lon = unrot_lon(xx, yy, constants.POLLON, constants.POLLAT) - lat = unrot_lat(yy, xx, constants.POLLAT) - - return lon.T, lat.T diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 28c151ae..5bfc35a8 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -1,9 +1,12 @@ # Standard library import os +import time # Third-party import numpy as np +import pytorch_lightning as pl import torch +import wandb from pytorch_lightning.utilities import rank_zero_only from torch import nn from tueplots import bundles, figsizes @@ -275,3 +278,51 @@ def init_wandb_metrics(wandb_logger): experiment.define_metric("val_mean_loss", summary="min") for step in constants.VAL_STEP_LOG_ERRORS: experiment.define_metric(f"val_loss_unroll{step}", summary="min") + + +@rank_zero_only +def rank_zero_print(*args, **kwargs): + """Print only from rank 0 process""" + print(*args, **kwargs) + + +@rank_zero_only +def init_wandb(args): + """Initialize wandb""" + if args.resume_run is None: + prefix = "subset-" if args.subset_ds else "" + if args.eval: + prefix = prefix + f"eval-{args.eval}-" + run_name = ( + f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-" + f"{time.strftime('%m_%d_%H_%M_%S')}" + ) + wandb.init( + name=run_name, + project=constants.WANDB_PROJECT, + config=args, + ) + logger = pl.loggers.WandbLogger( + project=constants.WANDB_PROJECT, + name=run_name, + config=args, + log_model=True, + ) + wandb.save("slurm_train.sh") + wandb.save("slurm_predict.sh") + wandb.save("neural_lam/constants.py") + else: + wandb.init( + project=constants.WANDB_PROJECT, + config=args, + id=args.resume_run, + resume="must", + ) + logger = pl.loggers.WandbLogger( + project=constants.WANDB_PROJECT, + id=args.resume_run, + config=args, + log_model=True, + ) + + return logger diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 04873cf6..509a5f58 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -11,7 +11,6 @@ # First-party from neural_lam import constants, utils -from neural_lam.rotate_grid import unrotate_latlon from neural_lam.weather_dataset import WeatherDataModule @@ -62,14 +61,12 @@ def plot_error_map(errors, global_mean, step_length=1, title=None): y_ticklabels = [ ( f"{name if name != 'RELHUM' else 'RH'} ({unit}) " - f"{f'{level:02}' if constants.IS_3D[name] else ''}" + f"{f'{z:02}' if constants.IS_3D[name] else ''}" ) for name, unit in zip( constants.PARAM_NAMES_SHORT, constants.PARAM_UNITS ) - for level in ( - constants.VERTICAL_LEVELS if constants.IS_3D[name] else [0] - ) + for z in (constants.VERTICAL_LEVELS if constants.IS_3D[name] else [0]) ] y_ticklabels = sorted(y_ticklabels) ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size) @@ -94,8 +91,10 @@ def plot_prediction(pred, target, title=None, vrange=None): vmin, vmax = vrange[0].cpu().item(), vrange[1].cpu().item() # get test data - data_latlon = xr.open_zarr(constants.EXAMPLE_FILE).isel(time=0) - lon, lat = unrotate_latlon(data_latlon) + data_latlon = xr.open_zarr(constants.EXAMPLE_FILE, consolidated=True).isel( + time=0 + ) + lon, lat = data_latlon.lon.values.T, data_latlon.lat.values.T fig, axes = plt.subplots( 2, @@ -151,7 +150,7 @@ def plot_spatial_error(error, title=None, vrange=None): # get test data data_latlon = xr.open_zarr(constants.EXAMPLE_FILE).isel(time=0) - lon, lat = unrotate_latlon(data_latlon) + lon, lat = data_latlon.lon.values.T, data_latlon.lat.values.T fig, ax = plt.subplots( figsize=constants.FIG_SIZE, @@ -191,7 +190,7 @@ def verify_inference( file_path: str, save_path: str, feature_channel: int, vrange=None ): """ - Plot example prediction, forecast, and ground truth. + Plot example prediction, verification, and ground truth. Each has shape (N_grid,) """ @@ -199,7 +198,6 @@ def verify_inference( predictions_data_module = WeatherDataModule( "cosmo", path_verif_file=file_path, - split="verif", standardize=False, subset=False, batch_size=6, @@ -220,7 +218,7 @@ def verify_inference( # get test data data_latlon = xr.open_zarr(constants.EXAMPLE_FILE).isel(time=0) - lon, lat = unrotate_latlon(data_latlon) + lon, lat = data_latlon.lon.values.T, data_latlon.lat.values.T # Get common scale for values total = predictions[0, :, :, feature_channel] diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index adfd81f3..ed88db92 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -1,7 +1,7 @@ # Standard library -import glob import os from datetime import datetime, timedelta +from random import randint # Third-party import numpy as np @@ -10,24 +10,13 @@ import xarray as xr # First-party -# BUG: Import should work in interactive mode as well -> create pypi package from neural_lam import constants, utils -# pylint: disable=W0613:unused-argument -# pylint: disable=W0201:attribute-defined-outside-init - class WeatherDataset(torch.utils.data.Dataset): - """ - N_t = 1h - N_x = 582 - N_y = 390 - N_grid = 582*390 = 226980 - d_features = 4(features) * 21(vertical model levels) = 84 - d_forcing = 0 - #TODO: extract incoming radiation from KENDA - """ + """Weather dataset for PyTorch Lightning.""" + # pylint: disable=too-many-branches def __init__( self, dataset_name, @@ -44,117 +33,64 @@ def __init__( "train", "val", "test", - "pred", + "predict", "verif", ), "Unknown dataset split" - self.sample_dir_path = os.path.join( - "data", dataset_name, "samples", split - ) - self.batch_size = batch_size - self.batch_index = 0 - self.index_within_batch = 0 - self.sample_dir_path = path_verif_file - - if split == "verif" and os.path.exists(self.sample_dir_path): - self.np_files = np.load(self.sample_dir_path) + if split == "verif": + self.np_files = np.load(path_verif_file) self.split = split return - self.zarr_files = sorted( - glob.glob(os.path.join(self.sample_dir_path, "data*.zarr")) + self.zarr_path = os.path.join( + "data", dataset_name, "samples", split, "data.zarr" + ) + self.ds = xr.open_zarr(self.zarr_path, consolidated=True) + if split == "train": + self.ds = self.ds.sel(time=slice("2015", "2019")) + else: + # BUG: Clean this up after zarr archive is fixed + self.ds = self.ds.sel(time=slice("2015", "2020")) + + new_vars = {} + for var_name, data_array in self.ds.data_vars.items(): + if var_name in constants.PARAM_NAMES_SHORT: + if constants.IS_3D[var_name]: + for z in constants.VERTICAL_LEVELS: + new_key = f"{var_name}_{int(z)}" + new_vars[new_key] = data_array.sel(z=z).drop_vars("z") + # BUG: Clean this up after zarr archive is fixed + elif var_name == "T_2M": + new_vars[var_name] = data_array.sel(z=2).drop_vars("z") + elif var_name in ["U_10M", "V_10M"]: + new_vars[var_name] = data_array.sel(z=10).drop_vars("z") + elif var_name == "PMSL": + new_vars[var_name] = data_array.sel(z=0).drop_vars("z") + else: + new_vars[var_name] = data_array + + self.ds = ( + xr.Dataset(new_vars) + # BUG: This should not be necessary with clean data without nans + .drop_isel(time=848) + .to_array() + .transpose("time", "x", "y", "variable") ) - if len(self.zarr_files) == 0: - raise ValueError("No .zarr files found in directory") if subset: - if constants.EVAL_DATETIME is not None and split == "test": + if constants.EVAL_DATETIMES is not None and split == "test": eval_datetime_obj = datetime.strptime( - constants.EVAL_DATETIME, "%Y%m%d%H" + constants.EVAL_DATETIMES[0], "%Y%m%d%H" ) - for i, file in enumerate(self.zarr_files): - file_datetime_str = file.split("/")[-1].split("_")[1][:-5] - file_datetime_obj = datetime.strptime( - file_datetime_str, "%Y%m%d%H" + self.ds = self.ds.sel( + time=slice( + eval_datetime_obj, + eval_datetime_obj + timedelta(hours=50), ) - if ( - file_datetime_obj - <= eval_datetime_obj - < file_datetime_obj - + timedelta(hours=constants.CHUNK_SIZE) - ): - # Retrieve the current file and the next file if it - # exists - next_file_index = i + 1 - if next_file_index < len(self.zarr_files): - self.zarr_files = [ - file, - self.zarr_files[next_file_index], - ] - else: - self.zarr_files = [file] - position_within_file = int( - ( - eval_datetime_obj - file_datetime_obj - ).total_seconds() - // 3600 - ) - self.batch_index = ( - position_within_file // self.batch_size - ) - self.index_within_batch = ( - position_within_file % self.batch_size - ) - break + ) else: - self.zarr_files = self.zarr_files[0:2] - - start_datetime = ( - self.zarr_files[0] - .split("/")[-1] - .split("_")[1] - .replace(".zarr", "") - ) - - print("Data subset of 200 samples starts on the", start_datetime) - - # Separate 3D and 2D variables - variables_3d = [ - var for var in constants.PARAM_NAMES_SHORT if constants.IS_3D[var] - ] - variables_2d = [ - var - for var in constants.PARAM_NAMES_SHORT - if not constants.IS_3D[var] - ] - - # Stack 3D variables - datasets_3d = [ - xr.open_zarr(file, consolidated=True)[variables_3d] - .sel(z_1=constants.VERTICAL_LEVELS) - .to_array() - .stack(var=("variable", "z_1")) - .transpose("time", "x_1", "y_1", "var") - for file in self.zarr_files - ] - - # Stack 2D variables without selecting along z_1 - datasets_2d = [ - xr.open_zarr(file, consolidated=True)[variables_2d] - .to_array() - .pipe( - lambda ds: (ds if "z_1" in ds.dims else ds.expand_dims(z_1=[0])) - ) - .stack(var=("variable", "z_1")) - .transpose("time", "x_1", "y_1", "var") - for file in self.zarr_files - ] - - # Combine 3D and 2D datasets - self.zarr_datasets = [ - xr.concat([ds_3d, ds_2d], dim="var").sortby("var") - for ds_3d, ds_2d in zip(datasets_3d, datasets_2d) - ] + start_idx = randint(0, self.ds.time.size - 50) + self.ds = self.ds.isel(time=slice(start_idx, start_idx + 50)) self.standardize = standardize if standardize: @@ -178,57 +114,38 @@ def __init__( ) self.random_subsample = split == "train" self.split = split - - def __len__(self): - num_steps = ( + self.num_steps = ( constants.TRAIN_HORIZON if self.split == "train" else constants.EVAL_HORIZON ) - total_time = 1 - if hasattr(self, "zarr_files"): - total_time = len(self.zarr_files) * constants.CHUNK_SIZE - num_steps - return total_time + self.batch_size = batch_size + self.control_only = control_only + + def __len__(self): + if self.split == "verif": + return len(self.np_files) + return len(self.ds.time) - self.num_steps def __getitem__(self, idx): if self.split == "verif": return self.np_files - - num_steps = ( - constants.TRAIN_HORIZON - if self.split == "train" - else constants.EVAL_HORIZON - ) - - # Calculate which zarr files need to be loaded - start_file_idx = idx // constants.CHUNK_SIZE - end_file_idx = (idx + num_steps) // constants.CHUNK_SIZE - # Index of current slice - idx_sample = idx % constants.CHUNK_SIZE - - sample_archive = xr.concat( - self.zarr_datasets[start_file_idx : end_file_idx + 1], - dim="time", - ) - - sample_xr = sample_archive.isel( - time=slice(idx_sample, idx_sample + num_steps) - ) + sample_xr = self.ds.isel(time=slice(idx, idx + self.num_steps)) # (N_t', N_x, N_y, d_features') sample = torch.tensor(sample_xr.values, dtype=torch.float32) - sample = sample.flatten(1, 2) # (N_t, N_grid, d_features) if self.standardize: - # Standardize sample sample = (sample - self.data_mean) / self.data_std - # Split up sample in init. states and target states init_states = sample[:2] # (2, N_grid, d_features) target_states = sample[2:] # (sample_length-2, N_grid, d_features) - return init_states, target_states + batch_time = self.ds.isel(time=idx).time.values + batch_time = np.datetime_as_string(batch_time, unit="h") + batch_time = str(batch_time).replace("-", "").replace("T", "") + return init_states, target_states, batch_time class WeatherDataModule(pl.LightningDataModule): @@ -237,7 +154,6 @@ class WeatherDataModule(pl.LightningDataModule): def __init__( self, dataset_name, - split="train", path_verif_file=None, standardize=True, subset=False, @@ -251,24 +167,23 @@ def __init__( self.num_workers = num_workers self.standardize = standardize self.subset = subset - - def prepare_data(self): - # download, split, etc... called only on 1 GPU/TPU in distributed - pass + self.train_dataset = None + self.val_dataset = None + self.test_dataset = None + self.verif_dataset = None + self.predict_dataset = None def setup(self, stage=None): - # make assignments here (val/train/test/pred split) - # called on every process in DDP if stage == "fit" or stage is None: self.train_dataset = WeatherDataset( - self.dataset_name, + dataset_name=self.dataset_name, split="train", standardize=self.standardize, subset=self.subset, batch_size=self.batch_size, ) self.val_dataset = WeatherDataset( - self.dataset_name, + dataset_name=self.dataset_name, split="val", standardize=self.standardize, subset=self.subset, @@ -277,7 +192,7 @@ def setup(self, stage=None): if stage == "test" or stage is None: self.test_dataset = WeatherDataset( - self.dataset_name, + dataset_name=self.dataset_name, split="test", standardize=self.standardize, subset=self.subset, @@ -294,10 +209,10 @@ def setup(self, stage=None): batch_size=self.batch_size, ) - if stage == "pred" or stage is None: - self.pred_dataset = WeatherDataset( + if stage == "predict" or stage is None: + self.predict_dataset = WeatherDataset( self.dataset_name, - split="pred", + split="predict", standardize=self.standardize, subset=False, batch_size=1, @@ -317,7 +232,7 @@ def val_dataloader(self): """Load validation dataset.""" return torch.utils.data.DataLoader( self.val_dataset, - batch_size=self.batch_size // self.batch_size, + batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, pin_memory=False, @@ -333,10 +248,10 @@ def test_dataloader(self): pin_memory=False, ) - def pred_dataloader(self): + def predict_dataloader(self): """Load prediction dataset.""" return torch.utils.data.DataLoader( - self.pred_dataset, + self.predict_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, diff --git a/requirements.txt b/requirements.txt index f99002c2..a3af0d68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,8 @@ numpy>=1.24.2 wandb>=0.13.10 matplotlib>=3.7.0 +dask +dask_jobqueue scipy>=1.10.0 pytorch-lightning>=2.0.3 shapely>=2.0.1 diff --git a/slurm_eval.sh b/slurm_eval.sh index 013464ef..1bcef8ed 100644 --- a/slurm_eval.sh +++ b/slurm_eval.sh @@ -1,13 +1,13 @@ #!/bin/bash -l #SBATCH --job-name=NeurWPe +#SBATCH --account=s83 +#SBATCH --partition=normal #SBATCH --nodes=1 #SBATCH --ntasks-per-node=4 -#SBATCH --partition=normal -#SBATCH --account=s83 +#SBATCH --time=00:59:00 +#SBATCH --no-requeue #SBATCH --output=lightning_logs/neurwp_eval_out.log #SBATCH --error=lightning_logs/neurwp_eval_err.log -#SBATCH --time=03:00:00 -#SBATCH --no-requeue export PREPROCESS=false export NORMALIZE=false diff --git a/slurm_predict.sh b/slurm_predict.sh index 98d999cf..e50df9ec 100644 --- a/slurm_predict.sh +++ b/slurm_predict.sh @@ -1,16 +1,15 @@ #!/bin/bash -l -#SBATCH --job-name=NeurWPredict +#SBATCH --job-name=NeurWPp +#SBATCH --account=s83 +#SBATCH --partition=normal #SBATCH --nodes=1 #SBATCH --ntasks-per-node=4 -#SBATCH --partition=normal -#SBATCH --account=s83 +#SBATCH --time=00:59:00 +#SBATCH --no-requeue #SBATCH --output=lightning_logs/neurwp_pred_out.log #SBATCH --error=lightning_logs/neurwp_pred_err.log -#SBATCH --time=03:00:00 -#SBATCH --no-requeue - -export PREPROCESS=true +export PREPROCESS=false export NORMALIZE=false if [ "$PREPROCESS" = true ]; then @@ -29,7 +28,6 @@ fi # Load necessary modules conda activate neural-lam - ulimit -c 0 export OMP_NUM_THREADS=16 diff --git a/slurm_train.sh b/slurm_train.sh index 2a846b27..1278d2c3 100644 --- a/slurm_train.sh +++ b/slurm_train.sh @@ -1,13 +1,13 @@ #!/bin/bash -l #SBATCH --job-name=NeurWP +#SBATCH --account=s83 +#SBATCH --partition=normal #SBATCH --nodes=1 #SBATCH --ntasks-per-node=4 -#SBATCH --partition=normal -#SBATCH --account=s83 -#SBATCH --output=lightning_logs/neurwp_out.log -#SBATCH --error=lightning_logs/neurwp_err.log #SBATCH --mem=400G #SBATCH --no-requeue +#SBATCH --output=lightning_logs/neurwp_out.log +#SBATCH --error=lightning_logs/neurwp_err.log export PREPROCESS=false export NORMALIZE=false @@ -17,7 +17,7 @@ conda activate neural-lam if [ "$PREPROCESS" = true ]; then echo "Create static features" - srun -ul -N1 -n1 python create_static_features.py --boundaries 60 + srun -ul -N1 -n1 python create_static_features.py --boundaries 60 --dataset "cosmo" echo "Creating mesh" srun -ul -N1 -n1 python create_mesh.py --dataset "cosmo" --plot 1 echo "Creating grid features" diff --git a/train_model.py b/train_model.py index eaf08288..9f597fec 100644 --- a/train_model.py +++ b/train_model.py @@ -1,17 +1,14 @@ # Standard library import os -import time from argparse import ArgumentParser # Third-party import pytorch_lightning as pl import torch -import wandb from lightning_fabric.utilities import seed -from pytorch_lightning.utilities import rank_zero_only # First-party -from neural_lam import constants, utils +from neural_lam import utils from neural_lam.models.base_graph_model import BaseGraphModel from neural_lam.models.graph_lam import GraphLAM from neural_lam.models.hi_lam import HiLAM @@ -26,62 +23,6 @@ } -@rank_zero_only -def print_args(args): - """Print arguments""" - print("Arguments:") - for arg in vars(args): - print(f"{arg}: {getattr(args, arg)}") - - -@rank_zero_only -def print_eval(args_eval): - """Print evaluation""" - print(f"Running evaluation on {args_eval}") - - -@rank_zero_only -def init_wandb(args): - """Initialize wandb""" - if args.resume_run is None: - prefix = "subset-" if args.subset_ds else "" - if args.eval: - prefix = prefix + f"eval-{args.eval}-" - run_name = ( - f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-" - f"{time.strftime('%m_%d_%H_%M_%S')}" - ) - wandb.init( - name=run_name, - project=constants.WANDB_PROJECT, - config=args, - ) - logger = pl.loggers.WandbLogger( - project=constants.WANDB_PROJECT, - name=run_name, - config=args, - log_model=True, - ) - wandb.save("slurm_train.sh") - wandb.save("slurm_predict.sh") - wandb.save("neural_lam/constants.py") - else: - wandb.init( - project=constants.WANDB_PROJECT, - config=args, - id=args.resume_run, - resume="must", - ) - logger = pl.loggers.WandbLogger( - project=constants.WANDB_PROJECT, - id=args.resume_run, - config=args, - log_model=True, - ) - - return logger - - def main(): # pylint: disable=too-many-branches """ @@ -263,7 +204,6 @@ def main(): # Set seed seed.seed_everything(args.seed) - # Create datamodule data_module = WeatherDataModule( args.dataset, @@ -282,7 +222,7 @@ def main(): model_class = MODELS[args.model] model = model_class(args) - result = init_wandb(args) + result = utils.init_wandb(args) if result is not None: logger = result @@ -305,10 +245,15 @@ def main(): use_distributed_sampler = False else: use_distributed_sampler = True + utils.rank_zero_print("Arguments:") + for arg in vars(args): + utils.rank_zero_print(f"{arg}: {getattr(args, arg)}") if torch.cuda.is_available(): accelerator = "cuda" - devices = torch.cuda.device_count() + devices = int( + os.environ.get("SLURM_GPUS_PER_NODE", torch.cuda.device_count()) + ) num_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES", 1)) else: accelerator = "cpu" @@ -347,7 +292,7 @@ def main(): # Check if the mode is prediction elif args.eval == "predict": - data_module.split = "pred" + data_module.split = "predict" trainer.predict( model=model, datamodule=data_module, diff --git a/wandb/example.ckpt b/wandb/example.ckpt index d649e45b..3247b089 100644 Binary files a/wandb/example.ckpt and b/wandb/example.ckpt differ