Skip to content

Commit

Permalink
ensemble capability, KCRPS for single-step training and batch-size=1
Browse files Browse the repository at this point in the history
  • Loading branch information
dkimpara committed Dec 19, 2024
1 parent a326a6c commit 36394b6
Show file tree
Hide file tree
Showing 4 changed files with 369 additions and 1 deletion.
292 changes: 292 additions & 0 deletions config/test_cesm_ensemble.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# --------------------------------------------------------------------------------------------------------------------- #
# This yaml file implements hourly state-in-state-out crossformer
# on NSF NCAR HPCs (casper.ucar.edu and derecho.hpc.ucar.edu)
# The model is trained on hourly model-level ERA5 data with top solar irradiance, geopotential, and land-sea mask
# Output variables: model level [U, V, T, Q], single level [SP, t2m], and 500 hPa [U, V, T, Z, Q]
# --------------------------------------------------------------------------------------------------------------------- #
save_loc: '/glade/work/$USER/CREDIT_runs/test_cesm_ensemble/'
seed: 1000

#net short top of model: FSNT
#net short bot of model: FSNS
#net long top of model: FLNT
#net long bot of model: FLNS
#note Q is vapor + condensed (total liquid)
#Fnet_sfc = FSNS - FLNS - LHFLX - SHFLX <<<<<<-------
#Fnet_TOA =

#The net radiative flux at "top of atmosphere" is sometimes called RESTOA ("RES" for residual), and at "top of model" it is called RESTOM. These are typically diagnosed from model output as

#RESTOM = FSNT - FLNT <<<<<<-------
#RESTOA = FSNTOA - FLUT

#The RHS of these are the net shortwave and longwave fluxes. They are computed as differences:

#FSNT = SOLIN - FSUT
#FSNTOA = SOLIN - FSUTOA
#FLNT = FLUT - [FLDT]

#FLDT is not usually output from the model, but FLNT and FLUT are. SOLIN is the insolation.

data:
# upper-air variables
variables: ['U','V','T','Q']
save_loc: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/f.e21.CREDIT_climate_????.zarr'

# surface variables
surface_variables: ['PS','PSL','TREFHT']
save_loc_surface: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/f.e21.CREDIT_climate_????.zarr'

# dynamic forcing variables
dynamic_forcing_variables: ['SOLIN']
save_loc_dynamic_forcing: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/f.e21.CREDIT_climate_????.zarr'

# static variables
static_variables: ['z_norm','LANDM_COSLAT']
save_loc_static: '/glade/derecho/scratch/dkimpara/f.e21.CREDIT_climate.statics_1.0deg_32levs_latlon_v3.nc'

# diagnostic variables
diagnostic_variables: ['PRECT','CLDTOT','CLDHGH','CLDLOW','CLDMED','TAUX','TAUY','U10', 'QFLX','FSNS','FLNS','FSNT','FLNT','SHFLX','LHFLX']
save_loc_diagnostic: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/f.e21.CREDIT_climate_????.zarr'

save_loc_physics: '/glade/derecho/scratch/dkimpara/f.e21.CREDIT_climate.statics_1.0deg_32levs_latlon_v3.nc'

# mean / std path
mean_path: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/mean_6h_1979_2010_32lev_1.0deg_new.nc'
# regular z-score version
# std_path: '/glade/derecho/scratch/ksha/CREDIT_data/std_1h_1979_2018_16lev_0.25deg.nc'
# residual norm version
std_path: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/std_6h_1979_2010_32lev_1.0deg_new.nc'

# train / validation split
train_years: [1979, 2009]
valid_years: [2010, 2014]

scaler_type: 'std_new'

# state-in-state-out
history_len: 1
valid_history_len: 1

# single step
forecast_len: 0
valid_forecast_len: 0

one_shot: True
lead_time_periods: 6
skip_periods: null

# compatible with the old 'std'
static_first: True

level_info_file: '/glade/derecho/scratch/dkimpara/f.e21.CREDIT_climate.statics_1.0deg_32levs_latlon_v3.nc'
#level_list: [10, 30, 40, 50, 60, 70, 80, 90, 95, 100, 105, 110, 120, 130, 136, 137] not used for cesm
timestep: 21600 # 1h timestep

trainer:
type: standard # <---------- change to your type

mode: none
cpu_offload: False
activation_checkpoint: True

load_weights: True
load_optimizer: False
load_scaler: False
load_sheduler: False

skip_validation: False
update_learning_rate: False

save_backup_weights: True
save_best_weights: True

learning_rate: 1.0e-04 # <-- change to your lr
weight_decay: 0

train_batch_size: 1
valid_batch_size: 1
ensemble_size: 2
long_rollout: True

batches_per_epoch: 1000
valid_batches_per_epoch: 4
stopping_patience: 999

start_epoch: 0
num_epoch: 50
reload_epoch: False
epochs: &epochs 100

use_scheduler: True
scheduler: {'scheduler_type': 'cosine-annealing', 'T_max': *epochs, 'last_epoch': -1}

amp: False
grad_accum_every: 1
grad_max_norm: 1.0

# number of workers
thread_workers: 4
valid_thread_workers: 0

model:
# crossformer example
type: "crossformer"
frames: 1 # number of input states (default: 1)
image_height: 192 # number of latitude grids (default: 640)
image_width: 288 # number of longitude grids (default: 1280)
levels: 32 # number of upper-air variable levels (default: 15)
channels: 4 # upper-air variable channels
surface_channels: 3 # surface variable channels
input_only_channels: 3 # dynamic forcing, forcing, static channels
output_only_channels: 15 # diagnostic variable channels

patch_width: 1 # number of latitude grids in each 3D patch (default: 1)
patch_height: 1 # number of longitude grids in each 3D patch (default: 1)
frame_patch_size: 1 # number of input states in each 3D patch (default: 1)

dim: [64, 128, 256, 512] # Dimensionality of each layer
depth: [2, 2, 8, 2] # Depth of each layer
global_window_size: [4, 4, 2, 1] # Global window size for each layer
local_window_size: 3 # Local window size
cross_embed_kernel_sizes: # kernel sizes for cross-embedding
- [4, 8, 16, 32]
- [2, 4]
- [2, 4]
- [2, 4]
cross_embed_strides: [2, 2, 2, 2] # Strides for cross-embedding (default: [4, 2, 2, 2])
attn_dropout: 0. # Dropout probability for attention layers (default: 0.0)
ff_dropout: 0. # Dropout probability for feed-forward layers (default: 0.0)

# map boundary padding
pad_lon: 48 # number of grids to pad on 0 and 360 deg lon
pad_lat: 48 # number of grids to pad on -90 and 90 deg lat

post_conf:
activate: True
#this scaling maps your variables to the ERA5 units
#make sure to adjust if your timestep is not 6 hours
requires_scaling: True
scaling_coefs:
tot_precip: 86400 #model timestep in seconds m/day -> m/s
surf_net_solar: 21600 #model timestep in seconds (W/s) -> J/s
surf_net_therm: -21600 #model timestep in seconds (W/s) -> J/s
surf_shflx: -21600 #model timestep in seconds (W/s) -> J/s
surf_lhflx: -21600 #model timestep in seconds (W/s) -> J/s
top_net_solar: 21600 #model timestep in seconds (W/s) -> J/s
top_net_therm: -21600 #model timestep in seconds (W/s) -> J/s
evap: -1
SP: 1
U: 1
V: 1
gph_surf: 1
temp: 1
Q: 1
T: 1

skebs:
activate: True
lmax: None
mmax: None
freeze_base_model_weights: True
multistep: False
uniform_dissipation: True # tunable fixed dissipation rate array else, use a backscatter FCNN
grid: 'equiangular'

tracer_fixer:
activate: False
denorm: True
tracer_name: ['Q','PRECT','U10','CLDTOT','CLDHGH','CLDLOW','CLDMED']
tracer_thres: [0, 0, 0, 0, 0, 0, 0]

global_mass_fixer:
activate: False
activate_outside_model: False
simple_demo: False
denorm: True
grid_type: 'sigma'
midpoint: True
fix_level_num: 14
lon_lat_level_name: ['lon2d', 'lat2d', 'hyai', 'hybi']
surface_pressure_name: ['PS']
specific_total_water_name: ['Q']

global_water_fixer:
activate: False
activate_outside_model: False
simple_demo: False
denorm: True
grid_type: 'sigma'
midpoint: True
lon_lat_level_name: ['lon2d', 'lat2d', 'hyai', 'hybi']
surface_pressure_name: ['PS']
specific_total_water_name: ['Q']
precipitation_name: ['PRECT']
evaporation_name: ['QFLX']

global_energy_fixer:
activate: False
activate_outside_model: False
simple_demo: False
denorm: True
grid_type: 'sigma'
midpoint: True
lon_lat_level_name: ['lon2d', 'lat2d', 'hyai', 'hybi']
surface_pressure_name: ['PS']
air_temperature_name: ['T']
specific_total_water_name: ['Q']
u_wind_name: ['U']
v_wind_name: ['V']
surface_geopotential_name: ['PHIS']
TOA_net_radiation_flux_name: ['FSNT', 'FLNT']
surface_net_radiation_flux_name: ['FSNS', 'FLNS']
surface_energy_flux_name: ['SHFLX', 'LHFLX',]

loss:
# the main training loss
training_loss: "KCRPS"

# power loss (x), spectral_loss (x)
use_power_loss: False
use_spectral_loss: False

# use latitude weighting
use_latitude_weights: True
latitude_weights: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/LSM_static_variables_ERA5_zhght_onedeg.nc'

# turn-off variable weighting
use_variable_weights: False

predict:
forecasts:
type: "custom" # keep it as "custom"
start_year: 2010 # year of the first initialization (where rollout will start)
start_month: 1 # month of the first initialization
start_day: 1 # day of the first initialization
start_hours: [0] # hour-of-day for each initialization, 0 for 00Z, 12 for 12Z
duration: 1 # number of days to initialize, starting from the (year, mon, day) above
# duration should be divisible by the number of GPUs
# (e.g., duration: 384 for 365-day rollout using 32 GPUs)
days: 10 # forecast lead time as days (1 means 24-hour forecast)

save_forecast: '/glade/derecho/scratch/wchapman/CREDIT/cesm_wxformer/'
save_vars: ['PRECT','CLDTOT','CLDHGH','CLDLOW','CLDMED','TAUX','TAUY','U10','QFLX','FSNS','FLNS','FSNT','FLNT','SHFLX','LHFLX']
forcing_file: '/glade/campaign/cisl/aiml/wchapman/MLWPS/STAGING/f.e21.CREDIT_climate.climaterun.statics_1.0deg.nc'
climate_timesteps: 2190

# turn-off low-pass filter
use_laplace_filter: False

# deprecated
# save_format: "nc"

pbs: #derecho
conda: "/glade/u/home/dkimpara/credit-derecho"
project: "NAML0001"
job_name: "wxformer_1h"
walltime: "12:00:00"
nodes: 2
ncpus: 64
ngpus: 4
mem: '480GB'
queue: 'main'
61 changes: 60 additions & 1 deletion credit/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

def load_loss(loss_type, reduction="mean"):
"""Load a specified loss function by its type.
Helper function of VariableTotalLoss2D
This function returns a loss function based on the specified
`loss_type`. It supports several common loss functions, including
Expand Down Expand Up @@ -44,6 +45,7 @@ def load_loss(loss_type, reduction="mean"):
"logcosh": LogCoshLoss,
"xtanh": XTanhLoss,
"xsigmoid": XSigmoidLoss,
"KCRPS": KCRPSLoss,
}

if loss_type in losses:
Expand Down Expand Up @@ -188,6 +190,61 @@ def forward(self, prediction, target):
return loss


class KCRPSLoss(nn.Module):
"""Adapted from Nvidia Modulus
Estimate the CRPS from a finite ensemble
Computes the local Continuous Ranked Probability Score (CRPS) by using
the kernel version of CRPS. The cost is O(m log m).
Creates a map of CRPS and does not accumulate over lat/lon regions.
Approximates:
.. math::
CRPS(X, y) = E[X - y] - 0.5 E[X-X']
with
.. math::
sum_i=1^m |X_i - y| / m - 1/(2m^2) sum_i,j=1^m |x_i - x_j|
"""

def __init__(self, reduction, biased: bool = False):
super().__init__()
self.biased = biased

def forward(self, target, pred):
"""Forward pass for KCRPS loss
Args:
prediction (torch.Tensor): Predicted tensor.
target (torch.Tensor): Target tensor.
Returns:
torch.Tensor: CRPS loss values at each lat/lon
"""
pred = torch.movedim(pred, 0, -1)
return self._kernel_crps_implementation(pred, target, self.biased)

@torch.jit.script
def _kernel_crps_implementation(pred: torch.Tensor, obs: torch.Tensor, biased: bool) -> torch.Tensor:
"""An O(m log m) implementation of the kernel CRPS formulas"""
skill = torch.abs(pred - obs[..., None]).mean(-1)
pred, _ = torch.sort(pred)

# derivation of fast implementation of spread-portion of CRPS formula when x is sorted
# sum_(i,j=1)^m |x_i - x_j| = sum_(i<j) |x_i -x_j| + sum_(i > j) |x_i - x_j|
# = 2 sum_(i <= j) |x_i -x_j|
# = 2 sum_(i <= j) (x_j - x_i)
# = 2 sum_(i <= j) x_j - 2 sum_(i <= j) x_i
# = 2 sum_(j=1)^m j x_j - 2 sum (m - i + 1) x_i
# = 2 sum_(i=1)^m (2i - m - 1) x_i
m = pred.size(-1)
i = torch.arange(1, m + 1, device=pred.device, dtype=pred.dtype)
denom = m * m if biased else m * (m - 1)
factor = (2 * i - m - 1) / denom
spread = torch.sum(factor * pred, dim=-1)
return skill - spread

class SpectralLoss2D(torch.nn.Module):
"""Spectral Loss in 2D.
Expand Down Expand Up @@ -494,9 +551,11 @@ def __init__(self, conf, validation=False):
self.validation = validation
if self.validation:
self.loss_fn = nn.L1Loss(reduction="none")
elif conf["loss"]["training_loss"] == "KCRPS": # for ensembles, load same loss for train and valid
self.loss_fn = load_loss(self.training_loss, reduction="none")
else:
self.loss_fn = load_loss(self.training_loss, reduction="none")

def forward(self, target, pred):
"""Calculate the total loss for the given target and prediction.
Expand Down
3 changes: 3 additions & 0 deletions credit/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,9 @@ def credit_main_parser(
"train_batch_size" in conf["trainer"]
), "Training set batch size ('train_batch_size') is missing from onf['trainer']"

if "ensemble_size" not in conf["trainer"]:
conf["trainer"]["ensemble_size"] = 1

if "load_scaler" not in conf["trainer"]:
conf["trainer"]["load_scaler"] = False

Expand Down
Loading

0 comments on commit 36394b6

Please sign in to comment.