Package mvpa :: Package tests :: Module test_gnb
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_gnb

 1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
 2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
 3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
 4  # 
 5  #   See COPYING file distributed along with the PyMVPA package for the 
 6  #   copyright and license terms. 
 7  # 
 8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
 9  """Unit tests for PyMVPA GNB classifier""" 
10   
11  from mvpa.clfs.gnb import GNB 
12  from tests_warehouse import * 
13   
14 -class GNBTests(unittest.TestCase):
15
16 - def testGNB(self):
17 gnb = GNB() 18 gnb_nc = GNB(common_variance=False) 19 gnb_n = GNB(normalize=True) 20 gnb_n_nc = GNB(normalize=True, common_variance=False) 21 22 ds_tr = datasets['uni2medium_train'] 23 ds_te = datasets['uni2medium_test'] 24 25 # Generic silly coverage just to assure that it works in all 26 # possible scenarios: 27 bools = (True, False) 28 # There should be better way... heh 29 for cv in bools: # common_variance? 30 for prior in ('uniform', 'laplacian_smoothing', 'ratio'): 31 tp = None # predictions -- all above should 32 # result in the same predictions 33 for n in bools: # normalized? 34 for ls in bools: # logspace? 35 for es in ((), ('values')): 36 gnb_ = GNB(common_variance=cv, 37 prior=prior, 38 normalize=n, 39 logprob=ls, 40 enable_states=es) 41 gnb_.train(ds_tr) 42 predictions = gnb_.predict(ds_te.samples) 43 if tp is None: 44 tp = predictions 45 self.failUnless((predictions == tp), 46 msg="%s failed to reproduce predictions" % 47 gnb_) 48 # if normalized -- check if values are such 49 if n and 'values' in es: 50 v = gnb_.values 51 if ls: # in log space -- take exp ;) 52 v = N.exp(v) 53 d1 = N.sum(v, axis=1) - 1.0 54 self.failUnless(N.max(N.abs(d1)) < 1e-5)
55
56 -def suite():
57 return unittest.makeSuite(GNBTests)
58 59 60 if __name__ == '__main__': 61 import runner 62