From c8296054b90f86ca2fd8d08c163d80c53cac69d1 Mon Sep 17 00:00:00 2001 From: John Muradeli Date: Thu, 25 Jul 2024 04:39:07 +0400 Subject: [PATCH] Support batched `icwt`; fix `icwt` with `scaletype='linear'` --- CHANGELOG.md | 8 ++++++++ ssqueezepy/__init__.py | 2 +- ssqueezepy/_cwt.py | 25 ++++++++++++++++++------- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 95773af..c52cb96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +### 0.6.6 + +#### FEATURES + - `icwt` now supports batched `Wx` (3D `Wx`, i.e. `cwt(x)` upon 2D `x`, `(n_inputs, n_times)`) + +#### FIXES + - `icwt` with `scaletype='linear'`: fix constant scaling factor + ### 0.6.5 #### FIXES diff --git a/ssqueezepy/__init__.py b/ssqueezepy/__init__.py index ce69f5d..2184efa 100644 --- a/ssqueezepy/__init__.py +++ b/ssqueezepy/__init__.py @@ -27,7 +27,7 @@ """ -__version__ = '0.6.5' +__version__ = '0.6.6-dev' __title__ = 'ssqueezepy' __author__ = 'John Muradeli' __license__ = __doc__ diff --git a/ssqueezepy/_cwt.py b/ssqueezepy/_cwt.py index eed8b05..4339b9d 100644 --- a/ssqueezepy/_cwt.py +++ b/ssqueezepy/_cwt.py @@ -327,6 +327,10 @@ def icwt(Wx, wavelet='gmw', scales='log-piecewise', nv=None, one_int=True, Wx: np.ndarray CWT computed via `ssqueezepy.cwt`. + - 2D: (n_scales, n_times) + - 3D: (n_inputs, n_scales, n_times). + Doesn't support `one_int=False`. + wavelet: str / tuple[str, dict] / `wavelets.Wavelet` Wavelet sampled in Fourier frequency domain. - str: name of builtin wavelet. `ssqueezepy.wavs()` @@ -352,6 +356,8 @@ def icwt(Wx, wavelet='gmw', scales='log-piecewise', nv=None, one_int=True, x_mean: float. mean of original `x` (not picked up in CWT since it's an infinite scale component). Default 0. + Note: if `Wx` is 3D, `x_mean` should be 1D (`x.mean()` along samples + axis). padtype: str Pad scheme to apply on input, in case of `one_int=False`. @@ -365,7 +371,9 @@ def icwt(Wx, wavelet='gmw', scales='log-piecewise', nv=None, one_int=True, # Returns: x: np.ndarray - The signal, as reconstructed from Wx. + The signal(s), as reconstructed from Wx. + + If `Wx` is 3D, `x` has shape `(n_inputs, n_times)`. # References: 1. One integral inverse CWT. John Muradeli. @@ -394,7 +402,7 @@ def icwt(Wx, wavelet='gmw', scales='log-piecewise', nv=None, one_int=True, synsq_cwt_iw.m """ #### Prepare for inversion ############################################### - na, n = Wx.shape + *_, na, n = Wx.shape x_len = x_len or n if not isinstance(scales, np.ndarray) and nv is None: nv = 32 # must match forward's; default to `cwt`'s @@ -414,8 +422,8 @@ def icwt(Wx, wavelet='gmw', scales='log-piecewise', nv=None, one_int=True, padtype=padtype, rpadded=rpadded, l1_norm=l1_norm) idx = logscale_transition_idx(scales) - x = icwt(Wx[:idx], scales=scales[:idx], **kw) - x += icwt(Wx[idx:], scales=scales[idx:], **kw) + x = icwt(Wx[..., :idx, :], scales=scales[:idx], **kw) + x += icwt(Wx[..., idx:, :], scales=scales[idx:], **kw) return x ########################################################################## @@ -423,6 +431,8 @@ def icwt(Wx, wavelet='gmw', scales='log-piecewise', nv=None, one_int=True, if one_int: x = _icwt_1int(Wx, scales, scaletype, l1_norm) else: + if Wx.ndim == 3: + raise NotImplementedError("batched `Wx` requires `one_int=True`.") x = _icwt_2int(Wx, scales, scaletype, l1_norm, wavelet, x_len, padtype, rpadded) @@ -430,11 +440,12 @@ def icwt(Wx, wavelet='gmw', scales='log-piecewise', nv=None, one_int=True, Cpsi = (adm_ssq(wavelet) if one_int else adm_cwt(wavelet)) if scaletype == 'log': - # Eq 4.67 in [1]; Theorem 4.5 in [1]; below Eq 14 in [2] + # Eq 4.67 in [3]; Theorem 4.5 in [3]; below Eq 14 in [5] # ln(2**(1/nv)) == ln(2)/nv == diff(ln(scales))[0] x *= (2 / Cpsi) * np.log(2 ** (1 / nv)) else: - x *= (2 / Cpsi) + # unclear why the `pi/4` here but it improves inversion + x *= (2 / Cpsi) * np.pi / 4 x += x_mean # CWT doesn't capture mean (infinite scale) return x @@ -466,7 +477,7 @@ def _icwt_2int(Wx, scales, scaletype, l1_norm, wavelet, x_len, def _icwt_1int(Wx, scales, scaletype, l1_norm): """One-integral iCWT; assumes analytic wavelet.""" norm = _icwt_norm(scaletype, l1_norm) - return (Wx.real / norm(scales)).sum(axis=0) + return (Wx.real / norm(scales)).sum(axis=-2) def _icwt_norm(scaletype, l1_norm):