Skip to content

Commit

Permalink
Bugfixes + alignment with upstream
Browse files Browse the repository at this point in the history
Indexing of forcings/fluxes was faulty, needed update after proper windowing in weather dataset return
  • Loading branch information
Simon Adamov committed Apr 24, 2024
1 parent e8b484e commit e3101fb
Showing 1 changed file with 25 additions and 41 deletions.
66 changes: 25 additions & 41 deletions create_parameter_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,45 +81,36 @@ def main():
# Compute mean and std.-dev. of each parameter (+ flux forcing)
# across full dataset
print("Computing mean and std.-dev. for parameters...")

means = []
squares = []
flux_means = []
flux_squares = []
for batch_data in tqdm(loader):
if constants.GRID_FORCING_DIM > 0:
init_batch, target_batch, _, forcing_batch = batch_data
flux_batch = forcing_batch[
:, :, :, :3
] # fluxes are first 3 features
flux_means.append(torch.mean(flux_batch, dim=(1, 2, 3))) # (,)
flux_squares.append(torch.mean(flux_batch**2, dim=(1, 2, 3))) # (,)
else:
init_batch, target_batch, _ = batch_data

for init_batch, target_batch, forcing_batch in tqdm(loader):
batch = torch.cat(
(init_batch, target_batch), dim=1
) # (N_batch, N_t, N_grid, d_features)
means.append(torch.mean(batch, dim=(1, 2))) # (N_batch, d_features,)
squares.append(
torch.mean(batch**2, dim=(1, 2))
) # (N_batch, d_features,)
mean = torch.mean(torch.cat(means, dim=0), dim=0) # (d_features)
second_moment = torch.mean(torch.cat(squares, dim=0), dim=0)
std = torch.sqrt(second_moment - mean**2) # (d_features)

mean = torch.mean(torch.cat(means, dim=0), dim=0) # (d_features)
second_moment = torch.mean(torch.cat(squares, dim=0), dim=0)
std = torch.sqrt(second_moment - mean**2) # (d_features)

if constants.GRID_FORCING_DIM > 0:
flux_mean = torch.mean(torch.cat(flux_means, dim=0), dim=0) # (,)
flux_second_moment = torch.mean(
torch.cat(flux_squares, dim=0), dim=0
) # (,)
flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,)
flux_stats = torch.stack((flux_mean, flux_std))

print("Saving mean flux_stats...")
torch.save(flux_stats, os.path.join(static_dir_path, "flux_stats.pt"))
print("Saving mean, std.-dev...")
if constants.GRID_FORCING_DIM > 0:
# Flux at 1st windowed position is index 1 in forcing
flux_means = []
flux_squares = []
flux_batch = forcing_batch[:, :, :, 1]
flux_means.append(torch.mean(flux_batch)) # (,)
flux_squares.append(torch.mean(flux_batch**2)) # (,)
flux_mean = torch.mean(torch.stack(flux_means)) # (,)
flux_second_moment = torch.mean(torch.stack(flux_squares)) # (,)
flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,)
flux_stats = torch.stack((flux_mean, flux_std))
torch.save(
flux_stats, os.path.join(static_dir_path, "flux_stats.pt")
)

print("Saving mean, std.-dev, flux_stats...")
torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt"))
torch.save(std, os.path.join(static_dir_path, "parameter_std.pt"))

Expand All @@ -137,18 +128,11 @@ def main():
diff_means = []
diff_squares = []
for batch_data in tqdm(loader_standard):
if constants.GRID_FORCING_DIM > 0:
init_batch, target_batch, _, forcing_batch = batch_data
flux_batch = forcing_batch[
:, :, :, :3
] # fluxes are first 3 features
flux_means.append(torch.mean(flux_batch, dim=(1, 2, 3))) # (,)
flux_squares.append(torch.mean(flux_batch**2, dim=(1, 2, 3))) # (,)
else:
init_batch, target_batch, _ = batch_data
batch_diffs = init_batch[:, 1:] - target_batch
# (N_batch', N_t-1, N_grid, d_features)

init_batch, target_batch, _ = batch_data
batch = torch.cat(init_batch, target_batch, dim=1)
# (N_batch, N_t', N_grid, d_features)
batch_diffs = batch[:, 1:] - batch[:, :-1]
# (N_batch, N_t-1, N_grid, d_features)
diff_means.append(
torch.mean(batch_diffs, dim=(1, 2))
) # (N_batch', d_features,)
Expand Down

0 comments on commit e3101fb

Please sign in to comment.