Skip to content

Commit

Permalink
ensemble capability for batch_size>1
Browse files Browse the repository at this point in the history
  • Loading branch information
dkimpara committed Dec 19, 2024
1 parent 36394b6 commit 1a6c1b2
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 13 deletions.
18 changes: 10 additions & 8 deletions config/test_cesm_ensemble.yml
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,14 @@ predict:
# save_format: "nc"

pbs: #derecho
conda: "/glade/u/home/dkimpara/credit-derecho"
conda: "/glade/u/home/dkimpara/credit"
project: "NAML0001"
job_name: "wxformer_1h"
walltime: "12:00:00"
nodes: 2
ncpus: 64
ngpus: 4
mem: '480GB'
queue: 'main'
job_name: "test_cesm_ensemble"
walltime: "00:15:00"
nodes: 1
ncpus: 8
ngpus: 1
mem: '32GB'
gpu_type: 'v100'
project: 'NAML0001'
queue: 'casper'
16 changes: 12 additions & 4 deletions credit/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,18 @@ class KCRPSLoss(nn.Module):
def __init__(self, reduction, biased: bool = False):
super().__init__()
self.biased = biased

self.batched_forward = torch.vmap(self.single_sample_forward)

def forward(self, target, pred):
"""Forward pass for KCRPS loss
# integer division but will error out next op if there is a remainder
ensemble_size = pred.shape[0] // target.shape[0] + pred.shape[0] % target.shape[0]
pred = pred.view(target.shape[0], ensemble_size, *target.shape[1:]) #b, ensemble, c, t, lat, lon
# apply single_sample_forward to each dim
target = target.unsqueeze(1)
return self.batched_forward(target, pred).squeeze(1)

def single_sample_forward(self, target, pred):
"""Forward pass for KCRPS loss for a single sample
Args:
prediction (torch.Tensor): Predicted tensor.
Expand All @@ -225,8 +234,7 @@ def forward(self, target, pred):
pred = torch.movedim(pred, 0, -1)
return self._kernel_crps_implementation(pred, target, self.biased)

@torch.jit.script
def _kernel_crps_implementation(pred: torch.Tensor, obs: torch.Tensor, biased: bool) -> torch.Tensor:
def _kernel_crps_implementation(self, pred: torch.Tensor, obs: torch.Tensor, biased: bool) -> torch.Tensor:
"""An O(m log m) implementation of the kernel CRPS formulas"""
skill = torch.abs(pred - obs[..., None]).mean(-1)
pred, _ = torch.sort(pred)
Expand Down
8 changes: 7 additions & 1 deletion credit/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, conf, predict_mode=False):
atmos_vars = conf["data"]["variables"]
surface_vars = conf["data"]["surface_variables"]
diag_vars = conf["data"]["diagnostic_variables"]

levels = (
conf["model"]["levels"]
if "levels" in conf["model"]
Expand All @@ -28,11 +28,17 @@ def __init__(self, conf, predict_mode=False):
# DO NOT apply these weights during metrics computations, only on the loss during
self.w_var = None

self.ensemble_size = conf["trainer"]["ensemble_size"]

def __call__(self, pred, y, clim=None, transform=None, forecast_datetime=0):
if transform is not None:
pred = transform(pred)
y = transform(y)

# calculate ensemble mean, if ensemble_size=1, does nothing
pred = pred.view(y.shape[0], self.ensemble_size, *y.shape[1:]) #b, ensemble, c, t, lat, lon
pred = pred.mean(dim=1)

# Get latitude and variable weights
w_lat = (
self.w_lat.to(dtype=pred.dtype, device=pred.device)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import yaml

import torch

from credit.loss import KCRPSLoss
from credit.parser import credit_main_parser

TEST_FILE_DIR = "/".join(os.path.abspath(__file__).split("/")[:-1])
CONFIG_FILE_DIR = os.path.join(
"/".join(os.path.abspath(__file__).split("/")[:-2]), "config"
)


def test_KCRPS():
loss_fn = KCRPSLoss("none")
batch_size = 2
ensemble_size = 5

target = torch.randn(batch_size, 10, 1, 40, 50)
pred = torch.randn(batch_size * ensemble_size, 10, 1, 40, 50)

loss = loss_fn(target, pred)
assert not torch.isnan(loss).any()


0 comments on commit 1a6c1b2

Please sign in to comment.