Package mvpa :: Package tests :: Module test_params
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_params

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Unit tests for PyMVPA Parameter class.""" 
 10   
 11  import unittest, copy 
 12   
 13  import numpy as N 
 14   
 15  from mvpa.datasets import Dataset 
 16  from mvpa.misc.state import ClassWithCollections, StateVariable 
 17  from mvpa.misc.param import Parameter, KernelParameter 
 18   
 19  from tests_warehouse_clfs import SameSignClassifier 
 20   
21 -class ParametrizedClassifier(SameSignClassifier):
22 p1 = Parameter(1.0) 23 kp1 = KernelParameter(100.0)
24
25 -class ParametrizedClassifierExtended(ParametrizedClassifier):
26 - def __init__(self):
27 ParametrizedClassifier.__init__(self) 28 self.kernel_params.add(KernelParameter(200.0, doc="Very useful param", name="kp2"))
29
30 -class BlankClass(ClassWithCollections):
31 pass
32
33 -class SimpleClass(ClassWithCollections):
34 C = Parameter(1.0, min=0, doc="C parameter")
35
36 -class MixedClass(ClassWithCollections):
37 C = Parameter(1.0, min=0, doc="C parameter") 38 D = Parameter(3.0, min=0, doc="D parameter") 39 state1 = StateVariable(doc="bogus")
40
41 -class ParamsTests(unittest.TestCase):
42
43 - def testBlank(self):
44 blank = BlankClass() 45 46 self.failUnlessRaises(AttributeError, blank.__getattribute__, 'states') 47 self.failUnlessRaises(IndexError, blank.__getattribute__, '')
48
49 - def testSimple(self):
50 simple = SimpleClass() 51 52 self.failUnlessEqual(len(simple.params.items), 1) 53 self.failUnlessRaises(AttributeError, simple.__getattribute__, 'dummy') 54 self.failUnlessRaises(IndexError, simple.__getattribute__, '') 55 56 self.failUnlessEqual(simple.C, 1.0) 57 self.failUnlessEqual(simple.params.isSet("C"), False) 58 self.failUnlessEqual(simple.params.isSet(), False) 59 self.failUnlessEqual(simple.params["C"].isDefault, True) 60 self.failUnlessEqual(simple.params["C"].equalDefault, True) 61 62 simple.C = 1.0 63 # we are not actually setting the value if == default 64 self.failUnlessEqual(simple.params["C"].isDefault, True) 65 self.failUnlessEqual(simple.params["C"].equalDefault, True) 66 67 simple.C = 10.0 68 self.failUnlessEqual(simple.params.isSet("C"), True) 69 self.failUnlessEqual(simple.params.isSet(), True) 70 self.failUnlessEqual(simple.params["C"].isDefault, False) 71 self.failUnlessEqual(simple.params["C"].equalDefault, False) 72 73 self.failUnlessEqual(simple.C, 10.0) 74 simple.params["C"].resetvalue() 75 self.failUnlessEqual(simple.params.isSet("C"), True) 76 # TODO: Test if we 'train' a classifier f we get isSet to false 77 self.failUnlessEqual(simple.C, 1.0) 78 self.failUnlessRaises(AttributeError, simple.params.__getattribute__, 'B')
79
80 - def testMixed(self):
81 mixed = MixedClass() 82 83 self.failUnlessEqual(len(mixed.params.items), 2) 84 self.failUnlessEqual(len(mixed.states.items), 1) 85 self.failUnlessRaises(AttributeError, mixed.__getattribute__, 'kernel_params') 86 87 self.failUnlessEqual(mixed.C, 1.0) 88 self.failUnlessEqual(mixed.params.isSet("C"), False) 89 self.failUnlessEqual(mixed.params.isSet(), False) 90 mixed.C = 10.0 91 self.failUnlessEqual(mixed.params.isSet("C"), True) 92 self.failUnlessEqual(mixed.params.isSet("D"), False) 93 self.failUnlessEqual(mixed.params.isSet(), True) 94 self.failUnlessEqual(mixed.D, 3.0)
95 96
97 - def testClassifier(self):
98 clf = ParametrizedClassifier() 99 self.failUnlessEqual(len(clf.params.items), 3) # + regression/retrainable 100 self.failUnlessEqual(len(clf.kernel_params.items), 1) 101 102 clfe = ParametrizedClassifierExtended() 103 self.failUnlessEqual(len(clfe.params.items), 3) 104 self.failUnlessEqual(len(clfe.kernel_params.items), 2) 105 self.failUnlessEqual(len(clfe.kernel_params.listing), 2) 106 107 # check assignment once again 108 self.failUnlessEqual(clfe.kp2, 200.0) 109 clfe.kp2 = 201.0 110 self.failUnlessEqual(clfe.kp2, 201.0) 111 self.failUnlessEqual(clfe.kernel_params.isSet("kp2"), True) 112 clfe.train(Dataset(samples=[[0,0]], labels=[1], chunks=[1])) 113 self.failUnlessEqual(clfe.kernel_params.isSet("kp2"), False) 114 self.failUnlessEqual(clfe.kernel_params.isSet(), False) 115 self.failUnlessEqual(clfe.params.isSet(), False)
116
117 -def suite():
118 return unittest.makeSuite(ParamsTests)
119 120 121 if __name__ == '__main__': 122 import runner 123