1
2
3
4
5
6
7
8
9 """Unit tests for PyMVPA GNB classifier"""
10
11 from mvpa.clfs.gnb import GNB
12 from tests_warehouse import *
13
15
17 gnb = GNB()
18 gnb_nc = GNB(common_variance=False)
19 gnb_n = GNB(normalize=True)
20 gnb_n_nc = GNB(normalize=True, common_variance=False)
21
22 ds_tr = datasets['uni2medium_train']
23 ds_te = datasets['uni2medium_test']
24
25
26
27 bools = (True, False)
28
29 for cv in bools:
30 for prior in ('uniform', 'laplacian_smoothing', 'ratio'):
31 tp = None
32
33 for n in bools:
34 for ls in bools:
35 for es in ((), ('values')):
36 gnb_ = GNB(common_variance=cv,
37 prior=prior,
38 normalize=n,
39 logprob=ls,
40 enable_states=es)
41 gnb_.train(ds_tr)
42 predictions = gnb_.predict(ds_te.samples)
43 if tp is None:
44 tp = predictions
45 self.failUnless((predictions == tp),
46 msg="%s failed to reproduce predictions" %
47 gnb_)
48
49 if n and 'values' in es:
50 v = gnb_.values
51 if ls:
52 v = N.exp(v)
53 d1 = N.sum(v, axis=1) - 1.0
54 self.failUnless(N.max(N.abs(d1)) < 1e-5)
55
58
59
60 if __name__ == '__main__':
61 import runner
62