Skip to content

Commit

Permalink
MAINT: add input validation to scales argument to cwt (#703)
Browse files Browse the repository at this point in the history
Co-authored-by: Ralf Gommers <ralf.gommers@gmail.com>
  • Loading branch information
cyschneck and rgommers authored Mar 8, 2024
1 parent c998ce3 commit ec9338f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
7 changes: 5 additions & 2 deletions pywt/_cwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,11 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
dt_cplx = np.result_type(dt, np.complex64)
if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
wavelet = DiscreteContinuousWavelet(wavelet)
if np.isscalar(scales):
scales = np.array([scales])

scales = np.atleast_1d(scales)
if np.any(scales <= 0):
raise ValueError("`scales` must only include positive values")

if not np.isscalar(axis):
raise np.AxisError("axis must be a scalar.")

Expand Down
14 changes: 14 additions & 0 deletions pywt/tests/test_cwt_wavelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,20 @@ def test_cwt_small_scales():
# extremely short scale factors raise a ValueError
assert_raises(ValueError, pywt.cwt, data, scales=0.01, wavelet='mexh')

def test_cwt_zero_scale():
data = np.zeros(32)
scales = np.arange(0, 4)

# scale that includes 0 throws ValueError to prevent IndexError
assert_raises(ValueError, pywt.cwt, data, scales=scales, wavelet='morl')

def test_cwt_negative_scale():
data = np.zeros(32)
scales = np.asarray([-1, -2, -3])

# scale that includes negative values throws ValueError to prevent IndexError
assert_raises(ValueError, pywt.cwt, data, scales=scales, wavelet='morl')


def test_cwt_method_fft():
rstate = np.random.RandomState(1)
Expand Down

0 comments on commit ec9338f

Please sign in to comment.