Skip to content

Commit

Permalink
Bugfixes before large training
Browse files Browse the repository at this point in the history
Fluxes should have size numstep-2
batch_times now actually covers just one timestep
TQV is not part of the current variable selection
  • Loading branch information
Simon Adamov committed Apr 23, 2024
1 parent 78c1eff commit f060f4d
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 79 deletions.
11 changes: 6 additions & 5 deletions neural_lam/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
PARAM_CONSTRAINTS = {
"RELHUM": (0, 100),
"CLCT": (0, 100),
"TQV": (0, None),
# "TQV": (0, None),
"TOT_PREC": (0, None),
}

Expand Down Expand Up @@ -179,14 +179,15 @@
# Plotting
FIG_SIZE = (15, 10)
EXAMPLE_FILE = "data/cosmo/samples/train/data.zarr"
EVAL_DATETIMES = ["2015112800"]
EVAL_DATETIMES = ["2019010100"] # prev_prev timestep (t-2)
EVAL_PLOT_VARS = ["T_2M"]
STORE_EXAMPLE_DATA = False
STORE_EXAMPLE_DATA = True
SELECTED_PROJ = ccrs.PlateCarree()
SMOOTH_BOUNDARIES = False

# Some constants useful for sub-classes
GRID_FORCING_DIM = 7 # 3 fluxes variables + 4 time-related features
# Some constants useful for sub-classes 3 fluxes variables + 4 time-related
# features; in packages of three (prev, prev_prev, current)
GRID_FORCING_DIM = (3 + 4) * 3
GRID_STATE_DIM = sum(
len(VERTICAL_LEVELS) if IS_3D[param] else 1 for param in PARAM_NAMES_SHORT
)
30 changes: 13 additions & 17 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@ def __init__(self, args):

# Some constants useful for sub-classes
self.grid_forcing_dim = constants.GRID_FORCING_DIM
count_3d_fields = sum(value == 1 for value in constants.IS_3D.values())
count_2d_fields = sum(value != 1 for value in constants.IS_3D.values())
self.grid_state_dim = (
len(constants.VERTICAL_LEVELS) * count_3d_fields + count_2d_fields
)
self.grid_state_dim = constants.GRID_STATE_DIM

# Load static features for grid/data
static_data_dict = utils.load_static_data(args.dataset)
Expand Down Expand Up @@ -334,7 +330,7 @@ def common_step(self, batch):
forcing_features: (B, pred_steps, num_grid_nodes, d_forcing),
where index 0 corresponds to index 1 of init_states
"""
init_states, target_states, batch_time = batch[:3]
init_states, target_states, batch_times = batch[:3]
forcing_features = batch[3] if len(batch) > 3 else None

prediction, pred_std = self.unroll_prediction(
Expand All @@ -343,7 +339,7 @@ def common_step(self, batch):
# prediction: (B, pred_steps, num_grid_nodes, d_f)
# pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,)

return prediction, target_states, pred_std, batch_time
return prediction, target_states, pred_std, batch_times

def training_step(self, batch):
"""
Expand Down Expand Up @@ -436,7 +432,7 @@ def test_step(self, batch, batch_idx):
"""
Run test on single batch
"""
prediction, target, pred_std, batch_time = self.common_step(batch)
prediction, target, pred_std, batch_times = self.common_step(batch)
# prediction: (B, pred_steps, num_grid_nodes, d_f)
# pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,)

Expand Down Expand Up @@ -498,12 +494,12 @@ def test_step(self, batch, batch_idx):
batch,
prediction=prediction,
target=target,
batch_time=batch_time,
batch_times=batch_times,
)

@rank_zero_only
def plot_examples(
self, batch, prediction=None, target=None, batch_time=None
self, batch, prediction=None, target=None, batch_times=None
):
"""
Plot the first n_examples forecasts from batch
Expand All @@ -519,13 +515,15 @@ def plot_examples(
handles indexing within the batch for targeted analysis,
performs prediction rescaling, and plots results.
"""
if prediction is None or target is None or batch_time is None:
prediction, target, _, batch_time = self.common_step(batch)
if prediction is None or target is None or batch_times is None:
prediction, target, _, batch_times = self.common_step(batch)

if self.global_rank == 0 and any(
eval_datetime in batch_time
eval_datetime in batch_times
for eval_datetime in constants.EVAL_DATETIMES
):
print("Plotting examples...")
print("batch_times", batch_times)
# Rescale to original data scale
prediction_rescaled = prediction * self.data_std + self.data_mean
prediction_rescaled = self.apply_constraints(prediction_rescaled)
Expand All @@ -536,7 +534,7 @@ def plot_examples(
prediction_rescaled
)

for i, eval_datetime in enumerate(batch_time):
for i, eval_datetime in enumerate(batch_times):
if eval_datetime not in constants.EVAL_DATETIMES:
continue
pred_rescaled = prediction_rescaled[i]
Expand Down Expand Up @@ -786,9 +784,7 @@ def on_test_epoch_end(self):

# Get all the images for the current variable and index
images = sorted(
glob.glob(
f"{dir_path}/{var_name}_test_lvl_{lvl:02}_t_*.png"
)
glob.glob(f"{dir_path}/{var_name}_lvl_{lvl:02}_t_*.png")
)
# Generate the GIF
with imageio.get_writer(
Expand Down
122 changes: 79 additions & 43 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ def __init__(
"verif",
), "Unknown dataset split"

self.random_subsample = split == "train"
self.split = split
self.batch_size = batch_size
self.control_only = control_only

if split == "verif":
self.np_files = np.load(path_verif_file)
self.split = split
Expand All @@ -49,7 +54,7 @@ def __init__(
if split == "train":
self.ds = self.ds.sel(time=slice("2015", "2019"))
else:
self.ds = self.ds.sel(time="2020")
self.ds = self.ds.sel(time=slice("2020"))

new_vars = {}
forcings = {}
Expand All @@ -75,28 +80,48 @@ def __init__(
.transpose("time", "x", "y", "variable")
)

self.num_steps = (
constants.TRAIN_HORIZON
if self.split == "train"
else constants.EVAL_HORIZON
)

if subset:
if constants.EVAL_DATETIMES is not None and split == "test":
utils.rank_zero_print(
f"Subsetting test dataset, using only first "
f"{self.num_steps} hours after "
f"{constants.EVAL_DATETIMES[0]}"
)
eval_datetime_obj = dt.datetime.strptime(
constants.EVAL_DATETIMES[0], "%Y%m%d%H"
)
init_datetime = np.datetime64(eval_datetime_obj, "ns")
end_datetime = np.datetime64(
eval_datetime_obj + dt.timedelta(hours=self.num_steps), "ns"
)
assert (
init_datetime in self.ds.time.values
), f"Eval datetime {init_datetime} not in dataset. "
self.ds = self.ds.sel(
time=slice(
eval_datetime_obj,
eval_datetime_obj + dt.timedelta(hours=50),
init_datetime,
end_datetime,
)
)
self.forcings = self.forcings.sel(
time=slice(
eval_datetime_obj,
eval_datetime_obj + dt.timedelta(hours=50),
init_datetime,
end_datetime,
)
)
else:
start_idx = randint(0, self.ds.time.size - 50)
self.ds = self.ds.isel(time=slice(start_idx, start_idx + 50))
start_idx = randint(0, self.ds.time.size - self.num_steps)
self.ds = self.ds.isel(
time=slice(start_idx, start_idx + self.num_steps)
)
self.forcings = self.forcings.isel(
time=slice(start_idx, start_idx + 50)
time=slice(start_idx, start_idx + self.num_steps)
)

self.standardize = standardize
Expand All @@ -119,15 +144,6 @@ def __init__(
ds_stats["data_mean"],
ds_stats["data_std"],
)
self.random_subsample = split == "train"
self.split = split
self.num_steps = (
constants.TRAIN_HORIZON
if self.split == "train"
else constants.EVAL_HORIZON
)
self.batch_size = batch_size
self.control_only = control_only

def __len__(self):
if self.split == "verif":
Expand All @@ -138,7 +154,7 @@ def __getitem__(self, idx):
if self.split == "verif":
return self.np_files
sample_xr = self.ds.isel(time=slice(idx, idx + self.num_steps))
forcings = self.forcings.isel(time=slice(idx, idx + 1))
forcings = self.forcings.isel(time=slice(idx, idx + self.num_steps))

# (N_t', N_x, N_y, d_features')
sample = torch.tensor(sample_xr.values, dtype=torch.float32)
Expand All @@ -153,46 +169,66 @@ def __getitem__(self, idx):
init_states = sample[:2] # (2, N_grid, d_features)
target_states = sample[2:] # (sample_length-2, N_grid, d_features)

batch_time = self.ds.isel(time=idx).time.values
batch_time = np.datetime_as_string(batch_time, unit="h")
batch_time = str(batch_time).replace("-", "").replace("T", "")
batch_times = self.ds.isel(
time=slice(idx, idx + self.num_steps)
).time.values
batch_times = np.datetime_as_string(batch_times, unit="h")
batch_times = [
str(t).replace("-", "").replace("T", "") for t in batch_times
]

# Time of day and year
dt_obj = dt.datetime.strptime(batch_time, "%Y%m%d%H")
hour_of_day = dt_obj.hour
second_into_year = (
dt_obj - dt.datetime(dt_obj.year, 1, 1)
).total_seconds()

hour_angle = torch.tensor(
(hour_of_day / 12) * torch.pi
dt_objs = [dt.datetime.strptime(t, "%Y%m%d%H") for t in batch_times]
hours_of_day = [dt_obj.hour for dt_obj in dt_objs]
seconds_into_year = [
(dt_obj - dt.datetime(dt_obj.year, 1, 1)).total_seconds()
for dt_obj in dt_objs
]

hour_angles = torch.tensor(
[(hour_of_day / 12) * torch.pi for hour_of_day in hours_of_day]
) # (sample_len,)
year_angle = torch.tensor(
(second_into_year / constants.SECONDS_IN_YEAR) * 2 * torch.pi
year_angles = torch.tensor(
[
(second_into_year / constants.SECONDS_IN_YEAR) * 2 * torch.pi
for second_into_year in seconds_into_year
]
) # (sample_len,)
datetime_forcing = torch.stack(
(
torch.sin(hour_angle),
torch.cos(hour_angle),
torch.sin(year_angle),
torch.cos(year_angle),
torch.sin(hour_angles),
torch.cos(hour_angles),
torch.sin(year_angles),
torch.cos(year_angles),
),
dim=0,
) # (N_t, 4)
dim=1,
) # (sample_len, 4)
datetime_forcing = (datetime_forcing + 1) / 2 # Rescale to [0,1]

datetime_forcing = (
datetime_forcing.unsqueeze(0)
.unsqueeze(0)
.expand(-1, forcings.shape[1], -1)
) # (N_t, N_grid, 4)
datetime_forcing = datetime_forcing.unsqueeze(1).expand(
-1, forcings.shape[1], -1
) # (sample_len, N_grid, 4)

# Put forcing features together
forcings = torch.cat(
(forcings, datetime_forcing), dim=-1
) # (sample_len, N_grid, d_forcing)

return init_states, target_states, batch_time, forcings
# Combine forcing over each window of 3 time steps (prev_prev, prev,
# current)
forcing = torch.cat(
(
forcings[:-2],
forcings[1:-1],
forcings[2:],
),
dim=2,
) # (sample_len-2, N_grid, 3*d_forcing)
# Now index 0 of ^ corresponds to forcing at index 0-2 of sample

# Start the plotting at the first time step
batch_time = batch_times[0]
return init_states, target_states, batch_time, forcing


class WeatherDataModule(pl.LightningDataModule):
Expand Down
26 changes: 17 additions & 9 deletions slurm_eval.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,32 @@
#SBATCH --job-name=NeurWPe
#SBATCH --account=s83
#SBATCH --nodes=1
#SBATCH --gres=gpu:8
#SBATCH --ntasks-per-node=8
#SBATCH --partition=normal
#SBATCH --mem=300G
#SBATCH --mem=375G
#SBATCH --no-requeue
#SBATCH --output=lightning_logs/neurwp_eval_out.log
#SBATCH --error=lightning_logs/neurwp_eval_err.log

export PREPROCESS=true
export NORMALIZE=false
export DATASET="cosmo"
export MODEL="hi_lam"

# Load necessary modules
conda activate neural-lam

if [ "$PREPROCESS" = true ]; then
echo "Create static features"
python create_static_features.py --boundaries 60
echo "Creating mesh"
python create_mesh.py --dataset $DATASET --plot 1
python create_static_features.py --boundaries 60 --dataset $DATASET
if [ "$MODEL" = "hi_lam" ]; then
echo "Creating hierarchical mesh"
python create_mesh.py --dataset $DATASET --plot 1 --graph hierarchical --levels 4 --hierarchical 1
else
echo "Creating flat mesh"
python create_mesh.py --dataset $DATASET --plot 1
fi
echo "Creating grid features"
python create_grid_features.py --dataset $DATASET
if [ "$NORMALIZE" = true ]; then
Expand All @@ -30,8 +37,9 @@ if [ "$PREPROCESS" = true ]; then
fi
fi

ulimit -c 0
export OMP_NUM_THREADS=16

srun -ul python train_model.py --load "wandb/example.ckpt" --dataset $DATASET \
--eval="test" --subset_ds 1 --n_workers 2 --batch_size 6
echo "Evaluating model"
if [ "$MODEL" = "hi_lam" ]; then
srun -ul python train_model.py --dataset $DATASET --val_interval 2 --epochs 1 --n_workers 4 --batch_size 1 --subset_ds 1 --model hi_lam --graph hierarchical --load wandb/example.ckpt --eval="test"
else
srun -ul python train_model.py --dataset $DATASET --val_interval 2 --epochs 1 --n_workers 4 --batch_size 1 --subset_ds 1 --load "wandb/example.ckpt" --eval="test"
fi
8 changes: 3 additions & 5 deletions slurm_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#SBATCH --job-name=NeurWP
#SBATCH --account=s83
#SBATCH --time=24:00:00
#SBATCH --nodes=1
#SBATCH --nodes=5
#SBATCH --gres=gpu:8
#SBATCH --ntasks-per-node=8
#SBATCH --partition=normal
Expand Down Expand Up @@ -39,11 +39,9 @@ if [ "$PREPROCESS" = true ]; then
fi
fi

ulimit -c 0

echo "Training model"
if [ "$MODEL" = "hi_lam" ]; then
srun -ul python train_model.py --dataset $DATASET --val_interval 2 --epochs 1 --n_workers 4 --batch_size 1 --subset_ds 0 --model hi_lam --graph hierarchical
srun -ul python train_model.py --dataset $DATASET --val_interval 20 --epochs 40 --n_workers 4 --batch_size 1 --subset_ds 0 --model hi_lam --graph hierarchical
else
srun -ul python train_model.py --dataset $DATASET --val_interval 2 --epochs 1 --n_workers 4 --batch_size 1 --subset_ds 0
srun -ul python train_model.py --dataset $DATASET --val_interval 20 --epochs 40 --n_workers 4 --batch_size 1 --subset_ds 0
fi
Binary file modified wandb/example.ckpt
Binary file not shown.

0 comments on commit f060f4d

Please sign in to comment.