Skip to content

Commit

Permalink
fix(cross): correct saving and loading of CPCCA and Rotator models (#225
Browse files Browse the repository at this point in the history
)
  • Loading branch information
nicrie authored Sep 12, 2024
1 parent 5ed9753 commit 7cba749
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 4 deletions.
32 changes: 32 additions & 0 deletions tests/models/cross/test_cpcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,38 @@ def test_save_load(tmp_path, engine, alpha):
assert np.allclose(XYr_o[1], XYr_l[1])


@pytest.mark.parametrize("engine", ["netcdf4", "zarr"])
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
def test_save_load_with_data(tmp_path, engine, alpha):
"""Test save/load methods in CPCCA class, ensuring that we can
roundtrip the model and get the same results for SCF."""
X = generate_random_data((200, 10), seed=123)
Y = generate_random_data((200, 20), seed=321)

original = CPCCA(alpha=alpha)
original.fit(X, Y, "sample")

# Save the CPCCA model
original.save(tmp_path / "cpcca", engine=engine, save_data=True)

# Check that the CPCCA model has been saved
assert (tmp_path / "cpcca").exists()

# Recreate the model from saved file
loaded = CPCCA.load(tmp_path / "cpcca", engine=engine)

# Check that the params and DataContainer objects match
assert original.get_params() == loaded.get_params()
assert all([key in loaded.data for key in original.data])
for key in original.data:
assert loaded.data[key].equals(original.data[key])

# Test that the recreated model can compute the SCF
assert np.allclose(
original.squared_covariance_fraction(), loaded.squared_covariance_fraction()
)


def test_serialize_deserialize_dataarray(mock_data_array):
"""Test roundtrip serialization when the model is fit on a DataArray."""
model = CPCCA()
Expand Down
97 changes: 97 additions & 0 deletions tests/models/cross/test_hilbert_cpcca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import dask.array as da
import numpy as np
import pytest
import xarray as xr

from xeofs.cross import HilbertCPCCA


def generate_random_data(shape, lazy=False, seed=142):
rng = np.random.default_rng(seed)
if lazy:
return xr.DataArray(
da.random.random(shape, chunks=(5, 5)),
dims=["sample", "feature"],
coords={"sample": np.arange(shape[0]), "feature": np.arange(shape[1])},
)
else:
return xr.DataArray(
rng.random(shape),
dims=["sample", "feature"],
coords={"sample": np.arange(shape[0]), "feature": np.arange(shape[1])},
)


def generate_well_conditioned_data(lazy=False):
rng = np.random.default_rng(142)
t = np.linspace(0, 50, 200)
std = 0.1
x1 = np.sin(t)[:, None] + rng.normal(0, std, size=(200, 2))
x2 = np.sin(t)[:, None] + rng.normal(0, std, size=(200, 3))
x1[:, 1] = x1[:, 1] ** 2
x2[:, 1] = x2[:, 1] ** 3
x2[:, 2] = abs(x2[:, 2]) ** (0.5)
coords_time = np.arange(len(t))
coords_fx = [1, 2]
coords_fy = [1, 2, 3]
X = xr.DataArray(
x1,
dims=["sample", "feature"],
coords={"sample": coords_time, "feature": coords_fx},
)
Y = xr.DataArray(
x2,
dims=["sample", "feature"],
coords={"sample": coords_time, "feature": coords_fy},
)
if lazy:
X = X.chunk({"sample": 5, "feature": -1})
Y = Y.chunk({"sample": 5, "feature": -1})
return X, Y
else:
return X, Y


# Currently, netCDF4 does not support complex numbers, so skip this test
@pytest.mark.parametrize("engine", ["zarr"])
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
def test_save_load_with_data(tmp_path, engine, alpha):
"""Test save/load methods in CPCCA class, ensuring that we can
roundtrip the model and get the same results."""
X = generate_random_data((200, 10), seed=123)
Y = generate_random_data((200, 20), seed=321)

original = HilbertCPCCA(alpha=alpha)
original.fit(X, Y, "sample")

# Save the CPCCA model
original.save(tmp_path / "cpcca", engine=engine, save_data=True)

# Check that the CPCCA model has been saved
assert (tmp_path / "cpcca").exists()

# Recreate the model from saved file
loaded = HilbertCPCCA.load(tmp_path / "cpcca", engine=engine)

# Check that the params and DataContainer objects match
assert original.get_params() == loaded.get_params()
assert all([key in loaded.data for key in original.data])
for key in original.data:
assert loaded.data[key].equals(original.data[key])

# Test that the recreated model can compute the SCF
assert np.allclose(
original.squared_covariance_fraction(), loaded.squared_covariance_fraction()
)

# Test that the recreated model can compute the components amplitude
A1_original, A2_original = original.components_amplitude()
A1_loaded, A2_loaded = loaded.components_amplitude()
assert np.allclose(A1_original, A1_loaded)
assert np.allclose(A2_original, A2_loaded)

# Test that the recreated model can compute the components phase
P1_original, P2_original = original.components_phase()
P1_loaded, P2_loaded = loaded.components_phase()
assert np.allclose(P1_original, P1_loaded)
assert np.allclose(P2_original, P2_loaded)
49 changes: 49 additions & 0 deletions tests/models/cross/test_hilbert_mca_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,52 @@ def test_scores_phase(mca_model, mock_data_array, dim):
mca_rotator = HilbertMCARotator(n_modes=2)
mca_rotator.fit(mca_model)
amps1, amps2 = mca_rotator.scores_phase()


@pytest.mark.parametrize(
"dim",
[
(("time",)),
(("lat", "lon")),
(("lon", "lat")),
],
)
# Currently, netCDF4 does not support complex numbers, so skip this test
@pytest.mark.parametrize("engine", ["zarr"])
def test_save_load_with_data(tmp_path, engine, mca_model):
"""Test save/load methods in HilbertMCARotator class, ensuring that we can
roundtrip the model and get the same results."""
original = HilbertMCARotator(n_modes=2)
original.fit(mca_model)

# Save the HilbertMCARotator model
original.save(tmp_path / "mca", engine=engine, save_data=True)

# Check that the HilbertMCARotator model has been saved
assert (tmp_path / "mca").exists()

# Recreate the model from saved file
loaded = HilbertMCARotator.load(tmp_path / "mca", engine=engine)

# Check that the params and DataContainer objects match
assert original.get_params() == loaded.get_params()
assert all([key in loaded.data for key in original.data])
for key in original.data:
assert loaded.data[key].equals(original.data[key])

# Test that the recreated model can compute the SCF
assert np.allclose(
original.squared_covariance_fraction(), loaded.squared_covariance_fraction()
)

# Test that the recreated model can compute the components amplitude
A1_original, A2_original = original.components_amplitude()
A1_loaded, A2_loaded = loaded.components_amplitude()
assert np.allclose(A1_original, A1_loaded)
assert np.allclose(A2_original, A2_loaded)

# Test that the recreated model can compute the components phase
P1_original, P2_original = original.components_phase()
P1_loaded, P2_loaded = loaded.components_phase()
assert np.allclose(P1_original, P1_loaded)
assert np.allclose(P2_original, P2_loaded)
2 changes: 2 additions & 0 deletions xeofs/cross/base_model_cross_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,8 @@ def get_serialization_attrs(self) -> dict:
preprocessor2=self.preprocessor2,
whitener1=self.whitener1,
whitener2=self.whitener2,
sample_name=self.sample_name,
feature_name=self.feature_name,
)

def _augment_data(self, X: DataArray, Y: DataArray) -> tuple[DataArray, DataArray]:
Expand Down
8 changes: 4 additions & 4 deletions xeofs/cross/cpcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,8 +1218,8 @@ def components_phase(self, normalized=True) -> tuple[DataObject, DataObject]:
Px = self.whitener1.inverse_transform_components(Px)
Py = self.whitener2.inverse_transform_components(Py)

Px = xr.apply_ufunc(np.angle, Px, keep_attrs=True)
Py = xr.apply_ufunc(np.angle, Py, keep_attrs=True)
Px = xr.apply_ufunc(np.angle, Px, keep_attrs=True, dask="allowed")
Py = xr.apply_ufunc(np.angle, Py, keep_attrs=True, dask="allowed")

Px.name = "components_phase_X"
Py.name = "components_phase_Y"
Expand Down Expand Up @@ -1288,8 +1288,8 @@ def scores_phase(self, normalized=False) -> tuple[DataArray, DataArray]:
Rx = self.whitener1.inverse_transform_scores(Rx)
Ry = self.whitener2.inverse_transform_scores(Ry)

Rx = xr.apply_ufunc(np.angle, Rx, keep_attrs=True)
Ry = xr.apply_ufunc(np.angle, Ry, keep_attrs=True)
Rx = xr.apply_ufunc(np.angle, Rx, keep_attrs=True, dask="allowed")
Ry = xr.apply_ufunc(np.angle, Ry, keep_attrs=True, dask="allowed")

Rx.name = "scores_phase_X"
Ry.name = "scores_phase_Y"
Expand Down
2 changes: 2 additions & 0 deletions xeofs/cross/cpcca_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def get_serialization_attrs(self) -> dict:
whitener2=self.whitener2,
model=self.model,
sorted=self.sorted,
sample_name=self.sample_name,
feature_name=self.feature_name,
)

def _fit_algorithm(self, model) -> Self:
Expand Down

0 comments on commit 7cba749

Please sign in to comment.