Package mvpa :: Package mappers :: Module wavelet
[hide private]
[frames] | no frames]

Source Code for Module mvpa.mappers.wavelet

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Wavelet mappers""" 
 10   
 11  from mvpa.base import externals 
 12   
 13  if externals.exists('pywt', raiseException=True): 
 14      # import conditional to be able to import the whole module while building 
 15      # the docs even if pywt is not installed 
 16      import pywt 
 17   
 18  import numpy as N 
 19   
 20  from mvpa.base import warning 
 21  from mvpa.mappers.base import Mapper 
 22  from mvpa.base.dochelpers import enhancedDocString 
 23   
 24  if __debug__: 
 25      from mvpa.base import debug 
 26   
 27  # WaveletPacket and WaveletTransformation mappers share lots of common 
 28  # functionality at the moment 
 29   
30 -class _WaveletMapper(Mapper):
31 """Generic class for Wavelet mappers (decomposition and packet) 32 """ 33
34 - def __init__(self, dim=1, wavelet='sym4', mode='per', maxlevel=None):
35 """Initialize _WaveletMapper mapper 36 37 :Parameters: 38 dim : int or tuple of int 39 dimensions to work across (for now just scalar value, ie 1D 40 transformation) is supported 41 wavelet : basestring 42 one from the families available withing pywt package 43 mode : basestring 44 periodization mode 45 maxlevel : int or None 46 number of levels to use. If None - automatically selected by pywt 47 """ 48 Mapper.__init__(self) 49 50 self._dim = dim 51 """Dimension to work along""" 52 53 self._maxlevel = maxlevel 54 """Maximal level of decomposition. None for automatic""" 55 56 if not wavelet in pywt.wavelist(): 57 raise ValueError, \ 58 "Unknown family of wavelets '%s'. Please use one " \ 59 "available from the list %s" % (wavelet, pywt.wavelist()) 60 self._wavelet = wavelet 61 """Wavelet family to use""" 62 63 if not mode in pywt.MODES.modes: 64 raise ValueError, \ 65 "Unknown periodization mode '%s'. Please use one " \ 66 "available from the list %s" % (mode, pywt.MODES.modes) 67 self._mode = mode 68 """Periodization mode"""
69 70
71 - def forward(self, data):
72 data = N.asanyarray(data) 73 self._inshape = data.shape 74 self._intimepoints = data.shape[self._dim] 75 res = self._forward(data) 76 self._outshape = res.shape 77 return res
78 79
80 - def reverse(self, data):
81 data = N.asanyarray(data) 82 return self._reverse(data)
83 84
85 - def _forward(self, *args):
86 raise NotImplementedError
87 88
89 - def _reverse(self, *args):
90 raise NotImplementedError
91 92
93 - def getInSize(self):
94 """Returns the number of original features.""" 95 return self._inshape[1:]
96 97
98 - def getOutSize(self):
99 """Returns the number of wavelet components.""" 100 return self._outshape[1:]
101 102
103 - def selectOut(self, outIds):
104 """Choose a subset of components... 105 106 just use MaskMapper on top?""" 107 raise NotImplementedError, "Please use in conjunction with MaskMapper"
108 109 110 __doc__ = enhancedDocString('_WaveletMapper', locals(), Mapper)
111 112
113 -def _getIndexes(shape, dim):
114 """Generator for coordinate tuples providing slice for all in `dim` 115 116 XXX Somewhat sloppy implementation... but works... 117 """ 118 if len(shape) < dim: 119 raise ValueError, "Dimension %d is incorrect for a shape %s" % \ 120 (dim, shape) 121 n = len(shape) 122 curindexes = [0] * n 123 curindexes[dim] = Ellipsis#slice(None) # all elements for dimension dim 124 while True: 125 yield tuple(curindexes) 126 for i in xrange(n): 127 if i == dim and dim == n-1: 128 return # we reached it -- thus time to go 129 if curindexes[i] == shape[i] - 1: 130 if i == n-1: 131 return 132 curindexes[i] = 0 133 else: 134 if i != dim: 135 curindexes[i] += 1 136 break
137 138
139 -class WaveletPacketMapper(_WaveletMapper):
140 """Convert signal into an overcomplete representaion using Wavelet packet 141 """ 142
143 - def __init__(self, level=None, **kwargs):
144 """Initialize WaveletPacketMapper mapper 145 146 :Parameters: 147 level : int or None 148 What level to decompose at. If 'None' data for all levels 149 is provided, but due to different sizes, they are placed 150 in 1D row. 151 """ 152 153 _WaveletMapper.__init__(self,**kwargs) 154 155 self.__level = level
156 157 158 # XXX too much of duplications between such methods -- it begs 159 # refactoring
160 - def __forwardSingleLevel(self, data):
161 if __debug__: 162 debug('MAP', "Converting signal using DWP (single level)") 163 164 wp = None 165 166 level = self.__level 167 wavelet = self._wavelet 168 mode = self._mode 169 dim = self._dim 170 171 level_paths = None 172 for indexes in _getIndexes(data.shape, self._dim): 173 if __debug__: 174 debug('MAP_', " %s" % (indexes,), lf=False, cr=True) 175 WP = pywt.WaveletPacket( 176 data[indexes], wavelet=wavelet, 177 mode=mode, maxlevel=level) 178 179 level_nodes = WP.get_level(level) 180 if level_paths is None: 181 # Needed for reconstruction 182 self.__level_paths = N.array([node.path for node in level_nodes]) 183 level_datas = N.array([node.data for node in level_nodes]) 184 185 if wp is None: 186 newdim = data.shape 187 newdim = newdim[:dim] + level_datas.shape + newdim[dim+1:] 188 if __debug__: 189 debug('MAP_', "Initializing storage of size %s for single " 190 "level (%d) mapping of data of size %s" % (newdim, level, data.shape)) 191 wp = N.empty( tuple(newdim) ) 192 193 wp[indexes] = level_datas 194 195 return wp
196 197
198 - def __forwardMultipleLevels(self, data):
199 wp = None 200 levels_length = None # total length at each level 201 levels_lengths = None # list of lengths per each level 202 for indexes in _getIndexes(data.shape, self._dim): 203 if __debug__: 204 debug('MAP_', " %s" % (indexes,), lf=False, cr=True) 205 WP = pywt.WaveletPacket( 206 data[indexes], 207 wavelet=self._wavelet, 208 mode=self._mode, maxlevel=self._maxlevel) 209 210 if levels_length is None: 211 levels_length = [None] * WP.maxlevel 212 levels_lengths = [None] * WP.maxlevel 213 214 levels_datas = [] 215 for level in xrange(WP.maxlevel): 216 level_nodes = WP.get_level(level+1) 217 level_datas = [node.data for node in level_nodes] 218 219 level_lengths = [len(x) for x in level_datas] 220 level_length = N.sum(level_lengths) 221 222 if levels_lengths[level] is None: 223 levels_lengths[level] = level_lengths 224 elif levels_lengths[level] != level_lengths: 225 raise RuntimeError, \ 226 "ADs of same level of different samples should have same number of elements." \ 227 " Got %s, was %s" % (level_lengths, levels_lengths[level]) 228 229 if levels_length[level] is None: 230 levels_length[level] = level_length 231 elif levels_length[level] != level_length: 232 raise RuntimeError, \ 233 "Levels of different samples should have same number of elements." \ 234 " Got %d, was %d" % (level_length, levels_length[level]) 235 236 level_data = N.hstack(level_datas) 237 levels_datas.append(level_data) 238 239 # assert(len(data) == levels_length) 240 # assert(len(data) >= Ntimepoints) 241 if wp is None: 242 newdim = list(data.shape) 243 newdim[self._dim] = N.sum(levels_length) 244 wp = N.empty( tuple(newdim) ) 245 wp[indexes] = N.hstack(levels_datas) 246 247 self.levels_lengths, self.levels_length = levels_lengths, levels_length 248 if __debug__: 249 debug('MAP_', "") 250 debug('MAP', "Done convertion into wp. Total size %s" % str(wp.shape)) 251 return wp
252 253
254 - def _forward(self, data):
255 if __debug__: 256 debug('MAP', "Converting signal using DWP") 257 258 if self.__level is None: 259 return self.__forwardMultipleLevels(data) 260 else: 261 return self.__forwardSingleLevel(data)
262 263 # 264 # Reverse mapping 265 #
266 - def __reverseSingleLevel(self, wp):
267 268 # local bindings 269 level_paths = self.__level_paths 270 271 # define wavelet packet to use 272 WP = pywt.WaveletPacket( 273 data=None, wavelet=self._wavelet, 274 mode=self._mode, maxlevel=self.__level) 275 276 # prepare storage 277 signal_shape = wp.shape[:1] + self.getInSize() 278 signal = N.zeros(signal_shape) 279 Ntime_points = self._intimepoints 280 for indexes in _getIndexes(signal_shape, 281 self._dim): 282 if __debug__: 283 debug('MAP_', " %s" % (indexes,), lf=False, cr=True) 284 285 for path, level_data in zip(level_paths, wp[indexes]): 286 WP[path] = level_data 287 288 signal[indexes] = WP.reconstruct(True)[:Ntime_points] 289 290 return signal
291 292
293 - def _reverse(self, data):
294 if __debug__: 295 debug('MAP', "Converting signal back using DWP") 296 297 if self.__level is None: 298 raise NotImplementedError 299 else: 300 if not externals.exists('pywt wp reconstruct'): 301 raise NotImplementedError, \ 302 "Reconstruction for a single level for versions of " \ 303 "pywt < 0.1.7 (revision 103) is not supported" 304 if not externals.exists('pywt wp reconstruct fixed'): 305 warning("Reconstruction using available version of pywt might " 306 "result in incorrect data in the tails of the signal") 307 return self.__reverseSingleLevel(data)
308 309 310 311 312
313 -class WaveletTransformationMapper(_WaveletMapper):
314 """Convert signal into wavelet representaion 315 """ 316
317 - def _forward(self, data):
318 """Decompose signal into wavelets's coefficients via dwt 319 """ 320 if __debug__: 321 debug('MAP', "Converting signal using DWT") 322 wd = None 323 coeff_lengths = None 324 for indexes in _getIndexes(data.shape, self._dim): 325 if __debug__: 326 debug('MAP_', " %s" % (indexes,), lf=False, cr=True) 327 coeffs = pywt.wavedec( 328 data[indexes], 329 wavelet=self._wavelet, 330 mode=self._mode, 331 level=self._maxlevel) 332 # Silly Yarik embedds extraction of statistics right in place 333 #stats = [] 334 #for coeff in coeffs: 335 # stats_ = [N.std(coeff), 336 # N.sqrt(N.dot(coeff, coeff)), 337 # ]# + list(N.histogram(coeff, normed=True)[0])) 338 # stats__ = list(coeff) + stats_[:] 339 # stats__ += list(N.log(stats_)) 340 # stats__ += list(N.sqrt(stats_)) 341 # stats__ += list(N.array(stats_)**2) 342 # stats__ += [ N.median(coeff), N.mean(coeff), scipy.stats.kurtosis(coeff) ] 343 # stats.append(stats__) 344 #coeffs = stats 345 coeff_lengths_ = N.array([len(x) for x in coeffs]) 346 if coeff_lengths is None: 347 coeff_lengths = coeff_lengths_ 348 assert((coeff_lengths == coeff_lengths_).all()) 349 if wd is None: 350 newdim = list(data.shape) 351 newdim[self._dim] = N.sum(coeff_lengths) 352 wd = N.empty( tuple(newdim) ) 353 coeff = N.hstack(coeffs) 354 wd[indexes] = coeff 355 if __debug__: 356 debug('MAP_', "") 357 debug('MAP', "Done DWT. Total size %s" % str(wd.shape)) 358 self.lengths = coeff_lengths 359 return wd
360 361
362 - def _reverse(self, wd):
363 if __debug__: 364 debug('MAP', "Performing iDWT") 365 signal = None 366 wd_offsets = [0] + list(N.cumsum(self.lengths)) 367 Nlevels = len(self.lengths) 368 Ntime_points = self._intimepoints #len(time_points) 369 # unfortunately sometimes due to padding iDWT would return longer 370 # sequences, thus we just limit to the right ones 371 372 for indexes in _getIndexes(wd.shape, self._dim): 373 if __debug__: 374 debug('MAP_', " %s" % (indexes,), lf=False, cr=True) 375 wd_sample = wd[indexes] 376 wd_coeffs = [wd_sample[wd_offsets[i]:wd_offsets[i+1]] for i in xrange(Nlevels)] 377 # need to compose original list 378 time_points = pywt.waverec( 379 wd_coeffs, wavelet=self._wavelet, mode=self._mode) 380 if signal is None: 381 newdim = list(wd.shape) 382 newdim[self._dim] = Ntime_points 383 signal = N.empty(newdim) 384 signal[indexes] = time_points[:Ntime_points] 385 if __debug__: 386 debug('MAP_', "") 387 debug('MAP', "Done iDWT. Total size %s" % (signal.shape, )) 388 return signal
389