Skip to content

Commit

Permalink
refactor(decomposer): provide kwargs
Browse files Browse the repository at this point in the history
before only specific parameter were allowed.
Now one can set any parameter available in the given solver.
  • Loading branch information
nicrie committed Aug 2, 2023
1 parent 717f5fe commit a75c805
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 67 deletions.
26 changes: 14 additions & 12 deletions tests/models/test_decomposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

@pytest.fixture
def decomposer():
return Decomposer(n_modes=2, n_iter=3, random_state=42, verbose=False)
return Decomposer(n_modes=2, random_state=42)


@pytest.fixture
def cross_decomposer():
return CrossDecomposer(n_modes=2, n_iter=3, random_state=42, verbose=False)
return CrossDecomposer(n_modes=2, random_state=42)


@pytest.fixture
Expand All @@ -41,17 +41,13 @@ def test_complex_dask_data_array(mock_complex_data_array):


def test_decomposer_init(decomposer):
assert decomposer.params["n_modes"] == 2
assert decomposer.params["n_iter"] == 3
assert decomposer.params["random_state"] == 42
assert decomposer.params["verbose"] == False
assert decomposer.n_modes == 2
assert decomposer.solver_kwargs["random_state"] == 42


def test_cross_decomposer_init(cross_decomposer):
assert cross_decomposer.params["n_modes"] == 2
assert cross_decomposer.params["n_iter"] == 3
assert cross_decomposer.params["random_state"] == 42
assert cross_decomposer.params["verbose"] == False
assert cross_decomposer.n_modes == 2
assert cross_decomposer.solver_kwargs["random_state"] == 42


def test_decomposer_fit(decomposer, mock_data_array):
Expand All @@ -61,7 +57,10 @@ def test_decomposer_fit(decomposer, mock_data_array):
assert "components_" in decomposer.__dict__


def test_decomposer_fit_dask(decomposer, mock_dask_data_array):
def test_decomposer_fit_dask(mock_dask_data_array):
# The Dask SVD solver has no parameter 'random_state' but 'seed' instead,
# so let's create a new decomposer for this case
decomposer = Decomposer(n_modes=2, seed=42)
decomposer.fit(mock_dask_data_array)
assert "scores_" in decomposer.__dict__
assert "singular_values_" in decomposer.__dict__
Expand Down Expand Up @@ -89,7 +88,10 @@ def test_cross_decomposer_fit_complex(cross_decomposer, mock_complex_data_array)
assert "singular_vectors2_" in cross_decomposer.__dict__


def test_cross_decomposer_fit_dask(cross_decomposer, mock_dask_data_array):
def test_cross_decomposer_fit_dask(mock_dask_data_array):
# The Dask SVD solver has no parameter 'random_state' but 'seed' instead,
# so let's create a new decomposer for this case
cross_decomposer = CrossDecomposer(n_modes=2, seed=42)
cross_decomposer.fit(mock_dask_data_array, mock_dask_data_array)
assert "singular_vectors1_" in cross_decomposer.__dict__
assert "singular_values_" in cross_decomposer.__dict__
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_eof.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def test_eof_transform(dim, mock_data_array):
"""Test projecting new unseen data onto the components (EOFs/eigenvectors)"""

# Create a xarray DataArray with random data
model = EOF(n_modes=5)
model = EOF(n_modes=5, solver_kwargs={"random_state": 0})
model.fit(mock_data_array, dim)
scores = model.scores()

Expand All @@ -293,7 +293,7 @@ def test_eof_transform(dim, mock_data_array):

# Check that the projection's data is the same as the scores
np.testing.assert_allclose(
scores.sel(mode=slice(1, 3)), projections.sel(mode=slice(1, 3)), rtol=1e-2
scores.sel(mode=slice(1, 3)), projections.sel(mode=slice(1, 3)), rtol=1e-3
)


Expand Down
28 changes: 18 additions & 10 deletions tests/models/test_orthogonality.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,10 @@ def test_mca_components(dim, use_coslat, mock_data_array):
K1 = V1.T @ V1
K2 = V2.T @ V2
assert np.allclose(
K1, np.eye(K1.shape[0]), atol=1e-5
K1, np.eye(K1.shape[0]), rtol=1e-8
), "Left components are not orthogonal"
assert np.allclose(
K2, np.eye(K2.shape[0]), atol=1e-5
K2, np.eye(K2.shape[0]), rtol=1e-8
), "Right components are not orthogonal"


Expand Down Expand Up @@ -321,10 +321,10 @@ def test_rmca_components(dim, use_coslat, power, squared_loadings, mock_data_arr
K1 = V1.conj().T @ V1
K2 = V2.conj().T @ V2
assert np.allclose(
np.diag(K1), np.ones(K1.shape[0]), atol=5.0e-2
np.diag(K1), np.ones(K1.shape[0]), rtol=1e-5
), "Components are not normalized"
assert np.allclose(
np.diag(K2), np.ones(K2.shape[0]), atol=5.0e-2
np.diag(K2), np.ones(K2.shape[0]), rtol=1e-5
), "Components are not normalized"
# Assert that off-diagonals are not zero
assert not np.allclose(K1, np.eye(K1.shape[0])), "Rotated components are orthogonal"
Expand Down Expand Up @@ -398,10 +398,10 @@ def test_crmca_components(dim, use_coslat, power, squared_loadings, mock_data_ar
K1 = V1.conj().T @ V1
K2 = V2.conj().T @ V2
assert np.allclose(
np.diag(K1), np.ones(K1.shape[0]), atol=1.0e-1
np.diag(K1), np.ones(K1.shape[0]), rtol=1e-5
), "Components are not normalized"
assert np.allclose(
np.diag(K2), np.ones(K2.shape[0]), atol=1.0e-1
np.diag(K2), np.ones(K2.shape[0]), rtol=1e-5
), "Components are not normalized"
# Assert that off-diagonals are not zero
assert not np.allclose(K1, np.eye(K1.shape[0])), "Rotated components are orthogonal"
Expand Down Expand Up @@ -498,15 +498,23 @@ def test_ceof_transform(dim, use_coslat, mock_data_array):
)
def test_reof_transform(dim, use_coslat, power, mock_data_array):
"""Transforming the original data results in the model scores"""
model = EOF(n_modes=5, standardize=True, use_coslat=use_coslat)
model = EOF(
n_modes=5,
standardize=True,
use_coslat=use_coslat,
solver_kwargs={"random_state": 0},
)
model.fit(mock_data_array, dim=dim)
rot = EOFRotator(n_modes=5, power=power)
rot.fit(model)
scores = rot.scores()
pseudo_scores = rot.transform(mock_data_array)
assert np.allclose(
scores, pseudo_scores, atol=1e-3
), "Transformed data does not match the scores"
np.testing.assert_allclose(
scores,
pseudo_scores,
rtol=5e-3,
err_msg="Transformed data does not match the scores",
)


# Complex Rotated EOF
Expand Down
8 changes: 7 additions & 1 deletion xeofs/models/_base_cross_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ class _BaseCrossModel(ABC):
"""

def __init__(
self, n_modes=10, standardize=False, use_coslat=False, use_weights=False
self,
n_modes=10,
standardize=False,
use_coslat=False,
use_weights=False,
solver_kwargs={},
):
# Define model parameters
self._params = {
Expand All @@ -37,6 +42,7 @@ def __init__(
"use_coslat": use_coslat,
"use_weights": use_weights,
}
self._solver_kwargs = solver_kwargs

# Define analysis-relevant meta data
self.attrs = {"model": "BaseCrossModel"}
Expand Down
8 changes: 7 additions & 1 deletion xeofs/models/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ class _BaseModel(ABC):
"""

def __init__(
self, n_modes=10, standardize=False, use_coslat=False, use_weights=False
self,
n_modes=10,
standardize=False,
use_coslat=False,
use_weights=False,
solver_kwargs={},
):
# Define model parameters
self._params = {
Expand All @@ -40,6 +45,7 @@ def __init__(
"use_coslat": use_coslat,
"use_weights": use_weights,
}
self._solver_kwargs = solver_kwargs

# Define analysis-relevant meta data
self.attrs = {"model": "BaseModel"}
Expand Down
51 changes: 17 additions & 34 deletions xeofs/models/decomposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,48 +24,37 @@ class Decomposer:
"""

def __init__(self, n_modes=100, n_iter=5, random_state=None, verbose=False):
self.params = {
"n_modes": n_modes,
"n_iter": n_iter,
"random_state": random_state,
"verbose": verbose,
}
def __init__(self, n_modes=100, **kwargs):
self.n_modes = n_modes
self.solver_kwargs = kwargs

def fit(self, X):
svd_kwargs = {}

is_dask = True if isinstance(X.data, DaskArray) else False
is_complex = True if np.iscomplexobj(X.data) else False

if (not is_complex) and (not is_dask):
svd_kwargs.update(
{
"n_components": self.params["n_modes"],
"random_state": self.params["random_state"],
}
)
self.solver_kwargs.update({"n_components": self.n_modes})

U, s, VT = xr.apply_ufunc(
svd,
X,
kwargs=svd_kwargs,
kwargs=self.solver_kwargs,
input_core_dims=[["sample", "feature"]],
output_core_dims=[["sample", "mode"], ["mode"], ["mode", "feature"]],
)

elif is_complex and (not is_dask):
# Scipy sparse version
svd_kwargs.update(
self.solver_kwargs.update(
{
"k": self.n_modes,
"solver": "lobpcg",
"k": self.params["n_modes"],
}
)
U, s, VT = xr.apply_ufunc(
complex_svd,
X,
kwargs=svd_kwargs,
kwargs=self.solver_kwargs,
input_core_dims=[["sample", "feature"]],
output_core_dims=[["sample", "mode"], ["mode"], ["mode", "feature"]],
)
Expand All @@ -75,11 +64,11 @@ def fit(self, X):
VT = VT[idx_sort, :]

elif (not is_complex) and is_dask:
svd_kwargs.update({"k": self.params["n_modes"]})
self.solver_kwargs.update({"k": self.n_modes})
U, s, VT = xr.apply_ufunc(
dask_svd,
X,
kwargs=svd_kwargs,
kwargs=self.solver_kwargs,
input_core_dims=[["sample", "feature"]],
output_core_dims=[["sample", "mode"], ["mode"], ["mode", "feature"]],
dask="allowed",
Expand Down Expand Up @@ -164,35 +153,29 @@ def fit(self, X1, X2):
is_dask = True if isinstance(cov_matrix.data, DaskArray) else False
is_complex = True if np.iscomplexobj(cov_matrix.data) else False

svd_kwargs = {}
if (not is_complex) and (not is_dask):
svd_kwargs.update(
{
"n_components": self.params["n_modes"],
"random_state": self.params["random_state"],
}
)
self.solver_kwargs.update({"n_components": self.n_modes})

U, s, VT = xr.apply_ufunc(
svd,
cov_matrix,
kwargs=svd_kwargs,
kwargs=self.solver_kwargs,
input_core_dims=[["feature1", "feature2"]],
output_core_dims=[["feature1", "mode"], ["mode"], ["mode", "feature2"]],
)

elif (is_complex) and (not is_dask):
# Scipy sparse version
svd_kwargs.update(
self.solver_kwargs.update(
{
"k": self.n_modes,
"solver": "lobpcg",
"k": self.params["n_modes"],
}
)
U, s, VT = xr.apply_ufunc(
complex_svd,
cov_matrix,
kwargs=svd_kwargs,
kwargs=self.solver_kwargs,
input_core_dims=[["feature1", "feature2"]],
output_core_dims=[["feature1", "mode"], ["mode"], ["mode", "feature2"]],
)
Expand All @@ -202,11 +185,11 @@ def fit(self, X1, X2):
VT = VT[idx_sort, :]

elif (not is_complex) and (is_dask):
svd_kwargs.update({"k": self.params["n_modes"]})
self.solver_kwargs.update({"k": self.n_modes})
U, s, VT = xr.apply_ufunc(
dask_svd,
cov_matrix,
kwargs=svd_kwargs,
kwargs=self.solver_kwargs,
input_core_dims=[["feature1", "feature2"]],
output_core_dims=[["feature1", "mode"], ["mode"], ["mode", "feature2"]],
dask="allowed",
Expand Down
16 changes: 13 additions & 3 deletions xeofs/models/eof.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class EOF(_BaseModel):
Whether to use cosine of latitude for scaling.
use_weights: bool, default=False
Whether to use weights.
solver_kwargs: dict, default={}
Additional keyword arguments to be passed to the SVD solver.
Examples
--------
Expand All @@ -33,13 +35,19 @@ class EOF(_BaseModel):
"""

def __init__(
self, n_modes=10, standardize=False, use_coslat=False, use_weights=False
self,
n_modes=10,
standardize=False,
use_coslat=False,
use_weights=False,
solver_kwargs={},
):
super().__init__(
n_modes=n_modes,
standardize=standardize,
use_coslat=use_coslat,
use_weights=use_weights,
solver_kwargs=solver_kwargs,
)
self.attrs.update({"model": "EOF analysis"})

Expand All @@ -56,7 +64,7 @@ def fit(self, data: AnyDataObject, dim, weights=None):
# Decompose the data
n_modes = self._params["n_modes"]

decomposer = Decomposer(n_modes=n_modes)
decomposer = Decomposer(n_modes=n_modes, **self._solver_kwargs)
decomposer.fit(input_data)

singular_values = decomposer.singular_values_
Expand Down Expand Up @@ -256,6 +264,8 @@ class ComplexEOF(EOF):
A smaller value (e.g. 0.05) is recommended for
data with high variability, while a larger value (e.g. 0.2) is recommended
for data with low variability. Default is 0.2.
solver_kwargs : dict, optional
Additional keyword arguments to be passed to the SVD solver.
References
----------
Expand Down Expand Up @@ -294,7 +304,7 @@ def fit(self, data: AnyDataObject, dim, weights=None):
# Decompose the complex data
n_modes = self._params["n_modes"]

decomposer = Decomposer(n_modes=n_modes)
decomposer = Decomposer(n_modes=n_modes, **self._solver_kwargs)
decomposer.fit(input_data)

singular_values = decomposer.singular_values_
Expand Down
Loading

0 comments on commit a75c805

Please sign in to comment.