Skip to content

Commit

Permalink
Replaced random sampling with DistributedSampler
Browse files Browse the repository at this point in the history
  • Loading branch information
jsschreck committed Dec 17, 2024
1 parent 72ccbe1 commit 6f13594
Showing 1 changed file with 124 additions and 18 deletions.
142 changes: 124 additions & 18 deletions credit/datasets/era5_multistep_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import torch
from functools import partial
from torch.utils.data import DistributedSampler

from credit.data import drop_var_from_dataset, get_forward_data
from credit.datasets.era5_multistep import worker
Expand Down Expand Up @@ -49,7 +50,8 @@ def __init__(
world_size=1,
skip_periods=None,
max_forecast_len=None,
batch_size=1
batch_size=1,
shuffle=True
):
"""
Initialize the ERA5_and_Forcing_Dataset
Expand Down Expand Up @@ -87,6 +89,7 @@ def __init__(
self.seed = seed
self.rank = rank
self.world_size = world_size
self.shuffle = shuffle

# skip periods
self.skip_periods = skip_periods
Expand Down Expand Up @@ -257,6 +260,15 @@ def __init__(
self.current_index = None
self.initial_index = None

# Use DistributedSampler for index management
self.sampler = DistributedSampler(
self,
num_replicas=world_size,
rank=rank,
shuffle=shuffle,
seed=seed
)

# Initialize state variables for batch management
self.batch_size = batch_size
self.batch_indices = None # To track initial indices for each batch item
Expand All @@ -267,20 +279,89 @@ def __init__(
# Initialize batch once when the dataset is created
self.initialize_batch()

# def initialize_batch(self):
# """
# Initializes batch indices using DistributedSampler's indices.
# Resets the time steps and forecast step counts.
# Ensures proper cycling when shuffle=False.
# """
# # Initialize the call count if not already present
# if not hasattr(self, "batch_call_count"):
# self.batch_call_count = 0

# # Set epoch for DistributedSampler to ensure consistent shuffling across devices
# if self.current_epoch is not None:
# self.sampler.set_epoch(self.current_epoch)

# # Retrieve indices for this GPU
# indices = list(self.sampler)
# total_indices = len(indices)

# # Select batch indices based on call count (deterministic cycling)
# start = self.batch_call_count * self.batch_size
# end = start + self.batch_size

# if end > total_indices:
# # Wrap-around to ensure no index is skipped
# indices = indices[start:] + indices[:(end % total_indices)]
# else:
# indices = indices[start:end]

# # Increment batch_call_count, reset when all indices are cycled
# self.batch_call_count += 1
# if start + self.batch_size >= total_indices:
# self.batch_call_count = 0 # Reset for next cycle

# # Assign batch indices
# self.batch_indices = indices
# self.time_steps = [0 for _ in self.batch_indices]
# self.forecast_step_counts = [0 for _ in self.batch_indices]
# self.initial_indices = list(self.batch_indices)

def initialize_batch(self):
"""
Initializes random starting indices for each batch item and resets their time steps and forecast counts.
This must be called before accessing the dataset or when resetting the batch.
Initializes batch indices using DistributedSampler's indices.
Ensures proper cycling when shuffle=False.
"""
# Randomly sample indices for the batch
self.batch_indices = np.random.choice(
range(self.__len__() - self.forecast_len),
size=self.batch_size,
replace=False,
)
self.time_steps = [0 for idx in self.batch_indices] # Initialize time to 0 for each item
self.forecast_step_counts = [0 for idx in self.batch_indices] # Initialize forecast step counts
self.initial_indices = list(self.batch_indices) # Track initial indices for each batch item
# Initialize the call count if not already present
if not hasattr(self, "batch_call_count"):
self.batch_call_count = 0

# Set epoch for DistributedSampler to ensure consistent shuffling across devices
if self.current_epoch is not None:
self.sampler.set_epoch(self.current_epoch)

# Retrieve indices for this GPU
indices = list(self.sampler)
total_indices = len(indices)

# Select batch indices based on call count (deterministic cycling)
start = self.batch_call_count * self.batch_size
end = start + self.batch_size

if not self.shuffle:
if end > total_indices:
# Simple wraparound by incrementing start index
start = start % total_indices
end = min(start + self.batch_size, total_indices)
indices = indices[start:end]
else:
if end > total_indices:
# Wrap-around to ensure no index is skipped
indices = indices[start:] + indices[:(end % total_indices)]
else:
indices = indices[start:end]

# Increment batch_call_count, reset when all indices are cycled
self.batch_call_count += 1
if start + self.batch_size >= total_indices:
self.batch_call_count = 0 # Reset for next cycle

# Assign batch indices
self.batch_indices = indices
self.time_steps = [0 for _ in self.batch_indices]
self.forecast_step_counts = [0 for _ in self.batch_indices]
self.initial_indices = list(self.batch_indices)

def __post_init__(self):
# Total sequence length of each sample.
Expand All @@ -293,10 +374,14 @@ def __len__(self):
total_len += len(ERA5_xarray["time"]) - self.total_seq_len + 1
return total_len

# def set_epoch(self, epoch):
# self.current_epoch = epoch
# self.current_index = None
# self.initial_index = None

def set_epoch(self, epoch):
self.current_epoch = epoch
self.current_index = None
self.initial_index = None
self.sampler.set_epoch(epoch)

def __getitem__(self, _):
"""
Expand All @@ -321,7 +406,7 @@ def __getitem__(self, _):
sample = self.worker(index_pair)

# Add index to the sample
sample["index"] = idx
sample["index"] = idx + current_t

# Concatenate data by common keys in sample
for key, value in sample.items():
Expand Down Expand Up @@ -748,6 +833,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):

batch_size = 2
data_config["forecast_len"] = 6
rank = 0
world_size = 2
shuffle = True

set_globals(data_config, namespace=globals())

Expand Down Expand Up @@ -783,7 +871,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
skip_periods=data_config['skip_periods'],
max_forecast_len=data_config['max_forecast_len'],
transform=load_transforms(conf),
batch_size=batch_size
batch_size=batch_size,
shuffle=shuffle,
rank=rank,
world_size=world_size
)
dataloader = DataLoader(
Expand Down Expand Up @@ -813,7 +904,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
skip_periods=data_config['skip_periods'],
max_forecast_len=data_config['max_forecast_len'],
transform=load_transforms(conf),
batch_size=batch_size
batch_size=batch_size,
shuffle=shuffle,
rank=rank,
world_size=world_size
)

dataloader = DataLoader(
Expand Down Expand Up @@ -866,6 +960,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
max_forecast_len=data_config['max_forecast_len'],
transform=load_transforms(conf),
batch_size=batch_size,
shuffle=shuffle,
rank=rank,
world_size=world_size,
num_workers=4
)
Expand Down Expand Up @@ -897,6 +994,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
max_forecast_len=data_config['max_forecast_len'],
transform=load_transforms(conf),
batch_size=batch_size,
shuffle=shuffle,
rank=rank,
world_size=world_size,
num_workers=4
)

Expand Down Expand Up @@ -949,8 +1049,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
max_forecast_len=data_config['max_forecast_len'],
transform=load_transforms(conf),
batch_size=batch_size,
shuffle=shuffle,
rank=rank,
world_size=world_size,
num_workers=6,
prefetch_factor=6
prefetch_factor=6,
)
dataloader = DataLoader(
Expand Down Expand Up @@ -980,6 +1083,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
max_forecast_len=data_config['max_forecast_len'],
transform=load_transforms(conf),
batch_size=batch_size,
shuffle=shuffle,
rank=rank,
world_size=world_size,
num_workers=6,
prefetch_factor=6
)
Expand Down

0 comments on commit 6f13594

Please sign in to comment.