diff --git a/Makefile b/Makefile index e8c7849..6475b0d 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,11 @@ LINT_TARGET_DIRS := PyEMD doc example +init: + python -m venv .venv + .venv/bin/pip install -r requirements.txt + .venv/bin/pip install -e .[dev] + @echo "Run 'source .venv/bin/activate' to activate the virtual environment" + test: python -m PyEMD.tests.test_all @@ -12,6 +18,7 @@ doc: format: python -m black $(LINT_TARGET_DIRS) + python -m isort PyEMD lint-check: python -m isort --check PyEMD diff --git a/PyEMD/EMD.py b/PyEMD/EMD.py index a2137b6..1d7e58f 100644 --- a/PyEMD/EMD.py +++ b/PyEMD/EMD.py @@ -13,7 +13,7 @@ from scipy.interpolate import interp1d from PyEMD.splines import akima, cubic, cubic_hermite, cubic_spline_3pts, pchip -from PyEMD.utils import get_timeline +from PyEMD.utils import deduce_common_type, get_timeline FindExtremaOutput = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray] @@ -199,14 +199,14 @@ def prepare_points( Position (1st row) and values (2nd row) of maxima. """ if self.extrema_detection == "parabol": - return self._prepare_points_parabol(T, S, max_pos, max_val, min_pos, min_val) + return self.prepare_points_parabol(T, S, max_pos, max_val, min_pos, min_val) elif self.extrema_detection == "simple": - return self._prepare_points_simple(T, S, max_pos, max_val, min_pos, min_val) + return self.prepare_points_simple(T, S, max_pos, max_val, min_pos, min_val) else: msg = "Incorrect extrema detection type. Please try: 'simple' or 'parabol'." raise ValueError(msg) - def _prepare_points_parabol(self, T, S, max_pos, max_val, min_pos, min_val) -> Tuple[np.ndarray, np.ndarray]: + def prepare_points_parabol(self, T, S, max_pos, max_val, min_pos, min_val) -> Tuple[np.ndarray, np.ndarray]: """ Performs mirroring on signal which extrema do not necessarily belong on the position array. @@ -324,7 +324,7 @@ def _prepare_points_parabol(self, T, S, max_pos, max_val, min_pos, min_val) -> T return max_extrema, min_extrema - def _prepare_points_simple( + def prepare_points_simple( self, T: np.ndarray, S: np.ndarray, @@ -765,7 +765,7 @@ def check_imf( @staticmethod def _common_dtype(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Casts inputs (x, y) into a common numpy DTYPE.""" - dtype = np.find_common_type([x.dtype, y.dtype], []) + dtype = deduce_common_type(x.dtype, y.dtype) if x.dtype != dtype: x = x.astype(dtype) if y.dtype != dtype: diff --git a/PyEMD/tests/test_checks.py b/PyEMD/tests/test_checks.py index 874f359..c4801f1 100644 --- a/PyEMD/tests/test_checks.py +++ b/PyEMD/tests/test_checks.py @@ -77,13 +77,13 @@ def test_whitenoise_check_rescaling_imf(self): def test_whitenoise_check_nan_values(self): """whitenoise check with nan in IMF.""" - S = np.array([np.full(100, np.NaN) for i in range(5, 0, -1)]) + S = np.array([np.full(100, np.nan) for i in range(5, 0, -1)]) res = whitenoise_check(S) - self.assertEqual(res, None, "Input NaN returns None") + self.assertEqual(res, None, "Input nan returns None") def test_invalid_alpha(self): """Test if invalid alpha return AssertionError.""" - S = np.array([np.full(100, np.NaN) for i in range(5, 0, -1)]) + S = np.array([np.full(100, np.nan) for i in range(5, 0, -1)]) self.assertRaises(AssertionError, whitenoise_check, S, alpha=1) self.assertRaises(AssertionError, whitenoise_check, S, alpha=0) self.assertRaises(AssertionError, whitenoise_check, S, alpha=-10) @@ -99,7 +99,7 @@ def test_invalid_test_name(self): def test_invalid_input_type(self): """Test if invalid input type return AssertionError.""" - S = [np.full(100, np.NaN) for i in range(5, 0, -1)] + S = [np.full(100, np.nan) for i in range(5, 0, -1)] self.assertRaises(AssertionError, whitenoise_check, S) self.assertRaises(AssertionError, whitenoise_check, 1) self.assertRaises(AssertionError, whitenoise_check, 1.2) diff --git a/PyEMD/tests/test_utils.py b/PyEMD/tests/test_utils.py index 63fbc6e..e766ff2 100644 --- a/PyEMD/tests/test_utils.py +++ b/PyEMD/tests/test_utils.py @@ -2,7 +2,7 @@ import numpy as np -from PyEMD.utils import get_timeline +from PyEMD.utils import deduce_common_type, get_timeline class MyTestCase(unittest.TestCase): @@ -31,14 +31,11 @@ def test_get_timeline_does_not_overflow_int16(self): self.assertEqual(T[-1], len(S) - 1, "Range is kept") self.assertEqual(T.dtype, np.uint16, "UInt16 is the min type that matches requirements") - def test_get_timeline_does_not_overflow_float16(self): - S = np.random.random(int(np.finfo(np.float16).max) + 5).astype(dtype=np.float16) - T = get_timeline(len(S), dtype=S.dtype) - - self.assertGreater(len(S), np.finfo(S.dtype).max, "Length of the signal is greater than its type max value") - self.assertEqual(len(T), len(S), "Lengths must be equal") - self.assertEqual(T[-1], len(S) - 1, "Range is kept") - self.assertEqual(T.dtype, np.float32, "Float32 is the min type that matches requirements") + def test_deduce_common_types(self): + self.assertEqual(deduce_common_type(np.int16, np.int32), np.int32) + self.assertEqual(deduce_common_type(np.int32, np.int16), np.int32) + self.assertEqual(deduce_common_type(np.int32, np.int32), np.int32) + self.assertEqual(deduce_common_type(np.float32, np.float64), np.float64) if __name__ == "__main__": diff --git a/PyEMD/tests/test_visualization.py b/PyEMD/tests/test_visualization.py index 882e6b6..982f9b3 100644 --- a/PyEMD/tests/test_visualization.py +++ b/PyEMD/tests/test_visualization.py @@ -19,8 +19,8 @@ def test_instantiation2(self): emd.emd(S, t) imfs, res = emd.get_imfs_and_residue() vis = Visualisation(emd) - self.assertTrue(np.alltrue(vis.imfs == imfs)) - self.assertTrue(np.alltrue(vis.residue == res)) + self.assertTrue(np.all(vis.imfs == imfs)) + self.assertTrue(np.all(vis.residue == res)) def test_check_imfs(self): vis = Visualisation() @@ -40,7 +40,7 @@ def test_check_imfs3(self): out_imfs, out_res = vis._check_imfs(imfs, None, False) - self.assertTrue(np.alltrue(imfs == out_imfs)) + self.assertTrue(np.all(imfs == out_imfs)) self.assertIsNone(out_res) def test_check_imfs4(self): @@ -57,8 +57,8 @@ def test_check_imfs5(self): imfs, res = emd.get_imfs_and_residue() vis = Visualisation(emd) imfs2, res2 = vis._check_imfs(imfs, res, False) - self.assertTrue(np.alltrue(imfs == imfs2)) - self.assertTrue(np.alltrue(res == res2)) + self.assertTrue(np.all(imfs == imfs2)) + self.assertTrue(np.all(res == res2)) def test_plot_imfs(self): vis = Visualisation() diff --git a/PyEMD/utils.py b/PyEMD/utils.py index 3330ad3..4c56f99 100644 --- a/PyEMD/utils.py +++ b/PyEMD/utils.py @@ -1,7 +1,13 @@ +import sys from typing import Optional import numpy as np +if sys.version_info >= (3, 9): + from functools import cache +else: + from functools import lru_cache as cache + def get_timeline(range_max: int, dtype: Optional[np.dtype] = None) -> np.ndarray: """Returns timeline array for requirements. @@ -50,3 +56,14 @@ def smallest_inclusive_dtype(ref_dtype: np.dtype, ref_value) -> np.dtype: raise ValueError("Requested too large integer range. Exceeds max( float64 ) == '{}.".format(max_val)) raise ValueError("Unsupported dtype '{}'. Only intX and floatX are supported.".format(ref_dtype)) + + +@cache +def deduce_common_type(xtype: np.dtype, ytype: np.dtype) -> np.dtype: + if xtype == ytype: + return xtype + if np.version.version[0] == "1": + dtype = np.find_common_type([xtype, ytype], []) + else: + dtype = np.promote_types(xtype, ytype) + return dtype diff --git a/requirements.txt b/requirements.txt index 094e97f..77c4133 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ numpy>=1.12 scipy>=0.19 pathos>=0.2.1 +tqdm>=4.64.0,<5.0