Package mvpa :: Package tests :: Module test_knn
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_knn

 1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
 2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
 3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
 4  # 
 5  #   See COPYING file distributed along with the PyMVPA package for the 
 6  #   copyright and license terms. 
 7  # 
 8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
 9  """Unit tests for PyMVPA kNN classifier""" 
10   
11  from mvpa.clfs.knn import kNN 
12  from tests_warehouse import * 
13  from tests_warehouse import pureMultivariateSignal 
14  from mvpa.clfs.distance import oneMinusCorrelation 
15   
16 -class KNNTests(unittest.TestCase):
17
18 - def testMultivariate(self):
19 20 mv_perf = [] 21 uv_perf = [] 22 23 clf = kNN(k=10) 24 for i in xrange(20): 25 train = pureMultivariateSignal( 20, 3 ) 26 test = pureMultivariateSignal( 20, 3 ) 27 clf.train(train) 28 p_mv = clf.predict( test.samples ) 29 mv_perf.append( N.mean(p_mv==test.labels) ) 30 31 clf.train(train.selectFeatures([0])) 32 p_uv = clf.predict( test.selectFeatures([0]).samples ) 33 uv_perf.append( N.mean(p_uv==test.labels) ) 34 35 mean_mv_perf = N.mean(mv_perf) 36 mean_uv_perf = N.mean(uv_perf) 37 38 self.failUnless( mean_mv_perf > 0.9 ) 39 self.failUnless( mean_uv_perf < mean_mv_perf )
40 41
42 - def testKNNState(self):
43 train = pureMultivariateSignal( 20, 3 ) 44 test = pureMultivariateSignal( 20, 3 ) 45 46 clf = kNN(k=10) 47 clf.train(train) 48 49 clf.states.enable('values') 50 clf.states.enable('predictions') 51 52 p = clf.predict(test.samples) 53 54 self.failUnless(p == clf.predictions) 55 self.failUnless(N.array(clf.values).shape == (80,2))
56 57
58 -def suite():
59 return unittest.makeSuite(KNNTests)
60 61 62 if __name__ == '__main__': 63 import runner 64