Package mvpa :: Package tests :: Module test_splitter
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_splitter

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Unit tests for PyMVPA pattern handling""" 
 10   
 11  from mvpa.datasets.masked import MaskedDataset 
 12  from mvpa.datasets.splitters import NFoldSplitter, OddEvenSplitter, \ 
 13                                     NoneSplitter, HalfSplitter, \ 
 14                                     CustomSplitter, NGroupSplitter 
 15  import unittest 
 16  import numpy as N 
 17   
 18   
19 -class SplitterTests(unittest.TestCase):
20
21 - def setUp(self):
22 self.data = \ 23 MaskedDataset(samples=N.random.normal(size=(100,10)), 24 labels=[ i%4 for i in range(100) ], 25 chunks=[ i/10 for i in range(100)])
26 27
28 - def testSimplestCVPatGen(self):
29 # create the generator 30 nfs = NFoldSplitter(cvtype=1) 31 32 # now get the xval pattern sets One-Fold CV) 33 xvpat = [ (train, test) for (train,test) in nfs(self.data) ] 34 35 self.failUnless( len(xvpat) == 10 ) 36 37 for i,p in enumerate(xvpat): 38 self.failUnless( len(p) == 2 ) 39 self.failUnless( p[0].nsamples == 90 ) 40 self.failUnless( p[1].nsamples == 10 ) 41 self.failUnless( p[1].chunks[0] == i )
42 43
44 - def testOddEvenSplit(self):
45 oes = OddEvenSplitter() 46 47 splits = [ (train, test) for (train, test) in oes(self.data) ] 48 49 self.failUnless(len(splits) == 2) 50 51 for i,p in enumerate(splits): 52 self.failUnless( len(p) == 2 ) 53 self.failUnless( p[0].nsamples == 50 ) 54 self.failUnless( p[1].nsamples == 50 ) 55 56 self.failUnless((splits[0][1].uniquechunks == [1, 3, 5, 7, 9]).all()) 57 self.failUnless((splits[0][0].uniquechunks == [0, 2, 4, 6, 8]).all()) 58 self.failUnless((splits[1][0].uniquechunks == [1, 3, 5, 7, 9]).all()) 59 self.failUnless((splits[1][1].uniquechunks == [0, 2, 4, 6, 8]).all()) 60 61 # check if it works on pure odd and even chunk ids 62 moresplits = [ (train, test) for (train, test) in oes(splits[0][0])] 63 64 for split in moresplits: 65 self.failUnless(split[0] != None) 66 self.failUnless(split[1] != None)
67 68
69 - def testHalfSplit(self):
70 hs = HalfSplitter() 71 72 splits = [ (train, test) for (train, test) in hs(self.data) ] 73 74 self.failUnless(len(splits) == 2) 75 76 for i,p in enumerate(splits): 77 self.failUnless( len(p) == 2 ) 78 self.failUnless( p[0].nsamples == 50 ) 79 self.failUnless( p[1].nsamples == 50 ) 80 81 self.failUnless((splits[0][1].uniquechunks == [0, 1, 2, 3, 4]).all()) 82 self.failUnless((splits[0][0].uniquechunks == [5, 6, 7, 8, 9]).all()) 83 self.failUnless((splits[1][1].uniquechunks == [5, 6, 7, 8, 9]).all()) 84 self.failUnless((splits[1][0].uniquechunks == [0, 1, 2, 3, 4]).all()) 85 86 # check if it works on pure odd and even chunk ids 87 moresplits = [ (train, test) for (train, test) in hs(splits[0][0])] 88 89 for split in moresplits: 90 self.failUnless(split[0] != None) 91 self.failUnless(split[1] != None)
92
93 - def testNGroupSplit(self):
94 """Test NGroupSplitter alongside with the reversal of the 95 order of spit out datasets 96 """ 97 # Test 2 groups like HalfSplitter first 98 hs = NGroupSplitter(2) 99 hs_reversed = NGroupSplitter(2, reverse=True) 100 101 for isreversed, splitter in enumerate((hs, hs_reversed)): 102 splits = list(splitter(self.data)) 103 self.failUnless(len(splits) == 2) 104 105 for i, p in enumerate(splits): 106 self.failUnless( len(p) == 2 ) 107 self.failUnless( p[0].nsamples == 50 ) 108 self.failUnless( p[1].nsamples == 50 ) 109 110 self.failUnless((splits[0][1-isreversed].uniquechunks == [0, 1, 2, 3, 4]).all()) 111 self.failUnless((splits[0][isreversed].uniquechunks == [5, 6, 7, 8, 9]).all()) 112 self.failUnless((splits[1][1-isreversed].uniquechunks == [5, 6, 7, 8, 9]).all()) 113 self.failUnless((splits[1][isreversed].uniquechunks == [0, 1, 2, 3, 4]).all()) 114 115 # check if it works on pure odd and even chunk ids 116 moresplits = list(hs(splits[0][0])) 117 118 for split in moresplits: 119 self.failUnless(split[0] != None) 120 self.failUnless(split[1] != None) 121 122 # now test more groups 123 s5 = NGroupSplitter(5) 124 s5_reversed = NGroupSplitter(5, reverse=True) 125 126 # get the splits 127 for isreversed, s5splitter in enumerate((s5, s5_reversed)): 128 splits = list(s5splitter(self.data)) 129 130 # must have 10 splits 131 self.failUnless(len(splits) == 5) 132 133 # check split content 134 self.failUnless((splits[0][1-isreversed].uniquechunks == [0, 1]).all()) 135 self.failUnless((splits[0][isreversed].uniquechunks == [2, 3, 4, 5, 6, 7, 8, 9]).all()) 136 self.failUnless((splits[1][1-isreversed].uniquechunks == [2, 3]).all()) 137 self.failUnless((splits[1][isreversed].uniquechunks == [0, 1, 4, 5, 6, 7, 8, 9]).all()) 138 # ... 139 self.failUnless((splits[4][1-isreversed].uniquechunks == [8, 9]).all()) 140 self.failUnless((splits[4][isreversed].uniquechunks == [0, 1, 2, 3, 4, 5, 6, 7]).all()) 141 142 143 # Test for too many groups 144 def splitcall(spl, dat): 145 return [ (train, test) for (train, test) in spl(dat) ]
146 s20 = NGroupSplitter(20) 147 self.assertRaises(ValueError,splitcall,s20,self.data)
148
149 - def testCustomSplit(self):
150 #simulate half splitter 151 hs = CustomSplitter([(None,[0,1,2,3,4]),(None,[5,6,7,8,9])]) 152 splits = list(hs(self.data)) 153 self.failUnless(len(splits) == 2) 154 155 for i,p in enumerate(splits): 156 self.failUnless( len(p) == 2 ) 157 self.failUnless( p[0].nsamples == 50 ) 158 self.failUnless( p[1].nsamples == 50 ) 159 160 self.failUnless((splits[0][1].uniquechunks == [0, 1, 2, 3, 4]).all()) 161 self.failUnless((splits[0][0].uniquechunks == [5, 6, 7, 8, 9]).all()) 162 self.failUnless((splits[1][1].uniquechunks == [5, 6, 7, 8, 9]).all()) 163 self.failUnless((splits[1][0].uniquechunks == [0, 1, 2, 3, 4]).all()) 164 165 166 # check fully customized split with working and validation set specified 167 cs = CustomSplitter([([0,3,4],[5,9])]) 168 splits = list(cs(self.data)) 169 self.failUnless(len(splits) == 1) 170 171 for i,p in enumerate(splits): 172 self.failUnless( len(p) == 2 ) 173 self.failUnless( p[0].nsamples == 30 ) 174 self.failUnless( p[1].nsamples == 20 ) 175 176 self.failUnless((splits[0][1].uniquechunks == [5, 9]).all()) 177 self.failUnless((splits[0][0].uniquechunks == [0, 3, 4]).all()) 178 179 # full test with additional sampling and 3 datasets per split 180 cs = CustomSplitter([([0,3,4],[5,9],[2])], 181 nperlabel=[3,4,1], 182 nrunspersplit=3) 183 splits = list(cs(self.data)) 184 self.failUnless(len(splits) == 3) 185 186 for i,p in enumerate(splits): 187 self.failUnless( len(p) == 3 ) 188 self.failUnless( p[0].nsamples == 12 ) 189 self.failUnless( p[1].nsamples == 16 ) 190 self.failUnless( p[2].nsamples == 4 ) 191 192 # lets test selection of samples by ratio and combined with 193 # other ways 194 cs = CustomSplitter([([0,3,4],[5,9],[2])], 195 nperlabel=[[0.3, 0.6, 1.0, 0.5], 196 0.5, 197 'all'], 198 nrunspersplit=3) 199 csall = CustomSplitter([([0,3,4],[5,9],[2])], 200 nrunspersplit=3) 201 # lets craft simpler dataset 202 #ds = Dataset(samples=N.arange(12), labels=[1]*6+[2]*6, chunks=1) 203 splits = list(cs(self.data)) 204 splitsall = list(csall(self.data)) 205 206 self.failUnless(len(splits) == 3) 207 ul = self.data.uniquelabels 208 209 self.failUnless(((N.array(splitsall[0][0].samplesperlabel.values()) 210 *[0.3, 0.6, 1.0, 0.5]).round().astype(int) == 211 N.array(splits[0][0].samplesperlabel.values())).all()) 212 213 self.failUnless(((N.array(splitsall[0][1].samplesperlabel.values())*0.5 214 ).round().astype(int) == 215 N.array(splits[0][1].samplesperlabel.values())).all()) 216 217 self.failUnless((N.array(splitsall[0][2].samplesperlabel.values()) == 218 N.array(splits[0][2].samplesperlabel.values())).all())
219 220
221 - def testNoneSplitter(self):
222 nos = NoneSplitter() 223 splits = [ (train, test) for (train, test) in nos(self.data) ] 224 self.failUnless(len(splits) == 1) 225 self.failUnless(splits[0][0] == None) 226 self.failUnless(splits[0][1].nsamples == 100) 227 228 nos = NoneSplitter(mode='first') 229 splits = [ (train, test) for (train, test) in nos(self.data) ] 230 self.failUnless(len(splits) == 1) 231 self.failUnless(splits[0][1] == None) 232 self.failUnless(splits[0][0].nsamples == 100) 233 234 235 # test sampling tools 236 # specified value 237 nos = NoneSplitter(nrunspersplit=3, 238 nperlabel=10) 239 splits = [ (train, test) for (train, test) in nos(self.data) ] 240 241 self.failUnless(len(splits) == 3) 242 for split in splits: 243 self.failUnless(split[0] == None) 244 self.failUnless(split[1].nsamples == 40) 245 self.failUnless(split[1].samplesperlabel.values() == [10,10,10,10]) 246 247 # auto-determined 248 nos = NoneSplitter(nrunspersplit=3, 249 nperlabel='equal') 250 splits = [ (train, test) for (train, test) in nos(self.data) ] 251 252 self.failUnless(len(splits) == 3) 253 for split in splits: 254 self.failUnless(split[0] == None) 255 self.failUnless(split[1].nsamples == 100) 256 self.failUnless(split[1].samplesperlabel.values() == [25,25,25,25])
257 258
259 - def testLabelSplitter(self):
260 oes = OddEvenSplitter(attr='labels') 261 262 splits = [ (first, second) for (first, second) in oes(self.data) ] 263 264 self.failUnless((splits[0][0].uniquelabels == [0,2]).all()) 265 self.failUnless((splits[0][1].uniquelabels == [1,3]).all()) 266 self.failUnless((splits[1][0].uniquelabels == [1,3]).all()) 267 self.failUnless((splits[1][1].uniquelabels == [0,2]).all())
268 269
270 - def testCountedSplitting(self):
271 # count > #chunks, should result in 10 splits 272 nchunks = len(self.data.uniquechunks) 273 for strategy in NFoldSplitter._STRATEGIES: 274 for count, target in [ (nchunks*2, nchunks), 275 (nchunks, nchunks), 276 (nchunks-1, nchunks-1), 277 (3, 3), 278 (0, 0), 279 (1, 1) 280 ]: 281 nfs = NFoldSplitter(cvtype=1, count=count, strategy=strategy) 282 splits = [ (train, test) for (train,test) in nfs(self.data) ] 283 self.failUnless(len(splits) == target) 284 chosenchunks = [int(s[1].uniquechunks) for s in splits] 285 286 # Check if "lastsplit" dsattr was assigned appropriately 287 nsplits = len(splits) 288 if nsplits > 0: 289 # dummy-proof testing of last split 290 for ds_ in splits[-1]: 291 self.failUnless(ds_._dsattr['lastsplit']) 292 # test all now 293 for isplit,split in enumerate(splits): 294 for ds_ in split: 295 ds_._dsattr['lastsplit'] == isplit==nsplits-1 296 297 # Check results of different strategies 298 if strategy == 'first': 299 self.failUnlessEqual(chosenchunks, range(target)) 300 elif strategy == 'equidistant': 301 if target == 3: 302 self.failUnlessEqual(chosenchunks, [0, 3, 7]) 303 elif strategy == 'random': 304 # none is selected twice 305 self.failUnless(len(set(chosenchunks)) == len(chosenchunks)) 306 self.failUnless(target == len(chosenchunks)) 307 else: 308 raise RuntimeError, "Add unittest for strategy %s" \ 309 % strategy
310 311
312 - def testDiscardedBoundaries(self):
313 splitters = [NFoldSplitter(), 314 NFoldSplitter(discard_boundary=(0,1)), # discard testing 315 NFoldSplitter(discard_boundary=(1,0)), # discard training 316 NFoldSplitter(discard_boundary=(2,0)), # discard 2 from training 317 NFoldSplitter(discard_boundary=1), # discard from both 318 OddEvenSplitter(discard_boundary=(1,0)), 319 OddEvenSplitter(discard_boundary=(0,1)), 320 HalfSplitter(discard_boundary=(1,0)), 321 ] 322 323 split_sets = [list(s(self.data)) for s in splitters] 324 counts = [[(len(s[0].chunks), len(s[1].chunks)) for s in split_set] 325 for split_set in split_sets] 326 327 nodiscard_tr = [c[0] for c in counts[0]] 328 nodiscard_te = [c[1] for c in counts[0]] 329 330 # Discarding in testing: 331 self.failUnless(nodiscard_tr == [c[0] for c in counts[1]]) 332 self.failUnless(nodiscard_te[1:-1] == [c[1] + 2 for c in counts[1][1:-1]]) 333 # at the beginning/end chunks, just a single element 334 self.failUnless(nodiscard_te[0] == counts[1][0][1] + 1) 335 self.failUnless(nodiscard_te[-1] == counts[1][-1][1] + 1) 336 337 # Discarding in training 338 for d in [1,2]: 339 self.failUnless(nodiscard_te == [c[1] for c in counts[1+d]]) 340 self.failUnless(nodiscard_tr[0] == counts[1+d][0][0] + d) 341 self.failUnless(nodiscard_tr[-1] == counts[1+d][-1][0] + d) 342 self.failUnless(nodiscard_tr[1:-1] == [c[0] + d*2 343 for c in counts[1+d][1:-1]]) 344 345 # Discarding in both -- should be eq min from counts[1] and [2] 346 counts_min = [(min(c1[0], c2[0]), min(c1[1], c2[1])) 347 for c1,c2 in zip(counts[1], counts[2])] 348 self.failUnless(counts_min == counts[4])
349 350 # TODO: test all those odd/even etc splitters... YOH: did 351 # visually... looks ok;) 352 #for count in counts[5:]: 353 # print count 354 355
356 -def suite():
357 return unittest.makeSuite(SplitterTests)
358 359 360 if __name__ == '__main__': 361 import runner 362