Skip to content

Commit

Permalink
fix negative stride (torch); fix inferring nv in icwt
Browse files Browse the repository at this point in the history
  • Loading branch information
OverLordGoldDragon committed Nov 25, 2024
1 parent dbe8696 commit 7b72ddc
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
4 changes: 2 additions & 2 deletions ssqueezepy/_cwt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import numpy as np
from .utils import fft, ifft, ifftshift, FFT_GLOBAL
from .utils import WARN, adm_cwt, adm_ssq, _process_fs_and_t
from .utils import WARN, adm_cwt, adm_ssq, _process_fs_and_t, is_array_or_tensor
from .utils import padsignal, process_scales, logscale_transition_idx
from .utils import backend as S
from .utils.backend import Q
Expand Down Expand Up @@ -404,7 +404,7 @@ def icwt(Wx, wavelet='gmw', scales='log-piecewise', nv=None, one_int=True,
#### Prepare for inversion ###############################################
*_, na, n = Wx.shape
x_len = x_len or n
if not isinstance(scales, np.ndarray) and nv is None:
if not is_array_or_tensor(scales) and nv is None:
nv = 32 # must match forward's; default to `cwt`'s

wavelet = _process_gmw_wavelet(wavelet, l1_norm)
Expand Down
5 changes: 5 additions & 0 deletions ssqueezepy/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def is_tensor(*args, mode='all'):
return cond(isinstance(x, torch.Tensor) for x in args)


def is_array_or_tensor(*args, mode='all'):
cond = all if mode == 'all' else any
return cond(isinstance(x, (torch.Tensor, np.ndarray)) for x in args)


def is_dtype(x, str_dtype):
return (str_dtype in str(x.dtype) if isinstance(str_dtype, str) else
any(sd in str(x.dtype) for sd in str_dtype))
Expand Down
2 changes: 1 addition & 1 deletion ssqueezepy/utils/cwt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def integrate_analytic(int_fn, nowarn=False):
Integrates near zero separately in log space (useful for e.g. 1/x).
"""
def _est_arr(mxlim, N):
t = np.linspace(mxlim, .1, N, endpoint=False)[::-1]
t = np.linspace(mxlim, .1, N, endpoint=False)[::-1].copy()
arr = int_fn(t)

max_idx = np.argmax(arr)
Expand Down

0 comments on commit 7b72ddc

Please sign in to comment.