diff --git a/tests/models/test_decomposer.py b/tests/models/test_decomposer.py index b269428..6ff5305 100644 --- a/tests/models/test_decomposer.py +++ b/tests/models/test_decomposer.py @@ -10,12 +10,12 @@ @pytest.fixture def decomposer(): - return Decomposer(n_modes=2, n_iter=3, random_state=42, verbose=False) + return Decomposer(n_modes=2, random_state=42) @pytest.fixture def cross_decomposer(): - return CrossDecomposer(n_modes=2, n_iter=3, random_state=42, verbose=False) + return CrossDecomposer(n_modes=2, random_state=42) @pytest.fixture @@ -41,17 +41,13 @@ def test_complex_dask_data_array(mock_complex_data_array): def test_decomposer_init(decomposer): - assert decomposer.params["n_modes"] == 2 - assert decomposer.params["n_iter"] == 3 - assert decomposer.params["random_state"] == 42 - assert decomposer.params["verbose"] == False + assert decomposer.n_modes == 2 + assert decomposer.solver_kwargs["random_state"] == 42 def test_cross_decomposer_init(cross_decomposer): - assert cross_decomposer.params["n_modes"] == 2 - assert cross_decomposer.params["n_iter"] == 3 - assert cross_decomposer.params["random_state"] == 42 - assert cross_decomposer.params["verbose"] == False + assert cross_decomposer.n_modes == 2 + assert cross_decomposer.solver_kwargs["random_state"] == 42 def test_decomposer_fit(decomposer, mock_data_array): @@ -61,7 +57,10 @@ def test_decomposer_fit(decomposer, mock_data_array): assert "components_" in decomposer.__dict__ -def test_decomposer_fit_dask(decomposer, mock_dask_data_array): +def test_decomposer_fit_dask(mock_dask_data_array): + # The Dask SVD solver has no parameter 'random_state' but 'seed' instead, + # so let's create a new decomposer for this case + decomposer = Decomposer(n_modes=2, seed=42) decomposer.fit(mock_dask_data_array) assert "scores_" in decomposer.__dict__ assert "singular_values_" in decomposer.__dict__ @@ -89,7 +88,10 @@ def test_cross_decomposer_fit_complex(cross_decomposer, mock_complex_data_array) assert "singular_vectors2_" in cross_decomposer.__dict__ -def test_cross_decomposer_fit_dask(cross_decomposer, mock_dask_data_array): +def test_cross_decomposer_fit_dask(mock_dask_data_array): + # The Dask SVD solver has no parameter 'random_state' but 'seed' instead, + # so let's create a new decomposer for this case + cross_decomposer = CrossDecomposer(n_modes=2, seed=42) cross_decomposer.fit(mock_dask_data_array, mock_dask_data_array) assert "singular_vectors1_" in cross_decomposer.__dict__ assert "singular_values_" in cross_decomposer.__dict__ diff --git a/tests/models/test_eof.py b/tests/models/test_eof.py index 1e0bfaf..1af8215 100644 --- a/tests/models/test_eof.py +++ b/tests/models/test_eof.py @@ -273,7 +273,7 @@ def test_eof_transform(dim, mock_data_array): """Test projecting new unseen data onto the components (EOFs/eigenvectors)""" # Create a xarray DataArray with random data - model = EOF(n_modes=5) + model = EOF(n_modes=5, solver_kwargs={"random_state": 0}) model.fit(mock_data_array, dim) scores = model.scores() @@ -293,7 +293,7 @@ def test_eof_transform(dim, mock_data_array): # Check that the projection's data is the same as the scores np.testing.assert_allclose( - scores.sel(mode=slice(1, 3)), projections.sel(mode=slice(1, 3)), rtol=1e-2 + scores.sel(mode=slice(1, 3)), projections.sel(mode=slice(1, 3)), rtol=1e-3 ) diff --git a/tests/models/test_orthogonality.py b/tests/models/test_orthogonality.py index ca26094..98425e2 100644 --- a/tests/models/test_orthogonality.py +++ b/tests/models/test_orthogonality.py @@ -214,10 +214,10 @@ def test_mca_components(dim, use_coslat, mock_data_array): K1 = V1.T @ V1 K2 = V2.T @ V2 assert np.allclose( - K1, np.eye(K1.shape[0]), atol=1e-5 + K1, np.eye(K1.shape[0]), rtol=1e-8 ), "Left components are not orthogonal" assert np.allclose( - K2, np.eye(K2.shape[0]), atol=1e-5 + K2, np.eye(K2.shape[0]), rtol=1e-8 ), "Right components are not orthogonal" @@ -321,10 +321,10 @@ def test_rmca_components(dim, use_coslat, power, squared_loadings, mock_data_arr K1 = V1.conj().T @ V1 K2 = V2.conj().T @ V2 assert np.allclose( - np.diag(K1), np.ones(K1.shape[0]), atol=5.0e-2 + np.diag(K1), np.ones(K1.shape[0]), rtol=1e-5 ), "Components are not normalized" assert np.allclose( - np.diag(K2), np.ones(K2.shape[0]), atol=5.0e-2 + np.diag(K2), np.ones(K2.shape[0]), rtol=1e-5 ), "Components are not normalized" # Assert that off-diagonals are not zero assert not np.allclose(K1, np.eye(K1.shape[0])), "Rotated components are orthogonal" @@ -398,10 +398,10 @@ def test_crmca_components(dim, use_coslat, power, squared_loadings, mock_data_ar K1 = V1.conj().T @ V1 K2 = V2.conj().T @ V2 assert np.allclose( - np.diag(K1), np.ones(K1.shape[0]), atol=1.0e-1 + np.diag(K1), np.ones(K1.shape[0]), rtol=1e-5 ), "Components are not normalized" assert np.allclose( - np.diag(K2), np.ones(K2.shape[0]), atol=1.0e-1 + np.diag(K2), np.ones(K2.shape[0]), rtol=1e-5 ), "Components are not normalized" # Assert that off-diagonals are not zero assert not np.allclose(K1, np.eye(K1.shape[0])), "Rotated components are orthogonal" @@ -498,15 +498,23 @@ def test_ceof_transform(dim, use_coslat, mock_data_array): ) def test_reof_transform(dim, use_coslat, power, mock_data_array): """Transforming the original data results in the model scores""" - model = EOF(n_modes=5, standardize=True, use_coslat=use_coslat) + model = EOF( + n_modes=5, + standardize=True, + use_coslat=use_coslat, + solver_kwargs={"random_state": 0}, + ) model.fit(mock_data_array, dim=dim) rot = EOFRotator(n_modes=5, power=power) rot.fit(model) scores = rot.scores() pseudo_scores = rot.transform(mock_data_array) - assert np.allclose( - scores, pseudo_scores, atol=1e-3 - ), "Transformed data does not match the scores" + np.testing.assert_allclose( + scores, + pseudo_scores, + rtol=5e-3, + err_msg="Transformed data does not match the scores", + ) # Complex Rotated EOF diff --git a/xeofs/models/_base_cross_model.py b/xeofs/models/_base_cross_model.py index dbaac05..b108790 100644 --- a/xeofs/models/_base_cross_model.py +++ b/xeofs/models/_base_cross_model.py @@ -28,7 +28,12 @@ class _BaseCrossModel(ABC): """ def __init__( - self, n_modes=10, standardize=False, use_coslat=False, use_weights=False + self, + n_modes=10, + standardize=False, + use_coslat=False, + use_weights=False, + solver_kwargs={}, ): # Define model parameters self._params = { @@ -37,6 +42,7 @@ def __init__( "use_coslat": use_coslat, "use_weights": use_weights, } + self._solver_kwargs = solver_kwargs # Define analysis-relevant meta data self.attrs = {"model": "BaseCrossModel"} diff --git a/xeofs/models/_base_model.py b/xeofs/models/_base_model.py index 6a6d918..ee64a48 100644 --- a/xeofs/models/_base_model.py +++ b/xeofs/models/_base_model.py @@ -31,7 +31,12 @@ class _BaseModel(ABC): """ def __init__( - self, n_modes=10, standardize=False, use_coslat=False, use_weights=False + self, + n_modes=10, + standardize=False, + use_coslat=False, + use_weights=False, + solver_kwargs={}, ): # Define model parameters self._params = { @@ -40,6 +45,7 @@ def __init__( "use_coslat": use_coslat, "use_weights": use_weights, } + self._solver_kwargs = solver_kwargs # Define analysis-relevant meta data self.attrs = {"model": "BaseModel"} diff --git a/xeofs/models/decomposer.py b/xeofs/models/decomposer.py index e5fca65..a6f6713 100644 --- a/xeofs/models/decomposer.py +++ b/xeofs/models/decomposer.py @@ -24,48 +24,37 @@ class Decomposer: """ - def __init__(self, n_modes=100, n_iter=5, random_state=None, verbose=False): - self.params = { - "n_modes": n_modes, - "n_iter": n_iter, - "random_state": random_state, - "verbose": verbose, - } + def __init__(self, n_modes=100, **kwargs): + self.n_modes = n_modes + self.solver_kwargs = kwargs def fit(self, X): - svd_kwargs = {} - is_dask = True if isinstance(X.data, DaskArray) else False is_complex = True if np.iscomplexobj(X.data) else False if (not is_complex) and (not is_dask): - svd_kwargs.update( - { - "n_components": self.params["n_modes"], - "random_state": self.params["random_state"], - } - ) + self.solver_kwargs.update({"n_components": self.n_modes}) U, s, VT = xr.apply_ufunc( svd, X, - kwargs=svd_kwargs, + kwargs=self.solver_kwargs, input_core_dims=[["sample", "feature"]], output_core_dims=[["sample", "mode"], ["mode"], ["mode", "feature"]], ) elif is_complex and (not is_dask): # Scipy sparse version - svd_kwargs.update( + self.solver_kwargs.update( { + "k": self.n_modes, "solver": "lobpcg", - "k": self.params["n_modes"], } ) U, s, VT = xr.apply_ufunc( complex_svd, X, - kwargs=svd_kwargs, + kwargs=self.solver_kwargs, input_core_dims=[["sample", "feature"]], output_core_dims=[["sample", "mode"], ["mode"], ["mode", "feature"]], ) @@ -75,11 +64,11 @@ def fit(self, X): VT = VT[idx_sort, :] elif (not is_complex) and is_dask: - svd_kwargs.update({"k": self.params["n_modes"]}) + self.solver_kwargs.update({"k": self.n_modes}) U, s, VT = xr.apply_ufunc( dask_svd, X, - kwargs=svd_kwargs, + kwargs=self.solver_kwargs, input_core_dims=[["sample", "feature"]], output_core_dims=[["sample", "mode"], ["mode"], ["mode", "feature"]], dask="allowed", @@ -164,35 +153,29 @@ def fit(self, X1, X2): is_dask = True if isinstance(cov_matrix.data, DaskArray) else False is_complex = True if np.iscomplexobj(cov_matrix.data) else False - svd_kwargs = {} if (not is_complex) and (not is_dask): - svd_kwargs.update( - { - "n_components": self.params["n_modes"], - "random_state": self.params["random_state"], - } - ) + self.solver_kwargs.update({"n_components": self.n_modes}) U, s, VT = xr.apply_ufunc( svd, cov_matrix, - kwargs=svd_kwargs, + kwargs=self.solver_kwargs, input_core_dims=[["feature1", "feature2"]], output_core_dims=[["feature1", "mode"], ["mode"], ["mode", "feature2"]], ) elif (is_complex) and (not is_dask): # Scipy sparse version - svd_kwargs.update( + self.solver_kwargs.update( { + "k": self.n_modes, "solver": "lobpcg", - "k": self.params["n_modes"], } ) U, s, VT = xr.apply_ufunc( complex_svd, cov_matrix, - kwargs=svd_kwargs, + kwargs=self.solver_kwargs, input_core_dims=[["feature1", "feature2"]], output_core_dims=[["feature1", "mode"], ["mode"], ["mode", "feature2"]], ) @@ -202,11 +185,11 @@ def fit(self, X1, X2): VT = VT[idx_sort, :] elif (not is_complex) and (is_dask): - svd_kwargs.update({"k": self.params["n_modes"]}) + self.solver_kwargs.update({"k": self.n_modes}) U, s, VT = xr.apply_ufunc( dask_svd, cov_matrix, - kwargs=svd_kwargs, + kwargs=self.solver_kwargs, input_core_dims=[["feature1", "feature2"]], output_core_dims=[["feature1", "mode"], ["mode"], ["mode", "feature2"]], dask="allowed", diff --git a/xeofs/models/eof.py b/xeofs/models/eof.py index bf44007..97c4a6e 100644 --- a/xeofs/models/eof.py +++ b/xeofs/models/eof.py @@ -23,6 +23,8 @@ class EOF(_BaseModel): Whether to use cosine of latitude for scaling. use_weights: bool, default=False Whether to use weights. + solver_kwargs: dict, default={} + Additional keyword arguments to be passed to the SVD solver. Examples -------- @@ -33,13 +35,19 @@ class EOF(_BaseModel): """ def __init__( - self, n_modes=10, standardize=False, use_coslat=False, use_weights=False + self, + n_modes=10, + standardize=False, + use_coslat=False, + use_weights=False, + solver_kwargs={}, ): super().__init__( n_modes=n_modes, standardize=standardize, use_coslat=use_coslat, use_weights=use_weights, + solver_kwargs=solver_kwargs, ) self.attrs.update({"model": "EOF analysis"}) @@ -56,7 +64,7 @@ def fit(self, data: AnyDataObject, dim, weights=None): # Decompose the data n_modes = self._params["n_modes"] - decomposer = Decomposer(n_modes=n_modes) + decomposer = Decomposer(n_modes=n_modes, **self._solver_kwargs) decomposer.fit(input_data) singular_values = decomposer.singular_values_ @@ -256,6 +264,8 @@ class ComplexEOF(EOF): A smaller value (e.g. 0.05) is recommended for data with high variability, while a larger value (e.g. 0.2) is recommended for data with low variability. Default is 0.2. + solver_kwargs : dict, optional + Additional keyword arguments to be passed to the SVD solver. References ---------- @@ -294,7 +304,7 @@ def fit(self, data: AnyDataObject, dim, weights=None): # Decompose the complex data n_modes = self._params["n_modes"] - decomposer = Decomposer(n_modes=n_modes) + decomposer = Decomposer(n_modes=n_modes, **self._solver_kwargs) decomposer.fit(input_data) singular_values = decomposer.singular_values_ diff --git a/xeofs/models/mca.py b/xeofs/models/mca.py index 068e15c..50942c6 100644 --- a/xeofs/models/mca.py +++ b/xeofs/models/mca.py @@ -30,6 +30,8 @@ class MCA(_BaseCrossModel): Whether to use cosine of latitude for scaling. use_weights: bool, default=False Whether to use additional weights. + solver_kwargs: dict, default={} + Additional keyword arguments passed to the SVD solver. Notes ----- @@ -49,8 +51,8 @@ class MCA(_BaseCrossModel): """ - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self, solver_kwargs={}, **kwargs): + super().__init__(solver_kwargs=solver_kwargs, **kwargs) self.attrs.update({"model": "MCA"}) # Initialize the DataContainer to store the results @@ -71,7 +73,9 @@ def fit( data2, dim, weights2 ) - decomposer = CrossDecomposer(n_modes=self._params["n_modes"]) + decomposer = CrossDecomposer( + n_modes=self._params["n_modes"], **self._solver_kwargs + ) decomposer.fit(data1_processed, data2_processed) # Note: @@ -466,6 +470,8 @@ class ComplexMCA(MCA): A smaller value (e.g. 0.05) is recommended for data with high variability, while a larger value (e.g. 0.2) is recommended for data with low variability. Default is 0.2. + solver_kwargs: dict, default={} + Additional keyword arguments passed to the SVD solver. Notes ----- @@ -540,7 +546,9 @@ def fit( data2_processed, dim="sample", padding=padding, decay_factor=decay_factor ) - decomposer = CrossDecomposer(n_modes=self._params["n_modes"]) + decomposer = CrossDecomposer( + n_modes=self._params["n_modes"], **self._solver_kwargs + ) decomposer.fit(data1_processed, data2_processed) # Note: