Package mvpa :: Package tests :: Module test_searchlight
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_searchlight

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Unit tests for PyMVPA searchlight algorithm""" 
 10   
 11  from mvpa.base import externals 
 12  from mvpa.datasets.masked import MaskedDataset 
 13  from mvpa.measures.searchlight import Searchlight 
 14  from mvpa.datasets.splitters import NFoldSplitter 
 15  from mvpa.algorithms.cvtranserror import CrossValidatedTransferError 
 16  from mvpa.clfs.transerror import TransferError 
 17   
 18  from tests_warehouse import * 
 19  from tests_warehouse_clfs import * 
 20   
21 -class SearchlightTests(unittest.TestCase):
22
23 - def setUp(self):
24 self.dataset = datasets['3dlarge']
25 26
27 - def testSearchlight(self):
28 # compute N-1 cross-validation for each sphere 29 transerror = TransferError(sample_clf_lin) 30 cv = CrossValidatedTransferError( 31 transerror, 32 NFoldSplitter(cvtype=1)) 33 # contruct radius 1 searchlight 34 sl = Searchlight(cv, radius=1.0, transformer=N.array, 35 enable_states=['spheresizes', 'raw_results']) 36 37 # run searchlight 38 results = sl(self.dataset) 39 40 # check for correct number of spheres 41 self.failUnless(len(results) == 106) 42 43 # verify if we can map correctly back 44 results_ospace = self.dataset.mapper.reverse(results) 45 46 # check for chance-level performance across all spheres 47 self.failUnless(0.4 < results.mean() < 0.6) 48 49 # check resonable sphere sizes 50 self.failUnless(len(sl.spheresizes) == 106) 51 self.failUnless(max(sl.spheresizes) == 7) 52 self.failUnless(min(sl.spheresizes) == 4) 53 54 # check base-class state 55 self.failUnlessEqual(len(sl.raw_results), 106)
56 57
59 # compute N-1 cross-validation for each sphere 60 transerror = TransferError(sample_clf_lin) 61 cv = CrossValidatedTransferError( 62 transerror, 63 NFoldSplitter(cvtype=1), 64 combiner=N.array) 65 # contruct radius 1 searchlight 66 sl = Searchlight(cv, radius=1.0, transformer=N.array, 67 center_ids=[3,50]) 68 69 # run searchlight 70 results = sl(self.dataset) 71 72 # only two spheres but error for all CV-folds 73 self.failUnlessEqual(results.shape, (2, len(self.dataset.uniquechunks)))
74 75
76 - def testChiSquareSearchlight(self):
77 # only do partial to save time 78 if not externals.exists('scipy'): 79 return 80 81 from mvpa.misc.stats import chisquare 82 83 transerror = TransferError(sample_clf_lin) 84 cv = CrossValidatedTransferError( 85 transerror, 86 NFoldSplitter(cvtype=1), 87 enable_states=['confusion']) 88 89 90 def getconfusion(data): 91 cv(data) 92 return chisquare(cv.confusion.matrix)[0]
93 94 # contruct radius 1 searchlight 95 sl = Searchlight(getconfusion, radius=1.0, 96 center_ids=[3,50]) 97 98 # run searchlight 99 results = sl(self.dataset) 100 101 self.failUnless(len(results) == 2)
102 103 104
105 -def suite():
106 return unittest.makeSuite(SearchlightTests)
107 108 109 if __name__ == '__main__': 110 import runner 111