256 lines
8.7 KiB
Python
256 lines
8.7 KiB
Python
#!/usr/bin/env python
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from numpy.testing import assert_allclose
|
|
|
|
import pywt
|
|
from pywt import data
|
|
|
|
# tolerances used in accuracy comparisons
|
|
tol_single = 1e-6
|
|
tol_double = 1e-13
|
|
atol = 1e-7
|
|
|
|
|
|
####
|
|
# 1d mra tests
|
|
####
|
|
|
|
@pytest.mark.parametrize('wavelet', ['db2', 'sym4', 'coif5'])
|
|
@pytest.mark.parametrize('transform', ['dwt', 'swt'])
|
|
@pytest.mark.parametrize('mode', pywt.Modes.modes)
|
|
@pytest.mark.parametrize(
|
|
'dtype', ['float32', 'float64', 'complex64', 'complex128']
|
|
)
|
|
def test_mra_roundtrip(wavelet, transform, mode, dtype):
|
|
x = data.ecg()[:64].astype(dtype)
|
|
if x.dtype.kind == 'c':
|
|
# fill some data for the imaginary channel
|
|
x.imag = x[::-1].real
|
|
|
|
if transform == 'swt':
|
|
# swt mode only supports periodization
|
|
if mode != 'periodization':
|
|
with pytest.raises(ValueError):
|
|
pywt.mra(x, wavelet, transform=transform, mode=mode)
|
|
return
|
|
|
|
coeffs = pywt.mra(x, wavelet, transform=transform, mode=mode)
|
|
assert isinstance(coeffs, list)
|
|
assert isinstance(coeffs[0], np.ndarray)
|
|
# assert all(isinstance(coeffs[i], dict) for i in range(1, len(coeffs)))
|
|
|
|
y = pywt.imra(coeffs)
|
|
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
|
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
|
|
|
|
|
@pytest.mark.parametrize('wavelet', ['rbio1.3', 'bior2.4'])
|
|
@pytest.mark.parametrize('transform', ['dwt', 'swt'])
|
|
def test_mra_warns_on_non_orthogonal(wavelet, transform):
|
|
dtype = np.float64
|
|
x = data.ecg()[:64].astype(dtype)
|
|
|
|
assert not pywt.Wavelet(wavelet).orthogonal
|
|
|
|
if transform == 'swt':
|
|
# bi-orthogonal wavelets raise a warning for SWT case
|
|
msg = 'norm=True, but the wavelet is not orthogonal'
|
|
with pytest.warns(UserWarning, match=msg):
|
|
coeffs = pywt.mra(x, wavelet, transform=transform)
|
|
else:
|
|
coeffs = pywt.mra(x, wavelet, transform=transform)
|
|
|
|
y = pywt.imra(coeffs)
|
|
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
|
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
|
|
|
|
|
@pytest.mark.parametrize('axis', [0, -1, 1, 2, -3])
|
|
@pytest.mark.parametrize('ndim', [1, 2, 3])
|
|
@pytest.mark.parametrize('transform', ['dwt', 'swt'])
|
|
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
|
|
def test_mra_axis(transform, ndim, axis, dtype):
|
|
# Test transforms over a specific axis of 1D, 2D or 3D data
|
|
if ndim == 1:
|
|
x = data.ecg()[:64]
|
|
elif ndim == 2:
|
|
x = data.camera()[:64, :32]
|
|
elif ndim == 3:
|
|
x = data.camera()[:48, :8]
|
|
x = np.stack((x,) * 8, axis=-1)
|
|
x = x.astype(dtype, copy=False)
|
|
|
|
# out of range axis
|
|
if axis < -x.ndim or axis >= x.ndim:
|
|
with pytest.raises(np.AxisError):
|
|
pywt.mra(x, 'db1', transform=transform, axis=axis)
|
|
return
|
|
|
|
coeffs = pywt.mra(x, 'db1', transform=transform, axis=axis)
|
|
y = pywt.imra(coeffs)
|
|
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
|
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
|
|
|
|
|
####
|
|
# 2d mra tests
|
|
####
|
|
|
|
@pytest.mark.parametrize('wavelet', ['db2', 'sym4', 'coif5'])
|
|
@pytest.mark.parametrize('transform', ['dwt2', 'swt2'])
|
|
@pytest.mark.parametrize('mode', pywt.Modes.modes)
|
|
@pytest.mark.parametrize(
|
|
'dtype', ['float32', 'float64', 'complex64', 'complex128']
|
|
)
|
|
def test_mra2_roundtrip(wavelet, transform, mode, dtype):
|
|
x = data.camera()[:32, :16].astype(dtype, copy=False)
|
|
if x.dtype.kind == 'c':
|
|
# fill some data for the imaginary channel
|
|
x.imag = x[::-1, :].real
|
|
|
|
if transform == 'swt2':
|
|
# swt mode only supports periodization
|
|
if mode != 'periodization':
|
|
with pytest.raises(ValueError):
|
|
pywt.mra2(x, wavelet, transform=transform, mode=mode)
|
|
return
|
|
|
|
coeffs = pywt.mra2(x, wavelet, transform=transform, mode=mode)
|
|
assert isinstance(coeffs, list)
|
|
assert isinstance(coeffs[0], np.ndarray)
|
|
# assert all(isinstance(coeffs[i], dict) for i in range(1, len(coeffs)))
|
|
|
|
y = pywt.imra2(coeffs)
|
|
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
|
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
|
|
|
|
|
@pytest.mark.parametrize('wavelet', ['rbio1.3', 'bior2.4'])
|
|
@pytest.mark.parametrize('transform', ['dwt2', 'swt2'])
|
|
def test_mra2_warns_on_non_orthogonal(wavelet, transform):
|
|
dtype = np.float64
|
|
x = data.camera()[:32, :8].astype(dtype, copy=False)
|
|
|
|
assert not pywt.Wavelet(wavelet).orthogonal
|
|
|
|
if transform == 'swt2':
|
|
# bi-orthogonal wavelets raise a warning for SWT case
|
|
msg = 'norm=True, but the wavelets used are not orthogonal'
|
|
with pytest.warns(UserWarning, match=msg):
|
|
coeffs = pywt.mra2(x, wavelet, transform=transform)
|
|
else:
|
|
coeffs = pywt.mra2(x, wavelet, transform=transform)
|
|
|
|
y = pywt.imra2(coeffs)
|
|
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
|
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
|
|
|
|
|
@pytest.mark.parametrize('transform', ['dwt2', 'swt2'])
|
|
@pytest.mark.parametrize('ndim', [2, 3])
|
|
@pytest.mark.parametrize('axes', [(0, 1), (-2, -1), (0, 2), (-3, 1), (0, 4)])
|
|
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
|
|
def test_mra2_axes(transform, axes, ndim, dtype):
|
|
# Test transforms over various axes of 2D or 3D data.
|
|
x = data.camera()[:32, :16].astype(dtype, copy=False)
|
|
if ndim == 3:
|
|
x = np.stack((x,) * 8, axis=-1)
|
|
|
|
# out of range axis
|
|
if any([axis < -x.ndim or axis >= x.ndim for axis in axes]):
|
|
with pytest.raises(np.AxisError):
|
|
pywt.mra2(x, 'db1', transform=transform, axes=axes)
|
|
return
|
|
|
|
coeffs = pywt.mra2(x, 'db1', transform=transform, axes=axes)
|
|
y = pywt.imra2(coeffs)
|
|
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
|
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
|
|
|
|
|
####
|
|
# nd mra tests
|
|
####
|
|
|
|
@pytest.mark.parametrize('wavelet', ['sym2', ])
|
|
@pytest.mark.parametrize('transform', ['dwtn', 'swtn'])
|
|
@pytest.mark.parametrize('mode', pywt.Modes.modes)
|
|
@pytest.mark.parametrize(
|
|
'dtype', ['float32', 'float64', 'complex64', 'complex128']
|
|
)
|
|
@pytest.mark.parametrize('ndim', [1, 2, 3])
|
|
def test_mran_roundtrip(wavelet, transform, mode, dtype, ndim):
|
|
if ndim == 1:
|
|
x = data.ecg()[:48].astype(dtype, copy=False)
|
|
elif ndim == 2:
|
|
x = data.camera()[:16, :8].astype(dtype, copy=False)
|
|
elif ndim == 3:
|
|
x = data.camera()[:16, :8].astype(dtype, copy=False)
|
|
x = np.stack((x,) * 8, axis=-1)
|
|
|
|
if x.dtype.kind == 'c':
|
|
# fill some data for the imaginary channel
|
|
x.imag = x[::-1, ...].real
|
|
|
|
if transform == 'swtn':
|
|
# swt mode only supports periodization
|
|
if mode != 'periodization':
|
|
with pytest.raises(ValueError):
|
|
pywt.mran(x, wavelet, transform=transform, mode=mode)
|
|
return
|
|
|
|
coeffs = pywt.mran(x, wavelet, transform=transform, mode=mode)
|
|
assert isinstance(coeffs, list)
|
|
assert isinstance(coeffs[0], np.ndarray)
|
|
# assert all(isinstance(coeffs[i], dict) for i in range(1, len(coeffs)))
|
|
|
|
y = pywt.imran(coeffs)
|
|
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
|
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
|
|
|
|
|
@pytest.mark.parametrize('wavelet', ['rbio1.3', 'bior2.4'])
|
|
@pytest.mark.parametrize('transform', ['dwtn', 'swtn'])
|
|
def test_mran_warns_on_non_orthogonal(wavelet, transform):
|
|
dtype = np.float64
|
|
x = data.camera()[:32, :8].astype(dtype, copy=False)
|
|
|
|
assert not pywt.Wavelet(wavelet).orthogonal
|
|
|
|
if transform == 'swtn':
|
|
# bi-orthogonal wavelets raise a warning for SWT case
|
|
msg = 'norm=True, but the wavelets used are not orthogonal'
|
|
with pytest.warns(UserWarning, match=msg):
|
|
coeffs = pywt.mran(x, wavelet, transform=transform)
|
|
else:
|
|
coeffs = pywt.mran(x, wavelet, transform=transform)
|
|
|
|
y = pywt.imran(coeffs)
|
|
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
|
|
assert_allclose(x, y, rtol=rtol, atol=rtol)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
'axes', [(0, 1), (-2, -1), (0, 2), (-3, 1), (0, 4), (-3, -2, -1),
|
|
(0, 2, 1), (0, 5, 1), (0,), (1,), (2,), (-2,), (-3,), (-4,)])
|
|
@pytest.mark.parametrize('transform', ['dwtn', 'swtn'])
|
|
def test_mran_axes(axes, transform):
|
|
# Test with transforms over 1, 2 or 3 axes of 3d data.
|
|
# Cases with out of range axes are also tested
|
|
dtype = np.float64
|
|
x = data.camera()[:32, :16].astype(dtype, copy=False)
|
|
x3d = np.stack((x,) * 8, axis=-1)
|
|
|
|
# out of range axis
|
|
if any([axis < -x.ndim or axis >= x.ndim for axis in axes]):
|
|
with pytest.raises(np.AxisError):
|
|
pywt.mran(x, 'db1', transform='dwtn', axes=axes)
|
|
return
|
|
|
|
coeffs = pywt.mran(x3d, 'db1', transform='dwtn', axes=axes)
|
|
y = pywt.imran(coeffs)
|
|
rtol = tol_single if x3d.real.dtype.kind == 'f' else tol_double
|
|
assert_allclose(x3d, y, rtol=rtol, atol=rtol)
|