1
2
3
4
5
6
7
8
9 """Unit tests for PyMVPA SampleGroup mapper"""
10
11
12 import unittest
13 from mvpa.support.copy import deepcopy
14 import numpy as N
15
16 from mvpa.mappers.samplegroup import SampleGroupMapper
17 from mvpa.datasets import Dataset
18
19
21
23 data = N.arange(24).reshape(8,3)
24 labels = [0, 1] * 4
25 chunks = N.repeat(N.array((0,1)),4)
26
27
28 csamples = [[3, 4, 5], [6, 7, 8], [15, 16, 17], [18, 19, 20]]
29 clabels = [0, 1, 0, 1]
30 cchunks = [0, 0, 1, 1]
31
32 ds = Dataset(samples=data, labels=labels, chunks=chunks)
33
34
35 m = SampleGroupMapper()
36
37
38 self.failUnlessRaises(RuntimeError, m, data)
39
40
41 m.train(ds)
42
43 self.failUnless((m.forward(ds.samples) == csamples).all())
44 self.failUnless((m.forward(ds.labels) == clabels).all())
45 self.failUnless((m.forward(ds.chunks) == cchunks).all())
46
47
48
49
50 mapped = ds.applyMapper(samplesmapper=SampleGroupMapper())
51
52 self.failUnless(mapped.nsamples == 4)
53 self.failUnless(mapped.nfeatures == 3)
54 self.failUnless((mapped.samples == csamples).all())
55 self.failUnless((mapped.labels == clabels).all())
56 self.failUnless((mapped.chunks == cchunks).all())
57
58 self.failUnless((mapped.origids == range(4)).all())
59
60
63
64
65 if __name__ == '__main__':
66 import runner
67