172 lines
6.1 KiB
Python
172 lines
6.1 KiB
Python
#!/usr/bin/env python
|
|
|
|
from __future__ import division, print_function, absolute_import
|
|
|
|
from itertools import product
|
|
from functools import reduce
|
|
import operator
|
|
import numpy as np
|
|
from numpy.testing import (assert_allclose, assert_, assert_raises,
|
|
assert_equal)
|
|
|
|
import pywt
|
|
|
|
|
|
def test_traversing_tree_nd():
|
|
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
|
wp = pywt.WaveletPacketND(data=x, wavelet='db1', mode='symmetric')
|
|
|
|
assert_(np.all(wp.data == x))
|
|
assert_(wp.path == '')
|
|
assert_(wp.level == 0)
|
|
assert_(wp.maxlevel == 3)
|
|
|
|
assert_allclose(wp['aa'].data, np.array([[3., 7., 11., 15.]] * 4),
|
|
rtol=1e-12)
|
|
assert_allclose(wp['da'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
|
|
assert_allclose(wp['ad'].data, -np.ones((4, 4)), rtol=1e-12, atol=1e-14)
|
|
assert_allclose(wp['dd'].data, np.zeros((4, 4)), rtol=1e-12, atol=1e-14)
|
|
|
|
assert_allclose(wp['aa'*2].data, np.array([[10., 26.]] * 2), rtol=1e-12)
|
|
# __getitem__ using a tuple acces instead
|
|
assert_allclose(wp[('aa', 'aa')].data, np.array([[10., 26.]] * 2),
|
|
rtol=1e-12)
|
|
|
|
assert_(wp['aa']['aa'].data is wp['aa'*2].data)
|
|
assert_allclose(wp['aa'*3].data, np.array([[36.]]), rtol=1e-12)
|
|
|
|
assert_raises(IndexError, lambda: wp['aa'*(wp.maxlevel+1)])
|
|
assert_raises(ValueError, lambda: wp['f'])
|
|
|
|
# getitem input must be a string or tuple of strings
|
|
assert_raises(TypeError, wp.__getitem__, (5, 3))
|
|
assert_raises(TypeError, wp.__getitem__, 5)
|
|
|
|
|
|
def test_accessing_node_attributes_nd():
|
|
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
|
wp = pywt.WaveletPacketND(data=x, wavelet='db1', mode='symmetric')
|
|
|
|
assert_allclose(wp['aa'+'ad'].data, np.zeros((2, 2)) - 4, rtol=1e-12)
|
|
assert_(wp['aa'+'ad'].path == 'aa'+'ad')
|
|
assert_(wp['aa'+'ad'].node_name == 'ad')
|
|
assert_(wp['aa'+'ad'].parent.path == 'aa')
|
|
|
|
assert_allclose(wp['aa'+'ad'].parent.data,
|
|
np.array([[3., 7., 11., 15.]] * 4), rtol=1e-12)
|
|
# can also index via a tuple instead of concatenated strings
|
|
assert_(wp[('aa', 'ad')].level == 2)
|
|
assert_(wp[('aa', 'ad')].maxlevel == 3)
|
|
assert_(wp[('aa', 'ad')].mode == 'symmetric')
|
|
|
|
# can access a node's path as either a single string or in tuple form
|
|
node = wp[('ad', 'dd')]
|
|
assert_(node.path == 'addd')
|
|
assert_(node.path_tuple == ('ad', 'dd'))
|
|
|
|
|
|
def test_collecting_nodes_nd():
|
|
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
|
wp = pywt.WaveletPacketND(data=x, wavelet='db1', mode='symmetric')
|
|
|
|
assert_(len(wp.get_level(0)) == 1)
|
|
assert_(wp.get_level(0)[0].path == '')
|
|
|
|
# First level
|
|
assert_(len(wp.get_level(1)) == 4)
|
|
assert_(
|
|
[node.path for node in wp.get_level(1)] == ['aa', 'ad', 'da', 'dd'])
|
|
|
|
# Second and third levels
|
|
for lev in [2, 3]:
|
|
assert_(len(wp.get_level(lev)) == (2**x.ndim)**lev)
|
|
paths = [node.path for node in wp.get_level(lev)]
|
|
expected_paths = [
|
|
reduce(operator.add, p) for
|
|
p in sorted(product(['aa', 'ad', 'da', 'dd'], repeat=lev))]
|
|
assert_(paths == expected_paths)
|
|
|
|
|
|
def test_data_reconstruction_delete_nodes_nd():
|
|
x = np.array([[1, 2, 3, 4, 5, 6, 7, 8]] * 8, dtype=np.float64)
|
|
wp = pywt.WaveletPacketND(data=x, wavelet='db1', mode='symmetric')
|
|
|
|
# The user must supply either data or axes
|
|
assert_raises(ValueError, pywt.WaveletPacketND, data=None, wavelet='db1',
|
|
axes=None)
|
|
|
|
new_wp = pywt.WaveletPacketND(data=None, wavelet='db1', mode='symmetric',
|
|
axes=range(x.ndim))
|
|
|
|
new_wp['ad'+'da'] = wp['ad'+'da'].data
|
|
new_wp['ad'*2] = wp['ad'+'da'].data
|
|
new_wp['ad'+'dd'] = np.zeros((2, 2), dtype=np.float64)
|
|
new_wp['aa'] = [[3.0, 7.0, 11.0, 15.0]] * 4
|
|
new_wp['dd'] = np.zeros((4, 4), dtype=np.float64)
|
|
new_wp['da'] = wp['da'] # all zeros
|
|
|
|
assert_allclose(new_wp.reconstruct(update=False),
|
|
np.array([[1.5, 1.5, 3.5, 3.5, 5.5, 5.5, 7.5, 7.5]] * 8),
|
|
rtol=1e-12)
|
|
|
|
new_wp['ad'+'aa'] = wp['ad'+'aa'].data
|
|
assert_allclose(new_wp.reconstruct(update=False), x, rtol=1e-12)
|
|
|
|
del(new_wp['ad'+'aa'])
|
|
# TypeError on accessing deleted node
|
|
assert_raises(TypeError, lambda: new_wp['ad'+'aa'])
|
|
|
|
new_wp['ad'+'aa'] = wp['ad'+'aa'].data
|
|
assert_(new_wp.data is None)
|
|
|
|
assert_allclose(new_wp.reconstruct(update=True), x, rtol=1e-12)
|
|
assert_allclose(new_wp.data, x, rtol=1e-12)
|
|
|
|
# TODO: decompose=True
|
|
|
|
|
|
def test_wavelet_packet_dtypes():
|
|
shape = (16, 8, 8)
|
|
for dtype in [np.float32, np.float64, np.complex64, np.complex128]:
|
|
x = np.random.randn(*shape).astype(dtype)
|
|
if np.iscomplexobj(x):
|
|
x = x + 1j*np.random.randn(*shape).astype(x.real.dtype)
|
|
wp = pywt.WaveletPacketND(data=x, wavelet='db1', mode='symmetric')
|
|
# no unnecessary copy made
|
|
assert_(wp.data is x)
|
|
|
|
# full decomposition
|
|
wp.get_level(wp.maxlevel)
|
|
|
|
# reconstruction from coefficients should preserve dtype
|
|
r = wp.reconstruct(False)
|
|
assert_equal(r.dtype, x.dtype)
|
|
assert_allclose(r, x, atol=1e-6, rtol=1e-6)
|
|
|
|
|
|
def test_wavelet_packet_axes():
|
|
rstate = np.random.RandomState(0)
|
|
shape = (32, 16, 8)
|
|
x = rstate.standard_normal(shape)
|
|
for axes in [(0, 1), 1, (-3, -2, -1), (0, 2), (1, )]:
|
|
wp = pywt.WaveletPacketND(data=x, wavelet='db1', mode='symmetric',
|
|
axes=axes)
|
|
|
|
# partial decomposition
|
|
nodes = wp.get_level(1)
|
|
# size along the transformed axes has changed
|
|
for ax2 in range(x.ndim):
|
|
if ax2 in tuple(np.atleast_1d(axes) % x.ndim):
|
|
nodes[0].data.shape[ax2] < x.shape[ax2]
|
|
else:
|
|
nodes[0].data.shape[ax2] == x.shape[ax2]
|
|
|
|
# recontsruction from coefficients should preserve dtype
|
|
r = wp.reconstruct(False)
|
|
assert_equal(r.dtype, x.dtype)
|
|
assert_allclose(r, x, atol=1e-12, rtol=1e-12)
|
|
|
|
# must have non-duplicate axes
|
|
assert_raises(ValueError, pywt.WaveletPacketND, data=x, wavelet='db1',
|
|
axes=(0, 0))
|