Package mvpa :: Package tests :: Module test_enet
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_enet

 1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
 2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
 3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
 4  # 
 5  #   See COPYING file distributed along with the PyMVPA package for the 
 6  #   copyright and license terms. 
 7  # 
 8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
 9  """Unit tests for PyMVPA least angle regression (ENET) classifier""" 
10   
11  from mvpa import cfg 
12  from mvpa.clfs.enet import ENET 
13  from scipy.stats import pearsonr 
14  from tests_warehouse import * 
15  from mvpa.misc.data_generators import normalFeatureDataset 
16   
17 -class ENETTests(unittest.TestCase):
18
19 - def testENET(self):
20 # not the perfect dataset with which to test, but 21 # it will do for now. 22 #data = datasets['dumb2'] 23 # for some reason the R code fails with the dumb data 24 data = datasets['chirp_linear'] 25 26 clf = ENET() 27 28 clf.train(data) 29 30 # prediction has to be almost perfect 31 # test with a correlation 32 pre = clf.predict(data.samples) 33 cor = pearsonr(pre, data.labels) 34 if cfg.getboolean('tests', 'labile', default='yes'): 35 self.failUnless(cor[0] > .8)
36
37 - def testENETState(self):
38 #data = datasets['dumb2'] 39 # for some reason the R code fails with the dumb data 40 data = datasets['chirp_linear'] 41 42 clf = ENET() 43 44 clf.train(data) 45 46 clf.states.enable('predictions') 47 48 p = clf.predict(data.samples) 49 50 self.failUnless((p == clf.predictions).all())
51 52
53 - def testENETSensitivities(self):
54 data = normalFeatureDataset(perlabel=10, nlabels=2, nfeatures=4) 55 56 # use ENET on binary problem 57 clf = ENET() 58 clf.train(data) 59 60 # now ask for the sensitivities WITHOUT having to pass the dataset 61 # again 62 sens = clf.getSensitivityAnalyzer(force_training=False)() 63 64 self.failUnless(sens.shape == (data.nfeatures,))
65 66
67 -def suite():
68 return unittest.makeSuite(ENETTests)
69 70 71 if __name__ == '__main__': 72 import runner 73