From 610a5b662b55afb705648cd4e3d38c5a04956cb7 Mon Sep 17 00:00:00 2001 From: Hongyang Cheng Date: Wed, 17 Jan 2024 05:59:54 +0100 Subject: [PATCH] Fix selftest errors (#63) * fix selftests * pass path_save_data as a string to the checkpoint --- grainlearning/rnn/train.py | 2 +- tests/unit/test_iterative_bayesian_filter.py | 4 +-- tests/unit/test_sampling.py | 26 +++++++++++--------- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/grainlearning/rnn/train.py b/grainlearning/rnn/train.py index ffe047e..a25f8de 100644 --- a/grainlearning/rnn/train.py +++ b/grainlearning/rnn/train.py @@ -142,7 +142,7 @@ def train_without_wandb(preprocessor: Preprocessor, config = None, model: tf.ker restore_best_weights=True, ) checkpoint = tf.keras.callbacks.ModelCheckpoint( - path_save_data, + str(path_save_data), monitor='val_loss', save_best_only=True, save_weights_only=config['save_weights_only'] diff --git a/tests/unit/test_iterative_bayesian_filter.py b/tests/unit/test_iterative_bayesian_filter.py index de6fa82..80cec29 100644 --- a/tests/unit/test_iterative_bayesian_filter.py +++ b/tests/unit/test_iterative_bayesian_filter.py @@ -130,7 +130,7 @@ def test_run_inference(): #: Create the iterative bayesian filter from a dictionary ibf_cls = IterativeBayesianFilter.from_dict( { - "inference": {"ess_target": 0.5}, + "inference": {"ess_target": 0.5, "scale_cov_with_max": True}, "sampling": {"max_num_components": 5}, "initial_sampling": 'halton', } @@ -151,7 +151,7 @@ def test_run_inference(): #: Assert that the inference runs correctly if a proposal density is provided ibf_cls = IterativeBayesianFilter.from_dict( { - "inference": {"ess_target": 0.5}, + "inference": {"ess_target": 0.5, "scale_cov_with_max": True}, "sampling": {"max_num_components": 5}, "initial_sampling": 'halton', "proposal": np.array([0.5, 0.2, 0.3]) diff --git a/tests/unit/test_sampling.py b/tests/unit/test_sampling.py index b0368c0..909c98e 100644 --- a/tests/unit/test_sampling.py +++ b/tests/unit/test_sampling.py @@ -108,7 +108,8 @@ def test_regenerate_params(): ) #: Initialize a Gaussian Mixture Model object - gmm_cls = GaussianMixtureModel(max_num_components=2, covariance_type="full", random_state=100, expand_factor=2) + gmm_cls = GaussianMixtureModel(max_num_components=2, covariance_type="tied", init_params="k-means++", + random_state=100, expand_factor=2) #: Generate the initial parameter samples system_cls.param_data = generate_params_qmc(system_cls, system_cls.num_samples) @@ -127,10 +128,10 @@ def test_regenerate_params(): new_params, np.array( [ - [2.50061801e+06, 1.92539376e-01], - [5.40525882e+06, 3.10537276e-01], - [3.46458943e+06, 3.76945456e-01], - [5.66254261e+06, 2.67135646e-01] + [3.50816173e+06, 2.49928167e-01], + [4.87521971e+06, 3.09866017e-01], + [2.57761751e+06, 3.26589170e-01], + [5.39169338e+06, 2.52627014e-01] ] ), rtol=0.001, @@ -140,7 +141,8 @@ def test_regenerate_params(): proposal = np.array([0.0, 0.5, 0.4, 0.1]) #: Initialize again a Gaussian Mixture Model object with slice sampling activated - gmm_cls = GaussianMixtureModel(max_num_components=2, covariance_type="full", expand_factor=10, slice_sampling=True) + gmm_cls = GaussianMixtureModel(max_num_components=2, covariance_type="tied", init_params="k-means++", + expand_factor=10, slice_sampling=True) #: Generate the initial parameter samples system_cls.param_data = generate_params_qmc(system_cls, system_cls.num_samples) @@ -155,8 +157,8 @@ def test_regenerate_params(): [ [5.50000000e+06, 2.93333333e-01], [3.25000000e+06, 3.96666667e-01], - [6.90625000e+06, 2.47407407e-01], - [4.23437500e+06, 3.35432099e-01] + [5.07812500e+06, 3.20123457e-01], + [4.55078125e+06, 3.30329218e-01] ] ), rtol=0.001, @@ -198,9 +200,11 @@ def test_draw_samples_within_bounds(): new_params, np.array( [ - [6.442085e+06, 3.272525e-01], - [5.109979e+06, 2.736598e-01], - [4.379674e+06, 3.722799e-01] + [3.72038687e+06, 2.50311001e-01], + [5.31282617e+06, 3.30634332e-01], + [6.04092450e+06, 2.61269286e-01], + [5.21447138e+06, 2.84730873e-01], + [3.60796701e+06, 2.81671695e-01] ] ), rtol=0.001,