Skip to content

Commit

Permalink
fixed return elements (batch_size was missing)
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Adamov committed Apr 24, 2024
1 parent e3101fb commit 5fc32bb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions create_parameter_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def main():
print("Computing mean and std.-dev. for parameters...")
means = []
squares = []
for init_batch, target_batch, forcing_batch in tqdm(loader):
for init_batch, target_batch, _, forcing_batch in tqdm(loader):
batch = torch.cat(
(init_batch, target_batch), dim=1
) # (N_batch, N_t, N_grid, d_features)
Expand Down Expand Up @@ -128,7 +128,7 @@ def main():
diff_means = []
diff_squares = []
for batch_data in tqdm(loader_standard):
init_batch, target_batch, _ = batch_data
init_batch, target_batch, _ , _= batch_data
batch = torch.cat(init_batch, target_batch, dim=1)
# (N_batch, N_t', N_grid, d_features)
batch_diffs = batch[:, 1:] - batch[:, :-1]
Expand Down

0 comments on commit 5fc32bb

Please sign in to comment.