Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(cross): correct saving and loading of CPCCA and Rotator models #225

Merged
merged 3 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tests/models/cross/test_cpcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,38 @@ def test_save_load(tmp_path, engine, alpha):
assert np.allclose(XYr_o[1], XYr_l[1])


@pytest.mark.parametrize("engine", ["netcdf4", "zarr"])
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
def test_save_load_with_data(tmp_path, engine, alpha):
"""Test save/load methods in CPCCA class, ensuring that we can
roundtrip the model and get the same results for SCF."""
X = generate_random_data((200, 10), seed=123)
Y = generate_random_data((200, 20), seed=321)

original = CPCCA(alpha=alpha)
original.fit(X, Y, "sample")

# Save the CPCCA model
original.save(tmp_path / "cpcca", engine=engine, save_data=True)

# Check that the CPCCA model has been saved
assert (tmp_path / "cpcca").exists()

# Recreate the model from saved file
loaded = CPCCA.load(tmp_path / "cpcca", engine=engine)

# Check that the params and DataContainer objects match
assert original.get_params() == loaded.get_params()
assert all([key in loaded.data for key in original.data])
for key in original.data:
assert loaded.data[key].equals(original.data[key])

# Test that the recreated model can compute the SCF
assert np.allclose(
original.squared_covariance_fraction(), loaded.squared_covariance_fraction()
)


def test_serialize_deserialize_dataarray(mock_data_array):
"""Test roundtrip serialization when the model is fit on a DataArray."""
model = CPCCA()
Expand Down
97 changes: 97 additions & 0 deletions tests/models/cross/test_hilbert_cpcca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import dask.array as da
import numpy as np
import pytest
import xarray as xr

from xeofs.cross import HilbertCPCCA


def generate_random_data(shape, lazy=False, seed=142):
rng = np.random.default_rng(seed)
if lazy:
return xr.DataArray(
da.random.random(shape, chunks=(5, 5)),
dims=["sample", "feature"],
coords={"sample": np.arange(shape[0]), "feature": np.arange(shape[1])},
)
else:
return xr.DataArray(
rng.random(shape),
dims=["sample", "feature"],
coords={"sample": np.arange(shape[0]), "feature": np.arange(shape[1])},
)


def generate_well_conditioned_data(lazy=False):
rng = np.random.default_rng(142)
t = np.linspace(0, 50, 200)
std = 0.1
x1 = np.sin(t)[:, None] + rng.normal(0, std, size=(200, 2))
x2 = np.sin(t)[:, None] + rng.normal(0, std, size=(200, 3))
x1[:, 1] = x1[:, 1] ** 2
x2[:, 1] = x2[:, 1] ** 3
x2[:, 2] = abs(x2[:, 2]) ** (0.5)
coords_time = np.arange(len(t))
coords_fx = [1, 2]
coords_fy = [1, 2, 3]
X = xr.DataArray(
x1,
dims=["sample", "feature"],
coords={"sample": coords_time, "feature": coords_fx},
)
Y = xr.DataArray(
x2,
dims=["sample", "feature"],
coords={"sample": coords_time, "feature": coords_fy},
)
if lazy:
X = X.chunk({"sample": 5, "feature": -1})
Y = Y.chunk({"sample": 5, "feature": -1})
return X, Y
else:
return X, Y


# Currently, netCDF4 does not support complex numbers, so skip this test
@pytest.mark.parametrize("engine", ["zarr"])
@pytest.mark.parametrize("alpha", [0.0, 0.5, 1.0])
def test_save_load_with_data(tmp_path, engine, alpha):
"""Test save/load methods in CPCCA class, ensuring that we can
roundtrip the model and get the same results."""
X = generate_random_data((200, 10), seed=123)
Y = generate_random_data((200, 20), seed=321)

original = HilbertCPCCA(alpha=alpha)
original.fit(X, Y, "sample")

# Save the CPCCA model
original.save(tmp_path / "cpcca", engine=engine, save_data=True)

# Check that the CPCCA model has been saved
assert (tmp_path / "cpcca").exists()

# Recreate the model from saved file
loaded = HilbertCPCCA.load(tmp_path / "cpcca", engine=engine)

# Check that the params and DataContainer objects match
assert original.get_params() == loaded.get_params()
assert all([key in loaded.data for key in original.data])
for key in original.data:
assert loaded.data[key].equals(original.data[key])

# Test that the recreated model can compute the SCF
assert np.allclose(
original.squared_covariance_fraction(), loaded.squared_covariance_fraction()
)

# Test that the recreated model can compute the components amplitude
A1_original, A2_original = original.components_amplitude()
A1_loaded, A2_loaded = loaded.components_amplitude()
assert np.allclose(A1_original, A1_loaded)
assert np.allclose(A2_original, A2_loaded)

# Test that the recreated model can compute the components phase
P1_original, P2_original = original.components_phase()
P1_loaded, P2_loaded = loaded.components_phase()
assert np.allclose(P1_original, P1_loaded)
assert np.allclose(P2_original, P2_loaded)
49 changes: 49 additions & 0 deletions tests/models/cross/test_hilbert_mca_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,52 @@ def test_scores_phase(mca_model, mock_data_array, dim):
mca_rotator = HilbertMCARotator(n_modes=2)
mca_rotator.fit(mca_model)
amps1, amps2 = mca_rotator.scores_phase()


@pytest.mark.parametrize(
"dim",
[
(("time",)),
(("lat", "lon")),
(("lon", "lat")),
],
)
# Currently, netCDF4 does not support complex numbers, so skip this test
@pytest.mark.parametrize("engine", ["zarr"])
def test_save_load_with_data(tmp_path, engine, mca_model):
"""Test save/load methods in HilbertMCARotator class, ensuring that we can
roundtrip the model and get the same results."""
original = HilbertMCARotator(n_modes=2)
original.fit(mca_model)

# Save the HilbertMCARotator model
original.save(tmp_path / "mca", engine=engine, save_data=True)

# Check that the HilbertMCARotator model has been saved
assert (tmp_path / "mca").exists()

# Recreate the model from saved file
loaded = HilbertMCARotator.load(tmp_path / "mca", engine=engine)

# Check that the params and DataContainer objects match
assert original.get_params() == loaded.get_params()
assert all([key in loaded.data for key in original.data])
for key in original.data:
assert loaded.data[key].equals(original.data[key])

# Test that the recreated model can compute the SCF
assert np.allclose(
original.squared_covariance_fraction(), loaded.squared_covariance_fraction()
)

# Test that the recreated model can compute the components amplitude
A1_original, A2_original = original.components_amplitude()
A1_loaded, A2_loaded = loaded.components_amplitude()
assert np.allclose(A1_original, A1_loaded)
assert np.allclose(A2_original, A2_loaded)

# Test that the recreated model can compute the components phase
P1_original, P2_original = original.components_phase()
P1_loaded, P2_loaded = loaded.components_phase()
assert np.allclose(P1_original, P1_loaded)
assert np.allclose(P2_original, P2_loaded)
2 changes: 2 additions & 0 deletions xeofs/cross/base_model_cross_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,8 @@ def get_serialization_attrs(self) -> dict:
preprocessor2=self.preprocessor2,
whitener1=self.whitener1,
whitener2=self.whitener2,
sample_name=self.sample_name,
feature_name=self.feature_name,
)

def _augment_data(self, X: DataArray, Y: DataArray) -> tuple[DataArray, DataArray]:
Expand Down
8 changes: 4 additions & 4 deletions xeofs/cross/cpcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,8 +1218,8 @@ def components_phase(self, normalized=True) -> tuple[DataObject, DataObject]:
Px = self.whitener1.inverse_transform_components(Px)
Py = self.whitener2.inverse_transform_components(Py)

Px = xr.apply_ufunc(np.angle, Px, keep_attrs=True)
Py = xr.apply_ufunc(np.angle, Py, keep_attrs=True)
Px = xr.apply_ufunc(np.angle, Px, keep_attrs=True, dask="allowed")
Py = xr.apply_ufunc(np.angle, Py, keep_attrs=True, dask="allowed")

Px.name = "components_phase_X"
Py.name = "components_phase_Y"
Expand Down Expand Up @@ -1288,8 +1288,8 @@ def scores_phase(self, normalized=False) -> tuple[DataArray, DataArray]:
Rx = self.whitener1.inverse_transform_scores(Rx)
Ry = self.whitener2.inverse_transform_scores(Ry)

Rx = xr.apply_ufunc(np.angle, Rx, keep_attrs=True)
Ry = xr.apply_ufunc(np.angle, Ry, keep_attrs=True)
Rx = xr.apply_ufunc(np.angle, Rx, keep_attrs=True, dask="allowed")
Ry = xr.apply_ufunc(np.angle, Ry, keep_attrs=True, dask="allowed")

Rx.name = "scores_phase_X"
Ry.name = "scores_phase_Y"
Expand Down
2 changes: 2 additions & 0 deletions xeofs/cross/cpcca_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def get_serialization_attrs(self) -> dict:
whitener2=self.whitener2,
model=self.model,
sorted=self.sorted,
sample_name=self.sample_name,
feature_name=self.feature_name,
)

def _fit_algorithm(self, model) -> Self:
Expand Down
Loading