diff --git a/create_parameter_weights.py b/create_parameter_weights.py index cba107a7..4066d14b 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -83,7 +83,7 @@ def main(): print("Computing mean and std.-dev. for parameters...") means = [] squares = [] - for init_batch, target_batch, forcing_batch in tqdm(loader): + for init_batch, target_batch, _, forcing_batch in tqdm(loader): batch = torch.cat( (init_batch, target_batch), dim=1 ) # (N_batch, N_t, N_grid, d_features) @@ -128,7 +128,7 @@ def main(): diff_means = [] diff_squares = [] for batch_data in tqdm(loader_standard): - init_batch, target_batch, _ = batch_data + init_batch, target_batch, _ , _= batch_data batch = torch.cat(init_batch, target_batch, dim=1) # (N_batch, N_t', N_grid, d_features) batch_diffs = batch[:, 1:] - batch[:, :-1]