1
2
3
4
5
6
7
8
9 """Wavelet mappers"""
10
11 from mvpa.base import externals
12
13 if externals.exists('pywt', raiseException=True):
14
15
16 import pywt
17
18 import numpy as N
19
20 from mvpa.base import warning
21 from mvpa.mappers.base import Mapper
22 from mvpa.base.dochelpers import enhancedDocString
23
24 if __debug__:
25 from mvpa.base import debug
26
27
28
29
31 """Generic class for Wavelet mappers (decomposition and packet)
32 """
33
34 - def __init__(self, dim=1, wavelet='sym4', mode='per', maxlevel=None):
35 """Initialize _WaveletMapper mapper
36
37 :Parameters:
38 dim : int or tuple of int
39 dimensions to work across (for now just scalar value, ie 1D
40 transformation) is supported
41 wavelet : basestring
42 one from the families available withing pywt package
43 mode : basestring
44 periodization mode
45 maxlevel : int or None
46 number of levels to use. If None - automatically selected by pywt
47 """
48 Mapper.__init__(self)
49
50 self._dim = dim
51 """Dimension to work along"""
52
53 self._maxlevel = maxlevel
54 """Maximal level of decomposition. None for automatic"""
55
56 if not wavelet in pywt.wavelist():
57 raise ValueError, \
58 "Unknown family of wavelets '%s'. Please use one " \
59 "available from the list %s" % (wavelet, pywt.wavelist())
60 self._wavelet = wavelet
61 """Wavelet family to use"""
62
63 if not mode in pywt.MODES.modes:
64 raise ValueError, \
65 "Unknown periodization mode '%s'. Please use one " \
66 "available from the list %s" % (mode, pywt.MODES.modes)
67 self._mode = mode
68 """Periodization mode"""
69
70
72 data = N.asanyarray(data)
73 self._inshape = data.shape
74 self._intimepoints = data.shape[self._dim]
75 res = self._forward(data)
76 self._outshape = res.shape
77 return res
78
79
83
84
86 raise NotImplementedError
87
88
90 raise NotImplementedError
91
92
94 """Returns the number of original features."""
95 return self._inshape[1:]
96
97
99 """Returns the number of wavelet components."""
100 return self._outshape[1:]
101
102
104 """Choose a subset of components...
105
106 just use MaskMapper on top?"""
107 raise NotImplementedError, "Please use in conjunction with MaskMapper"
108
109
110 __doc__ = enhancedDocString('_WaveletMapper', locals(), Mapper)
111
112
114 """Generator for coordinate tuples providing slice for all in `dim`
115
116 XXX Somewhat sloppy implementation... but works...
117 """
118 if len(shape) < dim:
119 raise ValueError, "Dimension %d is incorrect for a shape %s" % \
120 (dim, shape)
121 n = len(shape)
122 curindexes = [0] * n
123 curindexes[dim] = Ellipsis
124 while True:
125 yield tuple(curindexes)
126 for i in xrange(n):
127 if i == dim and dim == n-1:
128 return
129 if curindexes[i] == shape[i] - 1:
130 if i == n-1:
131 return
132 curindexes[i] = 0
133 else:
134 if i != dim:
135 curindexes[i] += 1
136 break
137
138
140 """Convert signal into an overcomplete representaion using Wavelet packet
141 """
142
143 - def __init__(self, level=None, **kwargs):
144 """Initialize WaveletPacketMapper mapper
145
146 :Parameters:
147 level : int or None
148 What level to decompose at. If 'None' data for all levels
149 is provided, but due to different sizes, they are placed
150 in 1D row.
151 """
152
153 _WaveletMapper.__init__(self,**kwargs)
154
155 self.__level = level
156
157
158
159
161 if __debug__:
162 debug('MAP', "Converting signal using DWP (single level)")
163
164 wp = None
165
166 level = self.__level
167 wavelet = self._wavelet
168 mode = self._mode
169 dim = self._dim
170
171 level_paths = None
172 for indexes in _getIndexes(data.shape, self._dim):
173 if __debug__:
174 debug('MAP_', " %s" % (indexes,), lf=False, cr=True)
175 WP = pywt.WaveletPacket(
176 data[indexes], wavelet=wavelet,
177 mode=mode, maxlevel=level)
178
179 level_nodes = WP.get_level(level)
180 if level_paths is None:
181
182 self.__level_paths = N.array([node.path for node in level_nodes])
183 level_datas = N.array([node.data for node in level_nodes])
184
185 if wp is None:
186 newdim = data.shape
187 newdim = newdim[:dim] + level_datas.shape + newdim[dim+1:]
188 if __debug__:
189 debug('MAP_', "Initializing storage of size %s for single "
190 "level (%d) mapping of data of size %s" % (newdim, level, data.shape))
191 wp = N.empty( tuple(newdim) )
192
193 wp[indexes] = level_datas
194
195 return wp
196
197
199 wp = None
200 levels_length = None
201 levels_lengths = None
202 for indexes in _getIndexes(data.shape, self._dim):
203 if __debug__:
204 debug('MAP_', " %s" % (indexes,), lf=False, cr=True)
205 WP = pywt.WaveletPacket(
206 data[indexes],
207 wavelet=self._wavelet,
208 mode=self._mode, maxlevel=self._maxlevel)
209
210 if levels_length is None:
211 levels_length = [None] * WP.maxlevel
212 levels_lengths = [None] * WP.maxlevel
213
214 levels_datas = []
215 for level in xrange(WP.maxlevel):
216 level_nodes = WP.get_level(level+1)
217 level_datas = [node.data for node in level_nodes]
218
219 level_lengths = [len(x) for x in level_datas]
220 level_length = N.sum(level_lengths)
221
222 if levels_lengths[level] is None:
223 levels_lengths[level] = level_lengths
224 elif levels_lengths[level] != level_lengths:
225 raise RuntimeError, \
226 "ADs of same level of different samples should have same number of elements." \
227 " Got %s, was %s" % (level_lengths, levels_lengths[level])
228
229 if levels_length[level] is None:
230 levels_length[level] = level_length
231 elif levels_length[level] != level_length:
232 raise RuntimeError, \
233 "Levels of different samples should have same number of elements." \
234 " Got %d, was %d" % (level_length, levels_length[level])
235
236 level_data = N.hstack(level_datas)
237 levels_datas.append(level_data)
238
239
240
241 if wp is None:
242 newdim = list(data.shape)
243 newdim[self._dim] = N.sum(levels_length)
244 wp = N.empty( tuple(newdim) )
245 wp[indexes] = N.hstack(levels_datas)
246
247 self.levels_lengths, self.levels_length = levels_lengths, levels_length
248 if __debug__:
249 debug('MAP_', "")
250 debug('MAP', "Done convertion into wp. Total size %s" % str(wp.shape))
251 return wp
252
253
262
263
264
265
267
268
269 level_paths = self.__level_paths
270
271
272 WP = pywt.WaveletPacket(
273 data=None, wavelet=self._wavelet,
274 mode=self._mode, maxlevel=self.__level)
275
276
277 signal_shape = wp.shape[:1] + self.getInSize()
278 signal = N.zeros(signal_shape)
279 Ntime_points = self._intimepoints
280 for indexes in _getIndexes(signal_shape,
281 self._dim):
282 if __debug__:
283 debug('MAP_', " %s" % (indexes,), lf=False, cr=True)
284
285 for path, level_data in zip(level_paths, wp[indexes]):
286 WP[path] = level_data
287
288 signal[indexes] = WP.reconstruct(True)[:Ntime_points]
289
290 return signal
291
292
294 if __debug__:
295 debug('MAP', "Converting signal back using DWP")
296
297 if self.__level is None:
298 raise NotImplementedError
299 else:
300 if not externals.exists('pywt wp reconstruct'):
301 raise NotImplementedError, \
302 "Reconstruction for a single level for versions of " \
303 "pywt < 0.1.7 (revision 103) is not supported"
304 if not externals.exists('pywt wp reconstruct fixed'):
305 warning("Reconstruction using available version of pywt might "
306 "result in incorrect data in the tails of the signal")
307 return self.__reverseSingleLevel(data)
308
309
310
311
312
389