Skip to content

Commit

Permalink
Force Beta distr to concave downward shape
Browse files Browse the repository at this point in the history
  • Loading branch information
dnerini committed Mar 22, 2023
1 parent 80dc29a commit a1d2ea9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
5 changes: 3 additions & 2 deletions mlpp_lib/probabilistic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ def new(params, event_shape=(), validate_args=False, name=None):
axis=0,
)
alpha, beta, shift, scale = tf.split(params, 4, axis=-1)
alpha = tf.math.softplus(tf.reshape(alpha, output_shape))
beta = tf.math.softplus(tf.reshape(beta, output_shape))
# alpha > 2 and beta > 2 produce a concave downward Beta
alpha = 2.0 + tf.math.softplus(tf.reshape(alpha, output_shape))
beta = 2.0 + tf.math.softplus(tf.reshape(beta, output_shape))
shift = tf.math.softplus(tf.reshape(shift, output_shape))
scale = tf.math.softplus(tf.reshape(scale, output_shape))
betad = tfd.Beta(alpha, beta, validate_args=validate_args)
Expand Down
15 changes: 12 additions & 3 deletions tests/test_train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

import cloudpickle
import numpy as np
import pytest
from keras.engine.functional import Functional
import xarray as xr
Expand Down Expand Up @@ -36,7 +37,7 @@
"model": {
"fully_connected_network": {
"hidden_layers": [10],
"probabilistic_layer": "IndependentNormal",
"probabilistic_layer": "IndependentBeta",
}
},
"loss": {"WeightedCRPSEnergy": {"threshold": 0, "n_samples": 5}},
Expand Down Expand Up @@ -80,7 +81,8 @@ def write_datasets_zarr(tmp_path, features_dataset, targets_dataset):
@pytest.mark.usefixtures("write_datasets_zarr")
@pytest.mark.parametrize("cfg", RUNS)
def test_train_fromfile(tmp_path, cfg):
cfg.update({"epochs": 3})
num_epochs = 3
cfg.update({"epochs": num_epochs})

splitter_options = ValidDataSplitterOptions(time="lists", station="lists")
splitter = DataSplitter(splitter_options.time_split, splitter_options.station_split)
Expand All @@ -96,6 +98,9 @@ def test_train_fromfile(tmp_path, cfg):
assert isinstance(results[2], Standardizer) # standardizer
assert isinstance(results[3], dict) # history

assert all([np.isfinite(v).all() for v in results[3].values()])
assert all([len(v) == num_epochs for v in results[3].values()])

# try to pickle the custom objects
cloudpickle.dumps(results[1])

Expand All @@ -105,7 +110,8 @@ def test_train_fromfile(tmp_path, cfg):

@pytest.mark.parametrize("cfg", RUNS)
def test_train_fromds(features_dataset, targets_dataset, cfg):
cfg.update({"epochs": 3})
num_epochs = 3
cfg.update({"epochs": num_epochs})

splitter_options = ValidDataSplitterOptions(time="lists", station="lists")
splitter = DataSplitter(splitter_options.time_split, splitter_options.station_split)
Expand All @@ -124,6 +130,9 @@ def test_train_fromds(features_dataset, targets_dataset, cfg):
assert isinstance(results[2], Standardizer) # standardizer
assert isinstance(results[3], dict) # history

assert all([np.isfinite(v).all() for v in results[3].values()])
assert all([len(v) == num_epochs for v in results[3].values()])

# try to pickle the custom objects
cloudpickle.dumps(results[1])

Expand Down

0 comments on commit a1d2ea9

Please sign in to comment.