Skip to content

Commit

Permalink
style(tests): apply Black coding format
Browse files Browse the repository at this point in the history
use Black default settings
  • Loading branch information
nicrie committed Aug 2, 2023
1 parent 759b011 commit 717f5fe
Show file tree
Hide file tree
Showing 29 changed files with 3,313 additions and 2,101 deletions.
89 changes: 53 additions & 36 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,123 +8,140 @@
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")


# =============================================================================
# Input data
# =============================================================================
@pytest.fixture
def mock_data_array():
rng = np.random.default_rng(7)
noise = rng.normal(5, 3, size=(25, 5, 4))
signal = 2*np.sin(np.linspace(0, 2*np.pi, 25))[:, None, None]
signal = 2 * np.sin(np.linspace(0, 2 * np.pi, 25))[:, None, None]
return xr.DataArray(
signal + noise,
dims=('time', 'lat', 'lon'),
coords={
'time': xr.date_range('2001', '2025', freq='YS'),
'lat': [20.0, 30.0, 40.0, 50.0, 60.0],
'lon': [-10.0, 0.0, 10.0, 20.0]
},
name='t2m',
attrs=dict(description='mock_data')
)
signal + noise,
dims=("time", "lat", "lon"),
coords={
"time": xr.date_range("2001", "2025", freq="YS"),
"lat": [20.0, 30.0, 40.0, 50.0, 60.0],
"lon": [-10.0, 0.0, 10.0, 20.0],
},
name="t2m",
attrs=dict(description="mock_data"),
)


@pytest.fixture
def mock_dataset(mock_data_array):
t2m = mock_data_array
prcp = t2m**2
return xr.Dataset({'t2m': t2m, 'prcp': prcp})
return xr.Dataset({"t2m": t2m, "prcp": prcp})


@pytest.fixture
def mock_data_array_list(mock_data_array):
da1 = mock_data_array
da2 = mock_data_array ** 2
da3 = mock_data_array ** 3
da2 = mock_data_array**2
da3 = mock_data_array**3
return [da1, da2, da3]


@pytest.fixture
def mock_data_array_isolated_nans(mock_data_array):
invalid_data = mock_data_array.copy()
invalid_data.loc[dict(time='2001', lat=30.0, lon=-10.0)] = np.nan
invalid_data.loc[dict(time="2001", lat=30.0, lon=-10.0)] = np.nan
return invalid_data


@pytest.fixture
def mock_data_array_full_dimensional_nans(mock_data_array):
valid_data = mock_data_array.copy()
valid_data.loc[dict(lat=30.0)] = np.nan
valid_data.loc[dict(time='2002')] = np.nan
valid_data.loc[dict(time="2002")] = np.nan
return valid_data


@pytest.fixture
def mock_data_array_boundary_nans(mock_data_array):
valid_data = mock_data_array.copy(deep=True)
valid_data.loc[dict(lat=30.0)] = np.nan
valid_data.loc[dict(time='2001')] = np.nan
valid_data.loc[dict(time="2001")] = np.nan
return valid_data


@pytest.fixture
def mock_dask_data_array(mock_data_array):
return mock_data_array.chunk({'lon': 2, 'lat': 2, 'time': -1})
return mock_data_array.chunk({"lon": 2, "lat": 2, "time": -1})


# =============================================================================
# Intermediate data
# =============================================================================
@pytest.fixture
def sample_input_data():
'''Create a sample input data.'''
return xr.DataArray(np.random.rand(10, 20), dims=('sample', 'feature'))
"""Create a sample input data."""
return xr.DataArray(np.random.rand(10, 20), dims=("sample", "feature"))


@pytest.fixture
def sample_components():
'''Create a sample components.'''
return xr.DataArray(np.random.rand(20, 5), dims=('feature', 'mode'))
"""Create a sample components."""
return xr.DataArray(np.random.rand(20, 5), dims=("feature", "mode"))


@pytest.fixture
def sample_scores():
'''Create a sample scores.'''
return xr.DataArray(np.random.rand(10, 5), dims=('sample', 'mode'))
"""Create a sample scores."""
return xr.DataArray(np.random.rand(10, 5), dims=("sample", "mode"))


@pytest.fixture
def sample_exp_var():
return xr.DataArray(
np.random.rand(10),
dims=('mode',),
coords={'mode': np.arange(10)},
name='explained_variance'
dims=("mode",),
coords={"mode": np.arange(10)},
name="explained_variance",
)


@pytest.fixture
def sample_total_variance(sample_exp_var):
return sample_exp_var.sum()


@pytest.fixture
def sample_idx_modes_sorted(sample_exp_var):
return sample_exp_var.argsort()[::-1]


@pytest.fixture
def sample_norm():
return xr.DataArray(np.random.rand(10), dims=('mode',))
return xr.DataArray(np.random.rand(10), dims=("mode",))


@pytest.fixture
def sample_squared_covariance():
return xr.DataArray(np.random.rand(10), dims=('mode',))
return xr.DataArray(np.random.rand(10), dims=("mode",))


@pytest.fixture
def sample_total_squared_covariance(sample_squared_covariance):
return sample_squared_covariance.sum('mode')
return sample_squared_covariance.sum("mode")


@pytest.fixture
def sample_rotation_matrix():
'''Create a sample rotation matrix.'''
return xr.DataArray(np.random.rand(5, 5), dims=('mode_m', 'mode_n'))
"""Create a sample rotation matrix."""
return xr.DataArray(np.random.rand(5, 5), dims=("mode_m", "mode_n"))


@pytest.fixture
def sample_phi_matrix():
'''Create a sample phi matrix.'''
return xr.DataArray(np.random.rand(5, 5), dims=('mode_m', 'mode_n'))
"""Create a sample phi matrix."""
return xr.DataArray(np.random.rand(5, 5), dims=("mode_m", "mode_n"))


@pytest.fixture
def sample_modes_sign():
'''Create a sample modes sign.'''
return xr.DataArray(np.random.choice([-1, 1], size=5), dims=('mode',))
"""Create a sample modes sign."""
return xr.DataArray(np.random.choice([-1, 1], size=5), dims=("mode",))
45 changes: 32 additions & 13 deletions tests/data_container/test_base_cross_model_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import xarray as xr
import numpy as np

from xeofs.data_container._base_cross_model_data_container import _BaseCrossModelDataContainer
from xeofs.data_container._base_cross_model_data_container import (
_BaseCrossModelDataContainer,
)


def test_init():
'''Test the initialization of the BaseCrossModelDataContainer.'''
"""Test the initialization of the BaseCrossModelDataContainer."""
data_container = _BaseCrossModelDataContainer()
assert data_container._input_data1 is None
assert data_container._input_data2 is None
Expand All @@ -15,19 +17,28 @@ def test_init():
assert data_container._scores1 is None
assert data_container._scores2 is None


def test_set_data(sample_input_data, sample_components, sample_scores):
'''Test the set_data() method.'''
"""Test the set_data() method."""
data_container = _BaseCrossModelDataContainer()
data_container.set_data(sample_input_data, sample_input_data, sample_components, sample_components, sample_scores, sample_scores)
data_container.set_data(
sample_input_data,
sample_input_data,
sample_components,
sample_components,
sample_scores,
sample_scores,
)
assert data_container._input_data1 is sample_input_data
assert data_container._input_data2 is sample_input_data
assert data_container._components1 is sample_components
assert data_container._components2 is sample_components
assert data_container._scores1 is sample_scores
assert data_container._scores2 is sample_scores


def test_no_data():
'''Test the data accessors without data.'''
"""Test the data accessors without data."""
data_container = _BaseCrossModelDataContainer()
with pytest.raises(ValueError):
data_container.input_data1
Expand All @@ -42,16 +53,24 @@ def test_no_data():
with pytest.raises(ValueError):
data_container.scores2
with pytest.raises(ValueError):
data_container.set_attrs({'test': 1})
data_container.set_attrs({"test": 1})
with pytest.raises(ValueError):
data_container.compute()


def test_set_attrs(sample_input_data, sample_components, sample_scores):
'''Test the set_attrs() method.'''
"""Test the set_attrs() method."""
data_container = _BaseCrossModelDataContainer()
data_container.set_data(sample_input_data, sample_input_data, sample_components, sample_components, sample_scores, sample_scores)
data_container.set_attrs({'test': 1})
assert data_container.components1.attrs['test'] == 1
assert data_container.components2.attrs['test'] == 1
assert data_container.scores1.attrs['test'] == 1
assert data_container.scores2.attrs['test'] == 1
data_container.set_data(
sample_input_data,
sample_input_data,
sample_components,
sample_components,
sample_scores,
sample_scores,
)
data_container.set_attrs({"test": 1})
assert data_container.components1.attrs["test"] == 1
assert data_container.components2.attrs["test"] == 1
assert data_container.scores1.attrs["test"] == 1
assert data_container.scores2.attrs["test"] == 1
19 changes: 11 additions & 8 deletions tests/data_container/test_base_model_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,24 @@


def test_init():
'''Test the initialization of the BaseModelDataContainer.'''
"""Test the initialization of the BaseModelDataContainer."""
data_container = _BaseModelDataContainer()
assert data_container._input_data is None
assert data_container._components is None
assert data_container._scores is None


def test_set_data(sample_input_data, sample_components, sample_scores):
'''Test the set_data() method.'''
"""Test the set_data() method."""
data_container = _BaseModelDataContainer()
data_container.set_data(sample_input_data, sample_components, sample_scores)
assert data_container._input_data is sample_input_data
assert data_container._components is sample_components
assert data_container._scores is sample_scores


def test_no_data():
'''Test the data accessors without data.'''
"""Test the data accessors without data."""
data_container = _BaseModelDataContainer()
with pytest.raises(ValueError):
data_container.input_data
Expand All @@ -30,14 +32,15 @@ def test_no_data():
with pytest.raises(ValueError):
data_container.scores
with pytest.raises(ValueError):
data_container.set_attrs({'test': 1})
data_container.set_attrs({"test": 1})
with pytest.raises(ValueError):
data_container.compute()


def test_set_attrs(sample_input_data, sample_components, sample_scores):
'''Test the set_attrs() method.'''
"""Test the set_attrs() method."""
data_container = _BaseModelDataContainer()
data_container.set_data(sample_input_data, sample_components, sample_scores)
data_container.set_attrs({'test': 1})
assert data_container.components.attrs['test'] == 1
assert data_container.scores.attrs['test'] == 1
data_container.set_attrs({"test": 1})
assert data_container.components.attrs["test"] == 1
assert data_container.scores.attrs["test"] == 1
Loading

0 comments on commit 717f5fe

Please sign in to comment.