Package mvpa :: Package tests :: Module test_ifs
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_ifs

 1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
 2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
 3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
 4  # 
 5  #   See COPYING file distributed along with the PyMVPA package for the 
 6  #   copyright and license terms. 
 7  # 
 8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
 9  """Unit tests for PyMVPA incremental feature search.""" 
10   
11  from mvpa.datasets.masked import MaskedDataset 
12  from mvpa.featsel.ifs import IFS 
13  from mvpa.algorithms.cvtranserror import CrossValidatedTransferError 
14  from mvpa.clfs.transerror import TransferError 
15  from mvpa.datasets.splitters import NFoldSplitter 
16  from mvpa.featsel.helpers import FixedNElementTailSelector 
17   
18  from tests_warehouse import * 
19  from tests_warehouse_clfs import * 
20 21 22 -class IFSTests(unittest.TestCase):
23
24 - def getData(self):
25 data = N.random.standard_normal(( 100, 2, 2, 2 )) 26 labels = N.concatenate( ( N.repeat( 0, 50 ), 27 N.repeat( 1, 50 ) ) ) 28 chunks = N.repeat( range(5), 10 ) 29 chunks = N.concatenate( (chunks, chunks) ) 30 return MaskedDataset(samples=data, labels=labels, chunks=chunks)
31 32 33 # XXX just testing based on a single classifier. Not sure if 34 # should test for every known classifier since we are simply 35 # testing IFS algorithm - not sensitivities 36 @sweepargs(svm=clfswh['has_sensitivity', '!meta'][:1])
37 - def testIFS(self, svm):
38 39 # data measure and transfer error quantifier use the SAME clf! 40 trans_error = TransferError(svm) 41 data_measure = CrossValidatedTransferError(trans_error, 42 NFoldSplitter(1)) 43 44 ifs = IFS(data_measure, 45 trans_error, 46 feature_selector=\ 47 # go for lower tail selection as data_measure will return 48 # errors -> low is good 49 FixedNElementTailSelector(1, tail='lower', mode='select'), 50 ) 51 wdata = self.getData() 52 wdata_nfeatures = wdata.nfeatures 53 tdata = self.getData() 54 tdata_nfeatures = tdata.nfeatures 55 56 sdata, stdata = ifs(wdata, tdata) 57 58 # fail if orig datasets are changed 59 self.failUnless(wdata.nfeatures == wdata_nfeatures) 60 self.failUnless(tdata.nfeatures == tdata_nfeatures) 61 62 # check that the features set with the least error is selected 63 self.failUnless(len(ifs.errors)) 64 e = N.array(ifs.errors) 65 self.failUnless(sdata.nfeatures == e.argmin() + 1) 66 67 68 # repeat with dataset where selection order is known 69 signal = datasets['dumb2'] 70 sdata, stdata = ifs(signal, signal) 71 self.failUnless((sdata.samples[:,0] == signal.samples[:,0]).all())
72
73 74 -def suite():
75 return unittest.makeSuite(IFSTests)
76 77 78 if __name__ == '__main__': 79 import runner 80