Skip to content

Commit

Permalink
A few more documentation pieces added to the if name == main test
Browse files Browse the repository at this point in the history
  • Loading branch information
jsschreck committed Dec 10, 2024
1 parent 1168d35 commit ced1678
Showing 1 changed file with 94 additions and 2 deletions.
96 changes: 94 additions & 2 deletions credit/datasets/era5_multistep_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,37 @@ def __exit__(self, exc_type, exc_val, exc_tb):
if option == "1":

logger.info("Option 1: ERA5_MultiStep_Batcher")

logger.info(
"""
dataset_multi = ERA5_MultiStep_Batcher(
varname_upper_air=data_config['varname_upper_air'],
varname_surface=data_config['varname_surface'],
varname_dyn_forcing=data_config['varname_dyn_forcing'],
varname_forcing=data_config['varname_forcing'],
varname_static=data_config['varname_static'],
varname_diagnostic=data_config['varname_diagnostic'],
filenames=data_config['all_ERA_files'],
filename_surface=data_config['surface_files'],
filename_dyn_forcing=data_config['dyn_forcing_files'],
filename_forcing=data_config['forcing_files'],
filename_static=data_config['static_files'],
filename_diagnostic=data_config['diagnostic_files'],
history_len=data_config['history_len'],
forecast_len=data_config['forecast_len'],
skip_periods=data_config['skip_periods'],
max_forecast_len=data_config['max_forecast_len'],
transform=load_transforms(conf),
batch_size=batch_size
)
dataloader = DataLoader(
dataset_multi,
num_workers=1, # Must be 1 to use prefetching
drop_last=True, # Drop the last incomplete batch if not divisible by batch_size,
prefetch_factor=4
)
"""
)
start_time = time.time()
dataset_multi = ERA5_MultiStep_Batcher(
varname_upper_air=data_config['varname_upper_air'],
Expand Down Expand Up @@ -814,6 +844,37 @@ def __exit__(self, exc_type, exc_val, exc_tb):
elif option == "2":

logger.info("Testing 2: MultiprocessingBatcher")
logger.info(
"""
dataset_multi = MultiprocessingBatcher(
varname_upper_air=data_config['varname_upper_air'],
varname_surface=data_config['varname_surface'],
varname_dyn_forcing=data_config['varname_dyn_forcing'],
varname_forcing=data_config['varname_forcing'],
varname_static=data_config['varname_static'],
varname_diagnostic=data_config['varname_diagnostic'],
filenames=data_config['all_ERA_files'],
filename_surface=data_config['surface_files'],
filename_dyn_forcing=data_config['dyn_forcing_files'],
filename_forcing=data_config['forcing_files'],
filename_static=data_config['static_files'],
filename_diagnostic=data_config['diagnostic_files'],
history_len=data_config['history_len'],
forecast_len=data_config['forecast_len'],
skip_periods=data_config['skip_periods'],
max_forecast_len=data_config['max_forecast_len'],
transform=load_transforms(conf),
batch_size=batch_size,
num_workers=4
)
dataloader = DataLoader(
dataset_multi,
num_workers=0, # Cannot use multiprocessing in both
drop_last=True # Drop the last incomplete batch if not divisible by batch_size
)
"""
)

start_time = time.time()
dataset_multi = MultiprocessingBatcher(
Expand All @@ -840,7 +901,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):

dataloader = DataLoader(
dataset_multi,
num_workers=0, # Must be 1 to use pre-fetching
num_workers=0, # Cannot use multiprocessing in both
drop_last=True # Drop the last incomplete batch if not divisible by batch_size
)

Expand All @@ -866,6 +927,37 @@ def __exit__(self, exc_type, exc_val, exc_tb):
elif option == "3":

logger.info("Testing 3: MultiprocessingBatcherPrefetch")
logger.info(
"""
dataset_multi = MultiprocessingBatcherPrefetch(
varname_upper_air=data_config['varname_upper_air'],
varname_surface=data_config['varname_surface'],
varname_dyn_forcing=data_config['varname_dyn_forcing'],
varname_forcing=data_config['varname_forcing'],
varname_static=data_config['varname_static'],
varname_diagnostic=data_config['varname_diagnostic'],
filenames=data_config['all_ERA_files'],
filename_surface=data_config['surface_files'],
filename_dyn_forcing=data_config['dyn_forcing_files'],
filename_forcing=data_config['forcing_files'],
filename_static=data_config['static_files'],
filename_diagnostic=data_config['diagnostic_files'],
history_len=data_config['history_len'],
forecast_len=data_config['forecast_len'],
skip_periods=data_config['skip_periods'],
max_forecast_len=data_config['max_forecast_len'],
transform=load_transforms(conf),
batch_size=batch_size,
num_workers=6,
prefetch_factor=6
)
dataloader = DataLoader(
dataset_multi,
drop_last=True, # Drop the last incomplete batch if not divisible by batch_size,
)
"""
)

start_time = time.time()
dataset_multi = MultiprocessingBatcherPrefetch(
Expand Down

0 comments on commit ced1678

Please sign in to comment.