Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migration to torch distributions and scoringrules integration #70

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10"]
python-version: ["3.10"]

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion mlpp_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

import os

os.environ["TF_USE_LEGACY_KERAS"] = "1"
os.environ["KERAS_BACKEND"] = "torch"
4 changes: 2 additions & 2 deletions mlpp_lib/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
import properscoring as ps
from tensorflow.keras import callbacks
from keras import callbacks


class EnsembleMetrics(callbacks.Callback):
Expand All @@ -16,7 +16,7 @@ def add_validation_data(self, validation_data) -> None:

def on_epoch_end(self, epoch, logs):
"""Compute a range of probabilistic scores at the end of each epoch."""
y_pred = self.model(self.X_val).sample(self.n_samples)
y_pred = self.model(self.X_val).sample((self.n_samples,))

y_pred = y_pred.numpy()[:, :, 0].T
y_val = np.squeeze(self.y_val)
Expand Down
178 changes: 178 additions & 0 deletions mlpp_lib/custom_distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import torch
from torch.distributions import Distribution, constraints, Normal

class TruncatedNormalDistribution(Distribution):
"""
Implementation of a truncated normal distribution in [a, b] with
differentiable sampling.

Source: The Truncated Normal Distribution, John Burkardt 2023
"""

def __init__(self, mu_bar: torch.Tensor, sigma_bar: torch.Tensor, a: torch.Tensor,b: torch.Tensor):
"""_summary_

Args:
mu_bar (torch.Tensor): The mean of the underlying Normal. It is not the true mean.
sigma_bar (torch.Tensor): The std of the underlying Normal. It is not the true std.
a (torch.Tensor): The left boundary
b (torch.Tensor): The right boundary
"""
self._n = Normal(mu_bar, sigma_bar)
self.mu_bar = mu_bar
self.sigma_bar = sigma_bar
super().__init__()

self.a = a
self.b = b


def icdf(self, p):
# inverse cdf
p_ = self._n.cdf(self.a) + p * (self._n.cdf(self.b) - self._n.cdf(self.a))
return self._n.icdf(p_)

def mean(self) -> torch.Tensor:
"""
Returns:
torch.Tensor: Returns the true mean of the distribution.
"""
alpha = (self.a - self.mu_bar) / self.sigma_bar
beta = (self.b - self.mu_bar) / self.sigma_bar

sn = torch.distributions.Normal(torch.zeros_like(self.mu_bar), torch.ones_like(self.mu_bar))

scale = (torch.exp(sn.log_prob(beta)) - torch.exp(sn.log_prob(alpha)))/(sn.cdf(beta) - sn.cdf(alpha))

return self.mu_bar - self.sigma_bar * scale

def variance(self) -> torch.Tensor:
"""
Returns:
torch.Tensor: Returns the true variance of the distribution.
"""
alpha = (self.a - self.mu_bar) / self.sigma_bar
beta = (self.b - self.mu_bar) / self.sigma_bar

sn = torch.distributions.Normal(torch.zeros_like(self.mu_bar), torch.ones_like(self.mu_bar))

pdf_a = torch.exp(sn.log_prob(alpha))
pdf_b = torch.exp(sn.log_prob(beta))
CDF_a = sn.cdf(alpha)
CDF_b = sn.cdf(beta)

return self.sigma_bar**2 * (1.0 - (beta*pdf_b - alpha*pdf_a)/(CDF_b - CDF_a) - ((pdf_b - pdf_a)/(CDF_b - CDF_a))**2)


def moment(self, k):
# Source: A Recursive Formula for the Moments of a Truncated Univariate Normal Distribution (Eric Orjebin)
if k == -1:
return torch.zeros_like(self.mu_bar)
if k == 0:
return torch.ones_like(self.mu_bar)

alpha = (self.a - self.mu_bar) / self.sigma_bar
beta = (self.b - self.mu_bar) / self.sigma_bar
sn = torch.distributions.Normal(torch.zeros_like(self.mu_bar), torch.ones_like(self.mu_bar))

scale = ((self.b**(k-1) * torch.exp(sn.log_prob(beta)) - self.a**(k-1) * torch.exp(sn.log_prob(alpha))) / (sn.cdf(beta) - sn.cdf(alpha)))

return (k-1)* self.sigma_bar ** 2 * self.moment(k-2) + self.mu_bar * self.moment(k-1) - self.sigma_bar * scale

def sample(self, shape):
return self.rsample(shape)

def rsample(self, shape):
# get some random probability [0,1]
p = torch.distributions.Uniform(0,1).sample(shape)
# apply the inverse cdf on p
return self.icdf(p)

@property
def arg_constraints(self):
return {
'mu_bar': constraints.real,
'sigma_bar': constraints.positive,
}

@property
def has_rsample(self):
return True

class CensoredNormalDistribution(Distribution):
r"""Implements a censored Normal distribution.
Values of the underlying normal that lie outside the range [a,b]
are assigned to a and b respectively.

.. math::
f_Y(y) =
\begin{cases}
a, & \text{if } y \leq a \\
\sim N(\bar{\mu}, \bar{\sigma}) & \text{if } a < y < b \\
b, & \text{if } y \geq b \\
\end{cases}


"""

def __init__(self, mu_bar: torch.Tensor, sigma_bar: torch.Tensor, a: torch.Tensor,b: torch.Tensor):
"""
Args:
mu_bar (torch.Tensor): The mean of the latent normal distribution
sigma_bar (torch.Tensor): The std of the latend normal distribution
a (torch.Tensor): The lower bound of the distribution.
b (torch.Tensor): The upper bound of the distribution.
"""


self._n = Normal(mu_bar, sigma_bar)
self.mu_bar = mu_bar
self.sigma_bar = sigma_bar
super().__init__()

self.a = a
self.b = b


def mean(self):
alpha = (self.a - self.mu_bar) / self.sigma_bar
beta = (self.b - self.mu_bar) / self.sigma_bar

sn = torch.distributions.Normal(torch.zeros_like(self.mu_bar), torch.ones_like(self.mu_bar))
E_z = TruncatedNormalDistribution(self.mu_bar, self.sigma_bar, self.a, self.b).mean()
return (
self.b * (1-sn.cdf(beta))
+ self.a * sn.cdf(alpha)
+ E_z * (sn.cdf(beta) - sn.cdf(alpha))
)


def variance(self):
# Variance := Var(Y) = E(Y^2) - E(Y)^2
alpha = (self.a - self.mu_bar) / self.sigma_bar
beta = (self.b - self.mu_bar) / self.sigma_bar
sn = torch.distributions.Normal(torch.zeros_like(self.mu_bar), torch.ones_like(self.sigma_bar))
tn = TruncatedNormalDistribution(mu_bar=self.mu_bar, sigma_bar=self.sigma_bar, a=self.a, b=self.b)

# Law of total expectation:
# E(Y^2) = E(Y^2|X>b)*P(X>b) + E(Y^2|X<a)*P(X<a) + E(Y^2 | a<X<b)*P(a<X<b)
# = b^2 * P(X>b) + a^2 * P(X<a) + E(Z^2~TruncNormal(mu,sigma,a,b)) * P(a<X<b)

E_z2 = tn.moment(2) # E(Z^2)
E_y2 = self.b**2 * (1-sn.cdf(beta)) + self.a**2 * sn.cdf(alpha) + E_z2 * (sn.cdf(beta) - sn.cdf(alpha)) # E(Y^2)

return E_y2 - self.mean()**2 # Var(Y)=E(Y^2)-E(Y)^2


def sample(self, shape):
# note: clipping degenerates the gradients.
# Do not use for MC optimization.
s = self._n.sample(shape)
return torch.clip(s, min=self.a, max=self.b)

@property
def arg_constraints(self):
return {
'mu_bar': constraints.real,
'sigma_bar': constraints.positive, # Enforce positive scale
}
40 changes: 20 additions & 20 deletions mlpp_lib/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import dask.array as da
import numpy as np
import pandas as pd
import tensorflow as tf
import xarray as xr
from typing_extensions import Self
from keras import KerasTensor
import keras

from .model_selection import DataSplitter
from .normalizers import DataTransformer
Expand Down Expand Up @@ -556,7 +557,7 @@ def __repr__(self) -> str:
return out


class DataLoader(tf.keras.utils.Sequence):
class DataLoader(keras.utils.Sequence):
"""A dataloader for mlpp.

Parameters
Expand Down Expand Up @@ -598,16 +599,16 @@ def __init__(
self.shuffle = shuffle
self.block_size = block_size
self.num_samples = len(self.dataset.x)
self.num_batches = int(np.ceil(self.num_samples / batch_size))
self._indices = tf.range(self.num_samples)
self.num_batches_ = int(np.ceil(self.num_samples / batch_size))
self._indices = keras.ops.arange(self.num_samples)
self._seed = 0
self._reset()

def __len__(self) -> int:
return self.num_batches
return self.num_batches_

def __getitem__(self, index) -> tuple[tf.Tensor, ...]:
if index >= self.num_batches:
def __getitem__(self, index) -> tuple[KerasTensor, ...]:
if index >= self.num_batches_:
self._reset()
raise IndexError
start = index * self.batch_size
Expand All @@ -626,38 +627,37 @@ def _shuffle_indices(self) -> None:
each block stay in their original order, but the blocks themselves are shuffled.
"""
if self.block_size == 1:
self._indices = tf.random.shuffle(self._indices, seed=self._seed)
self._indices = keras.random.shuffle(self._indices, seed=self._seed)
return
num_blocks = self._indices.shape[0] // self.block_size
reshaped_indices = tf.reshape(
reshaped_indices = keras.ops.reshape(
self._indices[: num_blocks * self.block_size], (num_blocks, self.block_size)
)
shuffled_blocks = tf.random.shuffle(reshaped_indices, seed=self._seed)
shuffled_indices = tf.reshape(shuffled_blocks, [-1])
shuffled_blocks = keras.random.shuffle(reshaped_indices, seed=self._seed)
shuffled_indices = keras.reshape(shuffled_blocks, [-1])
# Append any remaining elements if the number of indices isn't a multiple of the block size
if shuffled_indices.shape[0] % self.block_size:
remainder = self._indices[num_blocks * self.block_size :]
shuffled_indices = tf.concat([shuffled_indices, remainder], axis=0)
shuffled_indices = keras.ops.concatenate([shuffled_indices, remainder], axis=0)
self._indices = shuffled_indices

def _reset(self) -> None:
"""Reset iterator and shuffles data if needed"""
self.index = 0
if self.shuffle:
self._shuffle_indices()
self.dataset.x = tf.gather(self.dataset.x, self._indices)
self.dataset.y = tf.gather(self.dataset.y, self._indices)
self.dataset.x = keras.ops.take(self.dataset.x, self._indices, axis=0)
self.dataset.y = keras.ops.take(self.dataset.y, self._indices, axis=0)
if self.dataset.w is not None:
self.dataset.w = tf.gather(self.dataset.w, self._indices)
self.dataset.w = keras.ops.take(self.dataset.w, self._indices, axis=0)
self._seed += 1

def _to_device(self, device) -> None:
"""Transfer data to a device"""
with tf.device(device):
self.dataset.x = tf.constant(self.dataset.x)
self.dataset.y = tf.constant(self.dataset.y)
with keras.device(device):
self.dataset.x = keras.ops.array(self.dataset.x)
self.dataset.y = keras.ops.array(self.dataset.y)
if self.dataset.w is not None:
self.dataset.w = tf.constant(self.dataset.w)
self.dataset.w = keras.ops.array(self.dataset.w)


class DataFilter:
Expand Down
3 changes: 3 additions & 0 deletions mlpp_lib/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class MissingReparameterizationError(Exception):
"""Raised when a sampling function without 'rsample' is used in a context requiring reparameterization."""
pass
Loading
Loading