Home | Trees | Indices | Help |
|
---|
|
1 # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 2 # vi: set ft=python sts=4 ts=4 sw=4 et: 3 ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 4 # 5 # See COPYING file distributed along with the PyMVPA package for the 6 # copyright and license terms. 7 # 8 ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 9 """Unit tests for PyMVPA classifier cross-validation""" 10 11 from mvpa.datasets.splitters import NFoldSplitter 12 from mvpa.datasets.meta import MetaDataset 13 from mvpa.algorithms.cvtranserror import CrossValidatedTransferError 14 from mvpa.clfs.transerror import TransferError 15 16 from tests_warehouse import * 17 from tests_warehouse import pureMultivariateSignal, getMVPattern 18 from tests_warehouse_clfs import * 1921 22118 119 # TODO: test accessibility of {training_,}confusion{,s} of 120 # CrossValidatedTransferError 121 122 12324 data = getMVPattern(3) 25 26 self.failUnless( data.nsamples == 120 ) 27 self.failUnless( data.nfeatures == 2 ) 28 self.failUnless( 29 (data.labels == \ 30 [0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0] * 6).all()) 31 self.failUnless( 32 (data.chunks == \ 33 [k for k in range(1, 7) for i in range(20)]).all()) 34 35 transerror = TransferError(sample_clf_nl) 36 cv = CrossValidatedTransferError( 37 transerror, 38 NFoldSplitter(cvtype=1), 39 enable_states=['confusion', 'training_confusion', 40 'samples_error']) 41 42 results = cv(data) 43 self.failUnless( results < 0.2 and results >= 0.0 ) 44 45 # TODO: test accessibility of {training_,}confusion{,s} of 46 # CrossValidatedTransferError 47 48 self.failUnless(isinstance(cv.samples_error, dict)) 49 self.failUnless(len(cv.samples_error) == data.nsamples) 50 # one value for each origid 51 self.failUnless(sorted(cv.samples_error.keys()) == sorted(data.origids)) 52 for k, v in cv.samples_error.iteritems(): 53 self.failUnless(len(v) == 1)54 5557 # get a dataset with a very high SNR 58 data = getMVPattern(10) 59 60 # do crossval with default errorfx and 'mean' combiner 61 transerror = TransferError(sample_clf_nl) 62 cv = CrossValidatedTransferError(transerror, NFoldSplitter(cvtype=1)) 63 64 # must return a scalar value 65 result = cv(data) 66 67 # must be perfect 68 self.failUnless( result < 0.05 ) 69 70 # do crossval with permuted regressors 71 cv = CrossValidatedTransferError(transerror, 72 NFoldSplitter(cvtype=1, permute=True, nrunspersplit=10) ) 73 results = cv(data) 74 75 # must be at chance level 76 pmean = N.array(results).mean() 77 self.failUnless( pmean < 0.58 and pmean > 0.42 )78 7981 # get a dataset with a very high SNR 82 data = getMVPattern(10) 83 84 # do crossval with default errorfx and 'mean' combiner 85 transerror = TransferError(clfswh['linear'][0]) 86 cv = CrossValidatedTransferError( 87 transerror, 88 NFoldSplitter(cvtype=1), 89 harvest_attribs=['transerror.clf.training_time']) 90 result = cv(data) 91 self.failUnless(cv.harvested.has_key('transerror.clf.training_time')) 92 self.failUnless(len(cv.harvested['transerror.clf.training_time'])>1)93 9496 # simple datasets with decreasing SNR 97 data = MetaDataset([getMVPattern(3), getMVPattern(2), getMVPattern(1)]) 98 99 self.failUnless( data.nsamples == 120 ) 100 self.failUnless( data.nfeatures == 6 ) 101 self.failUnless( 102 (data.labels == \ 103 [0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0] * 6).all()) 104 self.failUnless( 105 (data.chunks == \ 106 [ k for k in range(1,7) for i in range(20) ] ).all() ) 107 108 transerror = TransferError(sample_clf_nl) 109 cv = CrossValidatedTransferError(transerror, 110 NFoldSplitter(cvtype=1), 111 enable_states=['confusion', 112 'training_confusion']) 113 114 results = cv(data) 115 self.failUnless(results < 0.2 and results >= 0.0, 116 msg="We should generalize while working with " 117 "metadataset. Got %s error" % results)125 return unittest.makeSuite(CrossValidationTests)126 127 128 if __name__ == '__main__': 129 import runner 130
Home | Trees | Indices | Help |
|
---|
Generated by Epydoc 3.0.1 on Mon Apr 23 23:09:42 2012 | http://epydoc.sourceforge.net |