1
2
3
4
5
6
7
8
9 """Unit tests for PyMVPA pattern handling"""
10
11 from mvpa.datasets.masked import MaskedDataset
12 from mvpa.datasets.splitters import NFoldSplitter, OddEvenSplitter, \
13 NoneSplitter, HalfSplitter, \
14 CustomSplitter, NGroupSplitter
15 import unittest
16 import numpy as N
17
18
20
22 self.data = \
23 MaskedDataset(samples=N.random.normal(size=(100,10)),
24 labels=[ i%4 for i in range(100) ],
25 chunks=[ i/10 for i in range(100)])
26
27
29
30 nfs = NFoldSplitter(cvtype=1)
31
32
33 xvpat = [ (train, test) for (train,test) in nfs(self.data) ]
34
35 self.failUnless( len(xvpat) == 10 )
36
37 for i,p in enumerate(xvpat):
38 self.failUnless( len(p) == 2 )
39 self.failUnless( p[0].nsamples == 90 )
40 self.failUnless( p[1].nsamples == 10 )
41 self.failUnless( p[1].chunks[0] == i )
42
43
45 oes = OddEvenSplitter()
46
47 splits = [ (train, test) for (train, test) in oes(self.data) ]
48
49 self.failUnless(len(splits) == 2)
50
51 for i,p in enumerate(splits):
52 self.failUnless( len(p) == 2 )
53 self.failUnless( p[0].nsamples == 50 )
54 self.failUnless( p[1].nsamples == 50 )
55
56 self.failUnless((splits[0][1].uniquechunks == [1, 3, 5, 7, 9]).all())
57 self.failUnless((splits[0][0].uniquechunks == [0, 2, 4, 6, 8]).all())
58 self.failUnless((splits[1][0].uniquechunks == [1, 3, 5, 7, 9]).all())
59 self.failUnless((splits[1][1].uniquechunks == [0, 2, 4, 6, 8]).all())
60
61
62 moresplits = [ (train, test) for (train, test) in oes(splits[0][0])]
63
64 for split in moresplits:
65 self.failUnless(split[0] != None)
66 self.failUnless(split[1] != None)
67
68
70 hs = HalfSplitter()
71
72 splits = [ (train, test) for (train, test) in hs(self.data) ]
73
74 self.failUnless(len(splits) == 2)
75
76 for i,p in enumerate(splits):
77 self.failUnless( len(p) == 2 )
78 self.failUnless( p[0].nsamples == 50 )
79 self.failUnless( p[1].nsamples == 50 )
80
81 self.failUnless((splits[0][1].uniquechunks == [0, 1, 2, 3, 4]).all())
82 self.failUnless((splits[0][0].uniquechunks == [5, 6, 7, 8, 9]).all())
83 self.failUnless((splits[1][1].uniquechunks == [5, 6, 7, 8, 9]).all())
84 self.failUnless((splits[1][0].uniquechunks == [0, 1, 2, 3, 4]).all())
85
86
87 moresplits = [ (train, test) for (train, test) in hs(splits[0][0])]
88
89 for split in moresplits:
90 self.failUnless(split[0] != None)
91 self.failUnless(split[1] != None)
92
94 """Test NGroupSplitter alongside with the reversal of the
95 order of spit out datasets
96 """
97
98 hs = NGroupSplitter(2)
99 hs_reversed = NGroupSplitter(2, reverse=True)
100
101 for isreversed, splitter in enumerate((hs, hs_reversed)):
102 splits = list(splitter(self.data))
103 self.failUnless(len(splits) == 2)
104
105 for i, p in enumerate(splits):
106 self.failUnless( len(p) == 2 )
107 self.failUnless( p[0].nsamples == 50 )
108 self.failUnless( p[1].nsamples == 50 )
109
110 self.failUnless((splits[0][1-isreversed].uniquechunks == [0, 1, 2, 3, 4]).all())
111 self.failUnless((splits[0][isreversed].uniquechunks == [5, 6, 7, 8, 9]).all())
112 self.failUnless((splits[1][1-isreversed].uniquechunks == [5, 6, 7, 8, 9]).all())
113 self.failUnless((splits[1][isreversed].uniquechunks == [0, 1, 2, 3, 4]).all())
114
115
116 moresplits = list(hs(splits[0][0]))
117
118 for split in moresplits:
119 self.failUnless(split[0] != None)
120 self.failUnless(split[1] != None)
121
122
123 s5 = NGroupSplitter(5)
124 s5_reversed = NGroupSplitter(5, reverse=True)
125
126
127 for isreversed, s5splitter in enumerate((s5, s5_reversed)):
128 splits = list(s5splitter(self.data))
129
130
131 self.failUnless(len(splits) == 5)
132
133
134 self.failUnless((splits[0][1-isreversed].uniquechunks == [0, 1]).all())
135 self.failUnless((splits[0][isreversed].uniquechunks == [2, 3, 4, 5, 6, 7, 8, 9]).all())
136 self.failUnless((splits[1][1-isreversed].uniquechunks == [2, 3]).all())
137 self.failUnless((splits[1][isreversed].uniquechunks == [0, 1, 4, 5, 6, 7, 8, 9]).all())
138
139 self.failUnless((splits[4][1-isreversed].uniquechunks == [8, 9]).all())
140 self.failUnless((splits[4][isreversed].uniquechunks == [0, 1, 2, 3, 4, 5, 6, 7]).all())
141
142
143
144 def splitcall(spl, dat):
145 return [ (train, test) for (train, test) in spl(dat) ]
146 s20 = NGroupSplitter(20)
147 self.assertRaises(ValueError,splitcall,s20,self.data)
148
150
151 hs = CustomSplitter([(None,[0,1,2,3,4]),(None,[5,6,7,8,9])])
152 splits = list(hs(self.data))
153 self.failUnless(len(splits) == 2)
154
155 for i,p in enumerate(splits):
156 self.failUnless( len(p) == 2 )
157 self.failUnless( p[0].nsamples == 50 )
158 self.failUnless( p[1].nsamples == 50 )
159
160 self.failUnless((splits[0][1].uniquechunks == [0, 1, 2, 3, 4]).all())
161 self.failUnless((splits[0][0].uniquechunks == [5, 6, 7, 8, 9]).all())
162 self.failUnless((splits[1][1].uniquechunks == [5, 6, 7, 8, 9]).all())
163 self.failUnless((splits[1][0].uniquechunks == [0, 1, 2, 3, 4]).all())
164
165
166
167 cs = CustomSplitter([([0,3,4],[5,9])])
168 splits = list(cs(self.data))
169 self.failUnless(len(splits) == 1)
170
171 for i,p in enumerate(splits):
172 self.failUnless( len(p) == 2 )
173 self.failUnless( p[0].nsamples == 30 )
174 self.failUnless( p[1].nsamples == 20 )
175
176 self.failUnless((splits[0][1].uniquechunks == [5, 9]).all())
177 self.failUnless((splits[0][0].uniquechunks == [0, 3, 4]).all())
178
179
180 cs = CustomSplitter([([0,3,4],[5,9],[2])],
181 nperlabel=[3,4,1],
182 nrunspersplit=3)
183 splits = list(cs(self.data))
184 self.failUnless(len(splits) == 3)
185
186 for i,p in enumerate(splits):
187 self.failUnless( len(p) == 3 )
188 self.failUnless( p[0].nsamples == 12 )
189 self.failUnless( p[1].nsamples == 16 )
190 self.failUnless( p[2].nsamples == 4 )
191
192
193
194 cs = CustomSplitter([([0,3,4],[5,9],[2])],
195 nperlabel=[[0.3, 0.6, 1.0, 0.5],
196 0.5,
197 'all'],
198 nrunspersplit=3)
199 csall = CustomSplitter([([0,3,4],[5,9],[2])],
200 nrunspersplit=3)
201
202
203 splits = list(cs(self.data))
204 splitsall = list(csall(self.data))
205
206 self.failUnless(len(splits) == 3)
207 ul = self.data.uniquelabels
208
209 self.failUnless(((N.array(splitsall[0][0].samplesperlabel.values())
210 *[0.3, 0.6, 1.0, 0.5]).round().astype(int) ==
211 N.array(splits[0][0].samplesperlabel.values())).all())
212
213 self.failUnless(((N.array(splitsall[0][1].samplesperlabel.values())*0.5
214 ).round().astype(int) ==
215 N.array(splits[0][1].samplesperlabel.values())).all())
216
217 self.failUnless((N.array(splitsall[0][2].samplesperlabel.values()) ==
218 N.array(splits[0][2].samplesperlabel.values())).all())
219
220
222 nos = NoneSplitter()
223 splits = [ (train, test) for (train, test) in nos(self.data) ]
224 self.failUnless(len(splits) == 1)
225 self.failUnless(splits[0][0] == None)
226 self.failUnless(splits[0][1].nsamples == 100)
227
228 nos = NoneSplitter(mode='first')
229 splits = [ (train, test) for (train, test) in nos(self.data) ]
230 self.failUnless(len(splits) == 1)
231 self.failUnless(splits[0][1] == None)
232 self.failUnless(splits[0][0].nsamples == 100)
233
234
235
236
237 nos = NoneSplitter(nrunspersplit=3,
238 nperlabel=10)
239 splits = [ (train, test) for (train, test) in nos(self.data) ]
240
241 self.failUnless(len(splits) == 3)
242 for split in splits:
243 self.failUnless(split[0] == None)
244 self.failUnless(split[1].nsamples == 40)
245 self.failUnless(split[1].samplesperlabel.values() == [10,10,10,10])
246
247
248 nos = NoneSplitter(nrunspersplit=3,
249 nperlabel='equal')
250 splits = [ (train, test) for (train, test) in nos(self.data) ]
251
252 self.failUnless(len(splits) == 3)
253 for split in splits:
254 self.failUnless(split[0] == None)
255 self.failUnless(split[1].nsamples == 100)
256 self.failUnless(split[1].samplesperlabel.values() == [25,25,25,25])
257
258
260 oes = OddEvenSplitter(attr='labels')
261
262 splits = [ (first, second) for (first, second) in oes(self.data) ]
263
264 self.failUnless((splits[0][0].uniquelabels == [0,2]).all())
265 self.failUnless((splits[0][1].uniquelabels == [1,3]).all())
266 self.failUnless((splits[1][0].uniquelabels == [1,3]).all())
267 self.failUnless((splits[1][1].uniquelabels == [0,2]).all())
268
269
271
272 nchunks = len(self.data.uniquechunks)
273 for strategy in NFoldSplitter._STRATEGIES:
274 for count, target in [ (nchunks*2, nchunks),
275 (nchunks, nchunks),
276 (nchunks-1, nchunks-1),
277 (3, 3),
278 (0, 0),
279 (1, 1)
280 ]:
281 nfs = NFoldSplitter(cvtype=1, count=count, strategy=strategy)
282 splits = [ (train, test) for (train,test) in nfs(self.data) ]
283 self.failUnless(len(splits) == target)
284 chosenchunks = [int(s[1].uniquechunks) for s in splits]
285
286
287 nsplits = len(splits)
288 if nsplits > 0:
289
290 for ds_ in splits[-1]:
291 self.failUnless(ds_._dsattr['lastsplit'])
292
293 for isplit,split in enumerate(splits):
294 for ds_ in split:
295 ds_._dsattr['lastsplit'] == isplit==nsplits-1
296
297
298 if strategy == 'first':
299 self.failUnlessEqual(chosenchunks, range(target))
300 elif strategy == 'equidistant':
301 if target == 3:
302 self.failUnlessEqual(chosenchunks, [0, 3, 7])
303 elif strategy == 'random':
304
305 self.failUnless(len(set(chosenchunks)) == len(chosenchunks))
306 self.failUnless(target == len(chosenchunks))
307 else:
308 raise RuntimeError, "Add unittest for strategy %s" \
309 % strategy
310
311
313 splitters = [NFoldSplitter(),
314 NFoldSplitter(discard_boundary=(0,1)),
315 NFoldSplitter(discard_boundary=(1,0)),
316 NFoldSplitter(discard_boundary=(2,0)),
317 NFoldSplitter(discard_boundary=1),
318 OddEvenSplitter(discard_boundary=(1,0)),
319 OddEvenSplitter(discard_boundary=(0,1)),
320 HalfSplitter(discard_boundary=(1,0)),
321 ]
322
323 split_sets = [list(s(self.data)) for s in splitters]
324 counts = [[(len(s[0].chunks), len(s[1].chunks)) for s in split_set]
325 for split_set in split_sets]
326
327 nodiscard_tr = [c[0] for c in counts[0]]
328 nodiscard_te = [c[1] for c in counts[0]]
329
330
331 self.failUnless(nodiscard_tr == [c[0] for c in counts[1]])
332 self.failUnless(nodiscard_te[1:-1] == [c[1] + 2 for c in counts[1][1:-1]])
333
334 self.failUnless(nodiscard_te[0] == counts[1][0][1] + 1)
335 self.failUnless(nodiscard_te[-1] == counts[1][-1][1] + 1)
336
337
338 for d in [1,2]:
339 self.failUnless(nodiscard_te == [c[1] for c in counts[1+d]])
340 self.failUnless(nodiscard_tr[0] == counts[1+d][0][0] + d)
341 self.failUnless(nodiscard_tr[-1] == counts[1+d][-1][0] + d)
342 self.failUnless(nodiscard_tr[1:-1] == [c[0] + d*2
343 for c in counts[1+d][1:-1]])
344
345
346 counts_min = [(min(c1[0], c2[0]), min(c1[1], c2[1]))
347 for c1,c2 in zip(counts[1], counts[2])]
348 self.failUnless(counts_min == counts[4])
349
350
351
352
353
354
355
358
359
360 if __name__ == '__main__':
361 import runner
362