Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support numpy 2+ #162

Merged
merged 3 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
LINT_TARGET_DIRS := PyEMD doc example

init:
python -m venv .venv
.venv/bin/pip install -r requirements.txt
.venv/bin/pip install -e .[dev]
@echo "Run 'source .venv/bin/activate' to activate the virtual environment"

test:
python -m PyEMD.tests.test_all

Expand All @@ -12,6 +18,7 @@ doc:

format:
python -m black $(LINT_TARGET_DIRS)
python -m isort PyEMD

lint-check:
python -m isort --check PyEMD
Expand Down
12 changes: 6 additions & 6 deletions PyEMD/EMD.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from scipy.interpolate import interp1d

from PyEMD.splines import akima, cubic, cubic_hermite, cubic_spline_3pts, pchip
from PyEMD.utils import get_timeline
from PyEMD.utils import deduce_common_type, get_timeline

FindExtremaOutput = Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]

Expand Down Expand Up @@ -199,14 +199,14 @@ def prepare_points(
Position (1st row) and values (2nd row) of maxima.
"""
if self.extrema_detection == "parabol":
return self._prepare_points_parabol(T, S, max_pos, max_val, min_pos, min_val)
return self.prepare_points_parabol(T, S, max_pos, max_val, min_pos, min_val)
elif self.extrema_detection == "simple":
return self._prepare_points_simple(T, S, max_pos, max_val, min_pos, min_val)
return self.prepare_points_simple(T, S, max_pos, max_val, min_pos, min_val)
else:
msg = "Incorrect extrema detection type. Please try: 'simple' or 'parabol'."
raise ValueError(msg)

def _prepare_points_parabol(self, T, S, max_pos, max_val, min_pos, min_val) -> Tuple[np.ndarray, np.ndarray]:
def prepare_points_parabol(self, T, S, max_pos, max_val, min_pos, min_val) -> Tuple[np.ndarray, np.ndarray]:
"""
Performs mirroring on signal which extrema do not necessarily
belong on the position array.
Expand Down Expand Up @@ -324,7 +324,7 @@ def _prepare_points_parabol(self, T, S, max_pos, max_val, min_pos, min_val) -> T

return max_extrema, min_extrema

def _prepare_points_simple(
def prepare_points_simple(
self,
T: np.ndarray,
S: np.ndarray,
Expand Down Expand Up @@ -765,7 +765,7 @@ def check_imf(
@staticmethod
def _common_dtype(x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Casts inputs (x, y) into a common numpy DTYPE."""
dtype = np.find_common_type([x.dtype, y.dtype], [])
dtype = deduce_common_type(x.dtype, y.dtype)
if x.dtype != dtype:
x = x.astype(dtype)
if y.dtype != dtype:
Expand Down
8 changes: 4 additions & 4 deletions PyEMD/tests/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ def test_whitenoise_check_rescaling_imf(self):

def test_whitenoise_check_nan_values(self):
"""whitenoise check with nan in IMF."""
S = np.array([np.full(100, np.NaN) for i in range(5, 0, -1)])
S = np.array([np.full(100, np.nan) for i in range(5, 0, -1)])
res = whitenoise_check(S)
self.assertEqual(res, None, "Input NaN returns None")
self.assertEqual(res, None, "Input nan returns None")

def test_invalid_alpha(self):
"""Test if invalid alpha return AssertionError."""
S = np.array([np.full(100, np.NaN) for i in range(5, 0, -1)])
S = np.array([np.full(100, np.nan) for i in range(5, 0, -1)])
self.assertRaises(AssertionError, whitenoise_check, S, alpha=1)
self.assertRaises(AssertionError, whitenoise_check, S, alpha=0)
self.assertRaises(AssertionError, whitenoise_check, S, alpha=-10)
Expand All @@ -99,7 +99,7 @@ def test_invalid_test_name(self):

def test_invalid_input_type(self):
"""Test if invalid input type return AssertionError."""
S = [np.full(100, np.NaN) for i in range(5, 0, -1)]
S = [np.full(100, np.nan) for i in range(5, 0, -1)]
self.assertRaises(AssertionError, whitenoise_check, S)
self.assertRaises(AssertionError, whitenoise_check, 1)
self.assertRaises(AssertionError, whitenoise_check, 1.2)
Expand Down
15 changes: 6 additions & 9 deletions PyEMD/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from PyEMD.utils import get_timeline
from PyEMD.utils import deduce_common_type, get_timeline


class MyTestCase(unittest.TestCase):
Expand Down Expand Up @@ -31,14 +31,11 @@ def test_get_timeline_does_not_overflow_int16(self):
self.assertEqual(T[-1], len(S) - 1, "Range is kept")
self.assertEqual(T.dtype, np.uint16, "UInt16 is the min type that matches requirements")

def test_get_timeline_does_not_overflow_float16(self):
S = np.random.random(int(np.finfo(np.float16).max) + 5).astype(dtype=np.float16)
T = get_timeline(len(S), dtype=S.dtype)

self.assertGreater(len(S), np.finfo(S.dtype).max, "Length of the signal is greater than its type max value")
self.assertEqual(len(T), len(S), "Lengths must be equal")
self.assertEqual(T[-1], len(S) - 1, "Range is kept")
self.assertEqual(T.dtype, np.float32, "Float32 is the min type that matches requirements")
def test_deduce_common_types(self):
self.assertEqual(deduce_common_type(np.int16, np.int32), np.int32)
self.assertEqual(deduce_common_type(np.int32, np.int16), np.int32)
self.assertEqual(deduce_common_type(np.int32, np.int32), np.int32)
self.assertEqual(deduce_common_type(np.float32, np.float64), np.float64)


if __name__ == "__main__":
Expand Down
10 changes: 5 additions & 5 deletions PyEMD/tests/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def test_instantiation2(self):
emd.emd(S, t)
imfs, res = emd.get_imfs_and_residue()
vis = Visualisation(emd)
self.assertTrue(np.alltrue(vis.imfs == imfs))
self.assertTrue(np.alltrue(vis.residue == res))
self.assertTrue(np.all(vis.imfs == imfs))
self.assertTrue(np.all(vis.residue == res))

def test_check_imfs(self):
vis = Visualisation()
Expand All @@ -40,7 +40,7 @@ def test_check_imfs3(self):

out_imfs, out_res = vis._check_imfs(imfs, None, False)

self.assertTrue(np.alltrue(imfs == out_imfs))
self.assertTrue(np.all(imfs == out_imfs))
self.assertIsNone(out_res)

def test_check_imfs4(self):
Expand All @@ -57,8 +57,8 @@ def test_check_imfs5(self):
imfs, res = emd.get_imfs_and_residue()
vis = Visualisation(emd)
imfs2, res2 = vis._check_imfs(imfs, res, False)
self.assertTrue(np.alltrue(imfs == imfs2))
self.assertTrue(np.alltrue(res == res2))
self.assertTrue(np.all(imfs == imfs2))
self.assertTrue(np.all(res == res2))

def test_plot_imfs(self):
vis = Visualisation()
Expand Down
17 changes: 17 additions & 0 deletions PyEMD/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import sys
from typing import Optional

import numpy as np

if sys.version_info >= (3, 9):
from functools import cache
else:
from functools import lru_cache as cache


def get_timeline(range_max: int, dtype: Optional[np.dtype] = None) -> np.ndarray:
"""Returns timeline array for requirements.
Expand Down Expand Up @@ -50,3 +56,14 @@ def smallest_inclusive_dtype(ref_dtype: np.dtype, ref_value) -> np.dtype:
raise ValueError("Requested too large integer range. Exceeds max( float64 ) == '{}.".format(max_val))

raise ValueError("Unsupported dtype '{}'. Only intX and floatX are supported.".format(ref_dtype))


@cache
def deduce_common_type(xtype: np.dtype, ytype: np.dtype) -> np.dtype:
if xtype == ytype:
return xtype
if np.version.version[0] == "1":
dtype = np.find_common_type([xtype, ytype], [])
else:
dtype = np.promote_types(xtype, ytype)
return dtype
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy>=1.12
scipy>=0.19
pathos>=0.2.1
tqdm>=4.64.0,<5.0