microproduct/deformation-sentiral/smallbaselineApp/pywt/tests/test_mra.py

256 lines
8.7 KiB
Python

#!/usr/bin/env python
import numpy as np
import pytest
from numpy.testing import assert_allclose
import pywt
from pywt import data
# tolerances used in accuracy comparisons
tol_single = 1e-6
tol_double = 1e-13
atol = 1e-7
####
# 1d mra tests
####
@pytest.mark.parametrize('wavelet', ['db2', 'sym4', 'coif5'])
@pytest.mark.parametrize('transform', ['dwt', 'swt'])
@pytest.mark.parametrize('mode', pywt.Modes.modes)
@pytest.mark.parametrize(
'dtype', ['float32', 'float64', 'complex64', 'complex128']
)
def test_mra_roundtrip(wavelet, transform, mode, dtype):
x = data.ecg()[:64].astype(dtype)
if x.dtype.kind == 'c':
# fill some data for the imaginary channel
x.imag = x[::-1].real
if transform == 'swt':
# swt mode only supports periodization
if mode != 'periodization':
with pytest.raises(ValueError):
pywt.mra(x, wavelet, transform=transform, mode=mode)
return
coeffs = pywt.mra(x, wavelet, transform=transform, mode=mode)
assert isinstance(coeffs, list)
assert isinstance(coeffs[0], np.ndarray)
# assert all(isinstance(coeffs[i], dict) for i in range(1, len(coeffs)))
y = pywt.imra(coeffs)
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
assert_allclose(x, y, rtol=rtol, atol=rtol)
@pytest.mark.parametrize('wavelet', ['rbio1.3', 'bior2.4'])
@pytest.mark.parametrize('transform', ['dwt', 'swt'])
def test_mra_warns_on_non_orthogonal(wavelet, transform):
dtype = np.float64
x = data.ecg()[:64].astype(dtype)
assert not pywt.Wavelet(wavelet).orthogonal
if transform == 'swt':
# bi-orthogonal wavelets raise a warning for SWT case
msg = 'norm=True, but the wavelet is not orthogonal'
with pytest.warns(UserWarning, match=msg):
coeffs = pywt.mra(x, wavelet, transform=transform)
else:
coeffs = pywt.mra(x, wavelet, transform=transform)
y = pywt.imra(coeffs)
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
assert_allclose(x, y, rtol=rtol, atol=rtol)
@pytest.mark.parametrize('axis', [0, -1, 1, 2, -3])
@pytest.mark.parametrize('ndim', [1, 2, 3])
@pytest.mark.parametrize('transform', ['dwt', 'swt'])
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
def test_mra_axis(transform, ndim, axis, dtype):
# Test transforms over a specific axis of 1D, 2D or 3D data
if ndim == 1:
x = data.ecg()[:64]
elif ndim == 2:
x = data.camera()[:64, :32]
elif ndim == 3:
x = data.camera()[:48, :8]
x = np.stack((x,) * 8, axis=-1)
x = x.astype(dtype, copy=False)
# out of range axis
if axis < -x.ndim or axis >= x.ndim:
with pytest.raises(np.AxisError):
pywt.mra(x, 'db1', transform=transform, axis=axis)
return
coeffs = pywt.mra(x, 'db1', transform=transform, axis=axis)
y = pywt.imra(coeffs)
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
assert_allclose(x, y, rtol=rtol, atol=rtol)
####
# 2d mra tests
####
@pytest.mark.parametrize('wavelet', ['db2', 'sym4', 'coif5'])
@pytest.mark.parametrize('transform', ['dwt2', 'swt2'])
@pytest.mark.parametrize('mode', pywt.Modes.modes)
@pytest.mark.parametrize(
'dtype', ['float32', 'float64', 'complex64', 'complex128']
)
def test_mra2_roundtrip(wavelet, transform, mode, dtype):
x = data.camera()[:32, :16].astype(dtype, copy=False)
if x.dtype.kind == 'c':
# fill some data for the imaginary channel
x.imag = x[::-1, :].real
if transform == 'swt2':
# swt mode only supports periodization
if mode != 'periodization':
with pytest.raises(ValueError):
pywt.mra2(x, wavelet, transform=transform, mode=mode)
return
coeffs = pywt.mra2(x, wavelet, transform=transform, mode=mode)
assert isinstance(coeffs, list)
assert isinstance(coeffs[0], np.ndarray)
# assert all(isinstance(coeffs[i], dict) for i in range(1, len(coeffs)))
y = pywt.imra2(coeffs)
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
assert_allclose(x, y, rtol=rtol, atol=rtol)
@pytest.mark.parametrize('wavelet', ['rbio1.3', 'bior2.4'])
@pytest.mark.parametrize('transform', ['dwt2', 'swt2'])
def test_mra2_warns_on_non_orthogonal(wavelet, transform):
dtype = np.float64
x = data.camera()[:32, :8].astype(dtype, copy=False)
assert not pywt.Wavelet(wavelet).orthogonal
if transform == 'swt2':
# bi-orthogonal wavelets raise a warning for SWT case
msg = 'norm=True, but the wavelets used are not orthogonal'
with pytest.warns(UserWarning, match=msg):
coeffs = pywt.mra2(x, wavelet, transform=transform)
else:
coeffs = pywt.mra2(x, wavelet, transform=transform)
y = pywt.imra2(coeffs)
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
assert_allclose(x, y, rtol=rtol, atol=rtol)
@pytest.mark.parametrize('transform', ['dwt2', 'swt2'])
@pytest.mark.parametrize('ndim', [2, 3])
@pytest.mark.parametrize('axes', [(0, 1), (-2, -1), (0, 2), (-3, 1), (0, 4)])
@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
def test_mra2_axes(transform, axes, ndim, dtype):
# Test transforms over various axes of 2D or 3D data.
x = data.camera()[:32, :16].astype(dtype, copy=False)
if ndim == 3:
x = np.stack((x,) * 8, axis=-1)
# out of range axis
if any([axis < -x.ndim or axis >= x.ndim for axis in axes]):
with pytest.raises(np.AxisError):
pywt.mra2(x, 'db1', transform=transform, axes=axes)
return
coeffs = pywt.mra2(x, 'db1', transform=transform, axes=axes)
y = pywt.imra2(coeffs)
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
assert_allclose(x, y, rtol=rtol, atol=rtol)
####
# nd mra tests
####
@pytest.mark.parametrize('wavelet', ['sym2', ])
@pytest.mark.parametrize('transform', ['dwtn', 'swtn'])
@pytest.mark.parametrize('mode', pywt.Modes.modes)
@pytest.mark.parametrize(
'dtype', ['float32', 'float64', 'complex64', 'complex128']
)
@pytest.mark.parametrize('ndim', [1, 2, 3])
def test_mran_roundtrip(wavelet, transform, mode, dtype, ndim):
if ndim == 1:
x = data.ecg()[:48].astype(dtype, copy=False)
elif ndim == 2:
x = data.camera()[:16, :8].astype(dtype, copy=False)
elif ndim == 3:
x = data.camera()[:16, :8].astype(dtype, copy=False)
x = np.stack((x,) * 8, axis=-1)
if x.dtype.kind == 'c':
# fill some data for the imaginary channel
x.imag = x[::-1, ...].real
if transform == 'swtn':
# swt mode only supports periodization
if mode != 'periodization':
with pytest.raises(ValueError):
pywt.mran(x, wavelet, transform=transform, mode=mode)
return
coeffs = pywt.mran(x, wavelet, transform=transform, mode=mode)
assert isinstance(coeffs, list)
assert isinstance(coeffs[0], np.ndarray)
# assert all(isinstance(coeffs[i], dict) for i in range(1, len(coeffs)))
y = pywt.imran(coeffs)
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
assert_allclose(x, y, rtol=rtol, atol=rtol)
@pytest.mark.parametrize('wavelet', ['rbio1.3', 'bior2.4'])
@pytest.mark.parametrize('transform', ['dwtn', 'swtn'])
def test_mran_warns_on_non_orthogonal(wavelet, transform):
dtype = np.float64
x = data.camera()[:32, :8].astype(dtype, copy=False)
assert not pywt.Wavelet(wavelet).orthogonal
if transform == 'swtn':
# bi-orthogonal wavelets raise a warning for SWT case
msg = 'norm=True, but the wavelets used are not orthogonal'
with pytest.warns(UserWarning, match=msg):
coeffs = pywt.mran(x, wavelet, transform=transform)
else:
coeffs = pywt.mran(x, wavelet, transform=transform)
y = pywt.imran(coeffs)
rtol = tol_single if x.real.dtype.kind == 'f' else tol_double
assert_allclose(x, y, rtol=rtol, atol=rtol)
@pytest.mark.parametrize(
'axes', [(0, 1), (-2, -1), (0, 2), (-3, 1), (0, 4), (-3, -2, -1),
(0, 2, 1), (0, 5, 1), (0,), (1,), (2,), (-2,), (-3,), (-4,)])
@pytest.mark.parametrize('transform', ['dwtn', 'swtn'])
def test_mran_axes(axes, transform):
# Test with transforms over 1, 2 or 3 axes of 3d data.
# Cases with out of range axes are also tested
dtype = np.float64
x = data.camera()[:32, :16].astype(dtype, copy=False)
x3d = np.stack((x,) * 8, axis=-1)
# out of range axis
if any([axis < -x.ndim or axis >= x.ndim for axis in axes]):
with pytest.raises(np.AxisError):
pywt.mran(x, 'db1', transform='dwtn', axes=axes)
return
coeffs = pywt.mran(x3d, 'db1', transform='dwtn', axes=axes)
y = pywt.imran(coeffs)
rtol = tol_single if x3d.real.dtype.kind == 'f' else tol_double
assert_allclose(x3d, y, rtol=rtol, atol=rtol)