Package mvpa :: Package tests :: Module test_clfcrossval
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_clfcrossval

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Unit tests for PyMVPA classifier cross-validation""" 
 10   
 11  from mvpa.datasets.splitters import NFoldSplitter 
 12  from mvpa.datasets.meta import MetaDataset 
 13  from mvpa.algorithms.cvtranserror import CrossValidatedTransferError 
 14  from mvpa.clfs.transerror import TransferError 
 15   
 16  from tests_warehouse import * 
 17  from tests_warehouse import pureMultivariateSignal, getMVPattern 
 18  from tests_warehouse_clfs import * 
 19   
20 -class CrossValidationTests(unittest.TestCase):
21 22
23 - def testSimpleNMinusOneCV(self):
24 data = getMVPattern(3) 25 26 self.failUnless( data.nsamples == 120 ) 27 self.failUnless( data.nfeatures == 2 ) 28 self.failUnless( 29 (data.labels == \ 30 [0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0] * 6).all()) 31 self.failUnless( 32 (data.chunks == \ 33 [k for k in range(1, 7) for i in range(20)]).all()) 34 35 transerror = TransferError(sample_clf_nl) 36 cv = CrossValidatedTransferError( 37 transerror, 38 NFoldSplitter(cvtype=1), 39 enable_states=['confusion', 'training_confusion', 40 'samples_error']) 41 42 results = cv(data) 43 self.failUnless( results < 0.2 and results >= 0.0 ) 44 45 # TODO: test accessibility of {training_,}confusion{,s} of 46 # CrossValidatedTransferError 47 48 self.failUnless(isinstance(cv.samples_error, dict)) 49 self.failUnless(len(cv.samples_error) == data.nsamples) 50 # one value for each origid 51 self.failUnless(sorted(cv.samples_error.keys()) == sorted(data.origids)) 52 for k, v in cv.samples_error.iteritems(): 53 self.failUnless(len(v) == 1)
54 55
56 - def testNoiseClassification(self):
57 # get a dataset with a very high SNR 58 data = getMVPattern(10) 59 60 # do crossval with default errorfx and 'mean' combiner 61 transerror = TransferError(sample_clf_nl) 62 cv = CrossValidatedTransferError(transerror, NFoldSplitter(cvtype=1)) 63 64 # must return a scalar value 65 result = cv(data) 66 67 # must be perfect 68 self.failUnless( result < 0.05 ) 69 70 # do crossval with permuted regressors 71 cv = CrossValidatedTransferError(transerror, 72 NFoldSplitter(cvtype=1, permute=True, nrunspersplit=10) ) 73 results = cv(data) 74 75 # must be at chance level 76 pmean = N.array(results).mean() 77 self.failUnless( pmean < 0.58 and pmean > 0.42 )
78 79
80 - def testHarvesting(self):
81 # get a dataset with a very high SNR 82 data = getMVPattern(10) 83 84 # do crossval with default errorfx and 'mean' combiner 85 transerror = TransferError(clfswh['linear'][0]) 86 cv = CrossValidatedTransferError( 87 transerror, 88 NFoldSplitter(cvtype=1), 89 harvest_attribs=['transerror.clf.training_time']) 90 result = cv(data) 91 self.failUnless(cv.harvested.has_key('transerror.clf.training_time')) 92 self.failUnless(len(cv.harvested['transerror.clf.training_time'])>1)
93 94
96 # simple datasets with decreasing SNR 97 data = MetaDataset([getMVPattern(3), getMVPattern(2), getMVPattern(1)]) 98 99 self.failUnless( data.nsamples == 120 ) 100 self.failUnless( data.nfeatures == 6 ) 101 self.failUnless( 102 (data.labels == \ 103 [0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0] * 6).all()) 104 self.failUnless( 105 (data.chunks == \ 106 [ k for k in range(1,7) for i in range(20) ] ).all() ) 107 108 transerror = TransferError(sample_clf_nl) 109 cv = CrossValidatedTransferError(transerror, 110 NFoldSplitter(cvtype=1), 111 enable_states=['confusion', 112 'training_confusion']) 113 114 results = cv(data) 115 self.failUnless(results < 0.2 and results >= 0.0, 116 msg="We should generalize while working with " 117 "metadataset. Got %s error" % results)
118 119 # TODO: test accessibility of {training_,}confusion{,s} of 120 # CrossValidatedTransferError 121 122 123
124 -def suite():
125 return unittest.makeSuite(CrossValidationTests)
126 127 128 if __name__ == '__main__': 129 import runner 130