1
2
3
4
5
6
7
8
9 """Unit tests for PyMVPA incremental feature search."""
10
11 from mvpa.datasets.masked import MaskedDataset
12 from mvpa.featsel.ifs import IFS
13 from mvpa.algorithms.cvtranserror import CrossValidatedTransferError
14 from mvpa.clfs.transerror import TransferError
15 from mvpa.datasets.splitters import NFoldSplitter
16 from mvpa.featsel.helpers import FixedNElementTailSelector
17
18 from tests_warehouse import *
19 from tests_warehouse_clfs import *
23
25 data = N.random.standard_normal(( 100, 2, 2, 2 ))
26 labels = N.concatenate( ( N.repeat( 0, 50 ),
27 N.repeat( 1, 50 ) ) )
28 chunks = N.repeat( range(5), 10 )
29 chunks = N.concatenate( (chunks, chunks) )
30 return MaskedDataset(samples=data, labels=labels, chunks=chunks)
31
32
33
34
35
36 @sweepargs(svm=clfswh['has_sensitivity', '!meta'][:1])
38
39
40 trans_error = TransferError(svm)
41 data_measure = CrossValidatedTransferError(trans_error,
42 NFoldSplitter(1))
43
44 ifs = IFS(data_measure,
45 trans_error,
46 feature_selector=\
47
48
49 FixedNElementTailSelector(1, tail='lower', mode='select'),
50 )
51 wdata = self.getData()
52 wdata_nfeatures = wdata.nfeatures
53 tdata = self.getData()
54 tdata_nfeatures = tdata.nfeatures
55
56 sdata, stdata = ifs(wdata, tdata)
57
58
59 self.failUnless(wdata.nfeatures == wdata_nfeatures)
60 self.failUnless(tdata.nfeatures == tdata_nfeatures)
61
62
63 self.failUnless(len(ifs.errors))
64 e = N.array(ifs.errors)
65 self.failUnless(sdata.nfeatures == e.argmin() + 1)
66
67
68
69 signal = datasets['dumb2']
70 sdata, stdata = ifs(signal, signal)
71 self.failUnless((sdata.samples[:,0] == signal.samples[:,0]).all())
72
76
77
78 if __name__ == '__main__':
79 import runner
80