diff --git a/PyEMD/EMD_matlab.py b/PyEMD/EMD_matlab.py index c5ea0ca..cf1a31b 100644 --- a/PyEMD/EMD_matlab.py +++ b/PyEMD/EMD_matlab.py @@ -84,7 +84,9 @@ def extractMaxMinSpline(self, T, S): return [-1] * 4 # Extrapolation of signal (ober boundaries) - maxExtrema, minExtrema = self.preparePoints(S, T, maxPos, maxVal, minPos, minVal) + maxExtrema, minExtrema = self.preparePoints( + S, T, maxPos, maxVal, minPos, minVal + ) _, maxSpline = self.splinePoints(T, maxExtrema, self.splineKind) _, minSpline = self.splinePoints(T, minExtrema, self.splineKind) @@ -218,8 +220,12 @@ def preparePoints(self, S, T, maxPos, maxVal, minPos, minVal): minExtrema = np.array([tmin, zmin], dtype=self.DTYPE) # Make double sure, that each extremum is significant - maxExtrema = np.delete(maxExtrema, np.where(maxExtrema[0, 1:] == maxExtrema[0, :-1]), axis=1) - minExtrema = np.delete(minExtrema, np.where(minExtrema[0, 1:] == minExtrema[0, :-1]), axis=1) + maxExtrema = np.delete( + maxExtrema, np.where(maxExtrema[0, 1:] == maxExtrema[0, :-1]), axis=1 + ) + minExtrema = np.delete( + minExtrema, np.where(minExtrema[0, 1:] == minExtrema[0, :-1]), axis=1 + ) return maxExtrema, minExtrema @@ -251,7 +257,9 @@ def splinePoints(self, T, extrema, splineKind): elif kind == "cubic": if extrema.shape[1] > 3: - return t, interp1d(extrema[0], extrema[1], kind=kind)(t).astype(self.DTYPE) + return t, interp1d(extrema[0], extrema[1], kind=kind)(t).astype( + self.DTYPE + ) else: return self.cubicSpline_3points(T, extrema) @@ -435,8 +443,18 @@ def emd(self, S, T=None, maxImf=None): The decomposition is limited to maxImf imf. No limitation as default. Returns IMF functions in dic format. IMF = {0:imf0, 1:imf1...}. + *Note*: First argument `self` should be an instance of EMD class. + It should be resolved in future versions. + + For example: + ``` + emd = EMD() + emd.emd(emd, S, T, maxImf) + ``` + Input: --------- + self: Instance of EMD class. S: Signal. T: Positions of signal. If none passed numpy arange is created. maxImf: IMF number to which decomposition should be performed. @@ -457,7 +475,7 @@ def emd(self, S, T=None, maxImf=None): maxImf = -1 # Make sure same types are dealt - S, T = unify_type(S, T) + S, T = unify_types(S, T) self.DTYPE = S.dtype Res = S.astype(self.DTYPE) @@ -479,7 +497,7 @@ def emd(self, S, T=None, maxImf=None): if S.shape != T.shape: info = "Time array should be the same size as signal." - raise Exception(info) + raise ValueError(info) # Create arrays IMF = {} # Dic for imfs signals diff --git a/PyEMD/__init__.py b/PyEMD/__init__.py index 17ab5ae..babb476 100644 --- a/PyEMD/__init__.py +++ b/PyEMD/__init__.py @@ -1,6 +1,6 @@ import logging -__version__ = "1.6.3" +__version__ = "1.6.4" logger = logging.getLogger("pyemd") from PyEMD.CEEMDAN import CEEMDAN # noqa diff --git a/PyEMD/tests/test_emd_matlab.py b/PyEMD/tests/test_emd_matlab.py new file mode 100644 index 0000000..5fee65d --- /dev/null +++ b/PyEMD/tests/test_emd_matlab.py @@ -0,0 +1,49 @@ +import unittest + +import numpy as np + +from PyEMD.EMD_matlab import EMD + + +class EMDMatlabTest(unittest.TestCase): + @staticmethod + def test_default_call_EMD(): + T = np.arange(0, 1, 0.01) + S = np.cos(2 * T * 2 * np.pi) + max_imf = 2 + + emd = EMD() + emd.emd(emd, S, T, max_imf) + + def test_different_length_input(self): + T = np.arange(20) + S = np.random.random(len(T) + 7) + + emd = EMD() + with self.assertRaises(ValueError): + emd.emd(emd, S, T) + + def test_trend(self): + """ + Input is trend. Expeting no shifting process. + """ + emd = EMD() + + T = np.arange(0, 1, 0.01) + S = np.cos(2 * T * 2 * np.pi) + + # Input - linear function f(t) = 2*t + output = emd.emd(emd, S, T) + self.assertEqual(len(output), 4, "Expecting 4 outputs - IMF, EXT, ITER, imfNo") + + IMF, EXT, ITER, imfNo = output + self.assertEqual(len(IMF), 2, "Expecting single IMF + residue") + self.assertEqual(len(IMF[0]), len(S), "Expecting single IMF") + self.assertTrue(np.allclose(S, IMF[0])) + self.assertLessEqual(ITER[0], 5, "Expecting 5 iterations at most") + self.assertEqual(imfNo, 2, "Expecting 1 IMF") + self.assertEqual(EXT[0], 3, "Expecting single EXT") + + +if __name__ == "__main__": + unittest.main()