Skip to content

Commit

Permalink
Fix selftest errors (#63)
Browse files Browse the repository at this point in the history
* fix selftests

* pass path_save_data as a string to the checkpoint
  • Loading branch information
chyalexcheng authored Jan 17, 2024
1 parent 9d342ae commit 610a5b6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
2 changes: 1 addition & 1 deletion grainlearning/rnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def train_without_wandb(preprocessor: Preprocessor, config = None, model: tf.ker
restore_best_weights=True,
)
checkpoint = tf.keras.callbacks.ModelCheckpoint(
path_save_data,
str(path_save_data),
monitor='val_loss',
save_best_only=True,
save_weights_only=config['save_weights_only']
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_iterative_bayesian_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_run_inference():
#: Create the iterative bayesian filter from a dictionary
ibf_cls = IterativeBayesianFilter.from_dict(
{
"inference": {"ess_target": 0.5},
"inference": {"ess_target": 0.5, "scale_cov_with_max": True},
"sampling": {"max_num_components": 5},
"initial_sampling": 'halton',
}
Expand All @@ -151,7 +151,7 @@ def test_run_inference():
#: Assert that the inference runs correctly if a proposal density is provided
ibf_cls = IterativeBayesianFilter.from_dict(
{
"inference": {"ess_target": 0.5},
"inference": {"ess_target": 0.5, "scale_cov_with_max": True},
"sampling": {"max_num_components": 5},
"initial_sampling": 'halton',
"proposal": np.array([0.5, 0.2, 0.3])
Expand Down
26 changes: 15 additions & 11 deletions tests/unit/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def test_regenerate_params():
)

#: Initialize a Gaussian Mixture Model object
gmm_cls = GaussianMixtureModel(max_num_components=2, covariance_type="full", random_state=100, expand_factor=2)
gmm_cls = GaussianMixtureModel(max_num_components=2, covariance_type="tied", init_params="k-means++",
random_state=100, expand_factor=2)

#: Generate the initial parameter samples
system_cls.param_data = generate_params_qmc(system_cls, system_cls.num_samples)
Expand All @@ -127,10 +128,10 @@ def test_regenerate_params():
new_params,
np.array(
[
[2.50061801e+06, 1.92539376e-01],
[5.40525882e+06, 3.10537276e-01],
[3.46458943e+06, 3.76945456e-01],
[5.66254261e+06, 2.67135646e-01]
[3.50816173e+06, 2.49928167e-01],
[4.87521971e+06, 3.09866017e-01],
[2.57761751e+06, 3.26589170e-01],
[5.39169338e+06, 2.52627014e-01]
]
),
rtol=0.001,
Expand All @@ -140,7 +141,8 @@ def test_regenerate_params():
proposal = np.array([0.0, 0.5, 0.4, 0.1])

#: Initialize again a Gaussian Mixture Model object with slice sampling activated
gmm_cls = GaussianMixtureModel(max_num_components=2, covariance_type="full", expand_factor=10, slice_sampling=True)
gmm_cls = GaussianMixtureModel(max_num_components=2, covariance_type="tied", init_params="k-means++",
expand_factor=10, slice_sampling=True)

#: Generate the initial parameter samples
system_cls.param_data = generate_params_qmc(system_cls, system_cls.num_samples)
Expand All @@ -155,8 +157,8 @@ def test_regenerate_params():
[
[5.50000000e+06, 2.93333333e-01],
[3.25000000e+06, 3.96666667e-01],
[6.90625000e+06, 2.47407407e-01],
[4.23437500e+06, 3.35432099e-01]
[5.07812500e+06, 3.20123457e-01],
[4.55078125e+06, 3.30329218e-01]
]
),
rtol=0.001,
Expand Down Expand Up @@ -198,9 +200,11 @@ def test_draw_samples_within_bounds():
new_params,
np.array(
[
[6.442085e+06, 3.272525e-01],
[5.109979e+06, 2.736598e-01],
[4.379674e+06, 3.722799e-01]
[3.72038687e+06, 2.50311001e-01],
[5.31282617e+06, 3.30634332e-01],
[6.04092450e+06, 2.61269286e-01],
[5.21447138e+06, 2.84730873e-01],
[3.60796701e+06, 2.81671695e-01]
]
),
rtol=0.001,
Expand Down

0 comments on commit 610a5b6

Please sign in to comment.