1
2
3
4
5
6
7
8
9 """Unit tests for PyMVPA classifier cross-validation"""
10
11 import unittest
12 from mvpa.support.copy import copy
13
14 from mvpa.base import externals
15 from mvpa.datasets import Dataset
16 from mvpa.datasets.splitters import OddEvenSplitter
17
18 from mvpa.clfs.meta import MulticlassClassifier
19 from mvpa.clfs.transerror import \
20 TransferError, ConfusionMatrix, ConfusionBasedError
21 from mvpa.algorithms.cvtranserror import CrossValidatedTransferError
22
23 from mvpa.clfs.stats import MCNullDist
24
25 from mvpa.misc.exceptions import UnknownStateError
26
27 from tests_warehouse import datasets, sweepargs
28 from tests_warehouse_clfs import *
31
33 data = N.array([1,2,1,2,2,2,3,2,1], ndmin=2).T
34 reg = [1,1,1,2,2,2,3,3,3]
35 regl = [1,2,1,2,2,2,3,2,1]
36 correct_cm = [[2,0,1],[1,3,1],[0,0,1]]
37
38 for t in [reg, tuple(reg), list(reg), N.array(reg)]:
39 for p in [regl, tuple(regl), list(regl), N.array(regl)]:
40 cm = ConfusionMatrix(targets=t, predictions=p)
41
42 self.failUnless((cm.matrix == correct_cm).all())
43
44
45
46 cm = ConfusionMatrix()
47 self.failUnlessRaises(ZeroDivisionError, lambda x:x.percentCorrect, cm)
48 """No samples -- raise exception"""
49
50 cm.add(reg, regl)
51
52 self.failUnlessEqual(len(cm.sets), 1,
53 msg="Should have a single set so far")
54 self.failUnlessEqual(cm.matrix.shape, (3,3),
55 msg="should be square matrix (len(reglabels) x len(reglabels)")
56
57 self.failUnlessRaises(ValueError, cm.add, reg, N.array([1]))
58 """ConfusionMatrix must complaint if number of samples different"""
59
60
61 self.failUnless((cm.matrix == correct_cm).all())
62
63
64 cm.add(reg, N.array([1,4,1,2,2,2,4,2,1]))
65
66 self.failUnlessEqual(cm.labels, [1,2,3,4],
67 msg="We should have gotten 4th label")
68
69 matrices = cm.matrices
70 self.failUnlessEqual(len(matrices), 2,
71 msg="Have gotten two splits")
72
73 self.failUnless((matrices[0].matrix + matrices[1].matrix == cm.matrix).all(),
74 msg="Total votes should match the sum across split CMs")
75
76
77
78 self.failUnless(len(cm.asstring(
79 header=True, summary=True,
80 description=True))>100)
81 self.failUnless(len(str(cm))>100)
82
83 self.failUnless(len(cm.asstring(summary=True,
84 header=False))>100)
85
86
87 cm += cm
88 self.failUnlessEqual(len(cm.matrices), 4, msg="Must be 4 sets now")
89
90
91 cm2 = cm + cm
92 self.failUnlessEqual(len(cm2.matrices), 8, msg="Must be 8 sets now")
93 self.failUnlessEqual(cm2.percentCorrect, cm.percentCorrect,
94 msg="Percent of corrrect should remain the same ;-)")
95
96 self.failUnlessEqual(cm2.error, 1.0-cm.percentCorrect/100.0,
97 msg="Test if we get proper error value")
98
99
101
102
103
104 for orig in ([1], [1, 1], [0], [0, 0]):
105 cm = ConfusionMatrix(targets=orig, predictions=orig, values=orig)
106
107 scm = str(cm)
108 self.failUnless(cm.stats['ACC%'] == 100)
109
110
112 reg = [0,0,1,1]
113 regl = [1,0,1,0]
114 cm = ConfusionMatrix(targets=reg, predictions=regl)
115 self.failUnless('ACC% 50' in str(cm))
116
117
119 data = N.array([1,2,1,2,2,2,3,2,1], ndmin=2).T
120 reg = [1,1,1,2,2,2,3,3,3]
121 regl = [1,2,1,2,2,2,3,2,1]
122 correct_cm = [[2,0,1], [1,3,1], [0,0,1]]
123 lm = {'apple':1, 'orange':2, 'shitty apple':1, 'candy':3}
124 cm = ConfusionMatrix(targets=reg, predictions=regl,
125 labels_map=lm)
126
127 self.failUnless((cm.matrix == correct_cm).all())
128
129 s = str(cm)
130 for l in lm.keys():
131 self.failUnless(l in s)
132
133
134
135 @sweepargs(l_clf=clfswh['linear', 'svm'])
137 train = datasets['uni2medium_train']
138
139 test3 = datasets['uni3medium_train']
140 err = ConfusionBasedError(clf=l_clf)
141 terr = TransferError(clf=l_clf)
142
143 self.failUnlessRaises(UnknownStateError, err, None)
144 """Shouldn't be able to access the state yet"""
145
146 l_clf.train(train)
147 e, te = err(None), terr(train)
148 self.failUnless(abs(e-te) < 1e-10,
149 msg="ConfusionBasedError (%.2g) should be equal to TransferError "
150 "(%.2g) on traindataset" % (e, te))
151
152
153
154 self.failIf(terr(test3) is None)
155
156
157 terr_copy = copy(terr)
158
159
160 @sweepargs(l_clf=clfswh['linear', 'svm'])
162 train = datasets['uni2medium']
163
164 num_perm = 10
165
166
167
168 terr = TransferError(
169 clf=l_clf,
170 null_dist=MCNullDist(permutations=num_perm,
171 tail='left'))
172
173
174 err = terr(train, train)
175 self.failUnless(err < 0.4)
176
177
178 cvte = CrossValidatedTransferError(
179 TransferError(clf=l_clf),
180 OddEvenSplitter(),
181 null_dist=MCNullDist(permutations=num_perm,
182 tail='left',
183 enable_states=['dist_samples']))
184 cv_err = cvte(train)
185
186
187
188 null_prob = terr.states.null_prob
189 if cfg.getboolean('tests', 'labile', default='yes'):
190 self.failUnless(null_prob <= 0.1,
191 msg="Failed to check that the result is highly significant "
192 "(got %f) since we know that the data has signal"
193 % null_prob)
194
195 self.failUnless(cvte.states.null_prob <= 0.1,
196 msg="Failed to check that the result is highly significant "
197 "(got p(cvte)=%f) since we know that the data has signal"
198 % cvte.states.null_prob)
199
200
201 self.failUnlessEqual(len(cvte.null_dist.states.dist_samples),
202 num_perm)
203
204
205 @sweepargs(l_clf=clfswh['linear', 'svm'])
219
220
221 @sweepargs(clf=clfswh['multiclass'])
223 """Test AUC computation
224 """
225 if isinstance(clf, MulticlassClassifier):
226
227 return
228 clf.states._changeTemporarily(enable_states = ['values'])
229
230 ds2 = datasets['uni2small'].copy()
231 ds2.labels = 1 - ds2.labels
232
233 ds3 = datasets['uni3small'].copy()
234 ul = ds3.uniquelabels
235 nl = ds3.labels.copy()
236 for l in xrange(3):
237 nl[ds3.labels == ul[l]] = ul[(l+1)%3]
238 ds3.labels = nl
239 for ds in [datasets['uni2small'], ds2,
240 datasets['uni3small'], ds3]:
241 cv = CrossValidatedTransferError(
242 TransferError(clf),
243 OddEvenSplitter(),
244 enable_states=['confusion', 'training_confusion'])
245 cverror = cv(ds)
246 stats = cv.confusion.stats
247 Nlabels = len(ds.uniquelabels)
248
249 self.failUnless(stats['ACC'] > 1.2 / Nlabels)
250 auc = stats['AUC']
251 if (Nlabels == 2) or (Nlabels > 2 and auc[0] is not N.nan):
252 mauc = N.min(stats['AUC'])
253 if cfg.getboolean('tests', 'labile', default='yes'):
254 self.failUnless(mauc > 0.55,
255 msg='All AUCs must be above chance. Got minimal '
256 'AUC=%.2g among %s' % (mauc, stats['AUC']))
257 clf.states._resetEnabledTemporarily()
258
259
260
261
263 """Based on existing cell dataset results.
264
265 Let in for possible future testing, but is not a part of the
266 unittests suite
267 """
268
269
270
271
272
273
274
275
276
277 array = N.array
278 uint8 = N.uint8
279 sets = [
280 (array([47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
281 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
282 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
283 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
284 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
285 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
286 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
287 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
288 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
289 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
290 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
291 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
292 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44], dtype=uint8),
293 array([40, 39, 47, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45, 41, 44,
294 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 46,
295 45, 38, 44, 39, 46, 38, 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38,
296 40, 47, 43, 45, 41, 44, 40, 46, 42, 38, 39, 40, 43, 45, 41, 44, 39,
297 46, 42, 47, 38, 38, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45,
298 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 38,
299 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 47, 43, 45, 41, 44, 40, 46,
300 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 45, 41,
301 44, 47, 46, 42, 47, 38, 39, 43, 45, 40, 44, 40, 46, 42, 47, 39, 40,
302 43, 45, 41, 44, 38, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 41,
303 47, 39, 38, 46, 45, 41, 44, 40, 46, 42, 40, 38, 38, 43, 45, 41, 44,
304 40, 45, 42, 47, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 42, 43,
305 45, 41, 44, 39, 46, 42, 39, 39, 39, 47, 45, 41, 44], dtype=uint8)),
306 (array([40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
307 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
308 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
309 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
310 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
311 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
312 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
313 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
314 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
315 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
316 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
317 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
318 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43], dtype=uint8),
319 array([40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 47, 46, 42, 47, 39, 40, 43,
320 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
321 39, 38, 43, 45, 41, 44, 39, 46, 42, 47, 47, 47, 43, 45, 41, 44, 40,
322 46, 42, 43, 39, 38, 43, 45, 41, 44, 38, 38, 42, 38, 39, 38, 43, 45,
323 41, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 40, 42, 47, 40,
324 40, 43, 45, 41, 44, 38, 38, 42, 47, 38, 38, 47, 45, 41, 44, 40, 46,
325 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 47, 39, 43, 45, 41,
326 44, 40, 46, 42, 39, 39, 42, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39,
327 43, 45, 41, 44, 47, 46, 42, 40, 39, 39, 43, 45, 41, 44, 40, 46, 42,
328 47, 39, 38, 43, 45, 40, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44,
329 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 46, 47, 38, 39, 43,
330 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 39,
331 39, 38, 47, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43], dtype=uint8)),
332 (array([45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
333 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
334 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
335 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
336 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
337 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
338 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
339 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
340 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
341 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
342 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
343 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
344 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47], dtype=uint8),
345 array([45, 41, 44, 40, 46, 42, 47, 39, 46, 43, 45, 41, 44, 40, 46, 42, 47,
346 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40,
347 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 43, 43, 45,
348 40, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47,
349 40, 43, 45, 41, 44, 40, 47, 42, 38, 47, 38, 43, 45, 41, 44, 40, 40,
350 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
351 44, 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 40, 38,
352 43, 45, 41, 44, 40, 46, 38, 38, 39, 38, 43, 45, 41, 44, 39, 46, 42,
353 47, 40, 39, 43, 45, 38, 44, 38, 46, 42, 47, 47, 40, 43, 45, 41, 44,
354 40, 40, 42, 47, 40, 38, 43, 39, 41, 44, 41, 46, 42, 39, 39, 38, 38,
355 45, 41, 44, 38, 46, 40, 46, 46, 46, 43, 45, 38, 44, 40, 46, 42, 39,
356 39, 45, 43, 45, 41, 44, 38, 46, 42, 38, 39, 39, 43, 45, 41, 38, 40,
357 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 40], dtype=uint8)),
358 (array([39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
359 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
360 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
361 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
362 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
363 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
364 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
365 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
366 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
367 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
368 39, 38, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 41, 44, 40, 46,
369 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
370 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40], dtype=uint8),
371 array([39, 38, 43, 45, 41, 44, 40, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40,
372 46, 42, 47, 39, 38, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45,
373 41, 44, 40, 38, 43, 47, 38, 38, 43, 45, 41, 44, 39, 46, 42, 39, 39,
374 38, 43, 45, 41, 44, 43, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46,
375 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 39, 38, 38, 43, 45, 40,
376 44, 47, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 39, 38,
377 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 47, 44, 45, 46, 42,
378 38, 39, 41, 43, 45, 41, 44, 38, 38, 42, 39, 40, 40, 43, 45, 41, 39,
379 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 47, 42, 47, 38, 38, 43,
380 45, 41, 44, 47, 46, 42, 47, 40, 47, 43, 45, 41, 44, 40, 46, 42, 47,
381 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 46, 44, 38, 46,
382 42, 47, 38, 44, 43, 45, 42, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41,
383 44, 38, 46, 42, 39, 39, 38, 43, 45, 41, 44, 40], dtype=uint8)),
384 (array([46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
385 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
386 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
387 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
388 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
389 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
390 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
391 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
392 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
393 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
394 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
395 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
396 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45], dtype=uint8),
397 array([46, 42, 39, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 42, 43, 45,
398 42, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47,
399 40, 43, 45, 41, 44, 41, 46, 42, 38, 39, 38, 43, 45, 41, 44, 38, 46,
400 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 46, 38, 38, 43, 45, 41,
401 44, 39, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39,
402 43, 45, 41, 44, 40, 47, 42, 47, 38, 39, 43, 45, 41, 44, 39, 46, 42,
403 47, 39, 46, 43, 45, 41, 44, 39, 46, 42, 39, 39, 38, 43, 45, 41, 44,
404 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43,
405 45, 41, 44, 40, 38, 42, 46, 39, 38, 43, 45, 41, 44, 38, 46, 42, 46,
406 46, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 38, 38, 45, 41, 44, 38,
407 38, 42, 43, 39, 40, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 47, 45,
408 46, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 46, 42, 47, 40,
409 38, 43, 45, 41, 44, 38, 46, 42, 38, 39, 38, 47, 45], dtype=uint8)),
410 (array([41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
411 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
412 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
413 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
414 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
415 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
416 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
417 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
418 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
419 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
420 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
421 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
422 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39], dtype=uint8),
423 array([41, 44, 38, 46, 42, 47, 39, 47, 40, 45, 41, 44, 40, 46, 42, 38, 40,
424 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45, 41, 44, 46, 38,
425 42, 40, 38, 39, 43, 45, 41, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41,
426 44, 40, 46, 42, 38, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 43, 39,
427 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42,
428 40, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39, 39, 39, 43, 45, 41, 44,
429 40, 46, 42, 39, 38, 47, 43, 45, 38, 44, 40, 38, 42, 47, 38, 38, 43,
430 45, 41, 44, 40, 38, 46, 47, 38, 38, 43, 45, 41, 44, 41, 46, 42, 40,
431 38, 38, 40, 45, 41, 44, 40, 40, 42, 43, 38, 40, 43, 39, 41, 44, 40,
432 40, 42, 47, 38, 46, 43, 45, 41, 44, 47, 41, 42, 43, 40, 47, 43, 45,
433 41, 44, 41, 38, 42, 40, 39, 40, 43, 45, 41, 44, 39, 43, 42, 47, 39,
434 40, 43, 45, 41, 44, 42, 46, 42, 47, 40, 46, 43, 45, 41, 44, 38, 46,
435 42, 47, 47, 38, 43, 45, 41, 44, 40, 38, 39, 47, 38], dtype=uint8)),
436 (array([38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
437 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
438 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
439 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
440 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
441 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
442 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
443 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
444 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
445 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
446 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
447 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
448 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46], dtype=uint8),
449 array([39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 41, 46,
450 42, 47, 47, 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41,
451 44, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 45, 38,
452 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42,
453 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
454 40, 46, 42, 47, 40, 39, 43, 45, 41, 44, 40, 39, 42, 40, 39, 38, 43,
455 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39,
456 39, 47, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40,
457 46, 42, 46, 47, 39, 47, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45,
458 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 47, 39,
459 38, 43, 45, 42, 44, 39, 47, 42, 39, 39, 47, 43, 47, 40, 44, 40, 46,
460 42, 39, 39, 38, 39, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41,
461 44, 46, 38, 42, 47, 39, 43, 43, 45, 41, 44, 40, 46], dtype=uint8)),
462 (array([42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
463 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
464 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42,
465 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44,
466 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43,
467 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47,
468 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40,
469 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45,
470 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39,
471 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46,
472 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41,
473 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38,
474 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45], dtype=uint8),
475 array([42, 38, 38, 40, 43, 45, 41, 44, 39, 46, 42, 47, 39, 38, 43, 45, 41,
476 44, 39, 38, 42, 47, 41, 40, 43, 45, 41, 44, 40, 41, 42, 47, 38, 46,
477 43, 45, 41, 44, 41, 41, 42, 40, 39, 39, 43, 45, 41, 44, 46, 45, 42,
478 39, 39, 40, 43, 45, 41, 44, 40, 46, 42, 40, 44, 38, 43, 41, 41, 44,
479 39, 46, 42, 39, 39, 39, 43, 45, 41, 44, 40, 43, 42, 47, 39, 39, 43,
480 45, 41, 44, 40, 47, 42, 38, 46, 39, 47, 45, 41, 44, 39, 46, 42, 47,
481 41, 38, 43, 45, 41, 44, 42, 46, 42, 46, 39, 38, 43, 45, 41, 44, 41,
482 46, 42, 46, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45,
483 41, 44, 38, 46, 42, 39, 40, 43, 43, 45, 41, 44, 39, 38, 40, 40, 38,
484 38, 43, 45, 41, 44, 41, 40, 42, 39, 39, 39, 43, 45, 41, 44, 40, 46,
485 42, 47, 40, 40, 43, 45, 41, 44, 40, 46, 42, 41, 39, 39, 43, 45, 41,
486 44, 40, 38, 42, 40, 39, 46, 43, 45, 41, 44, 47, 46, 42, 47, 39, 38,
487 43, 45, 41, 44, 41, 46, 42, 43, 39, 39, 43, 45], dtype=uint8))]
488 labels_map = {'12kHz': 40,
489 '20kHz': 41,
490 '30kHz': 42,
491 '3kHz': 38,
492 '7kHz': 39,
493 'song1': 43,
494 'song2': 44,
495 'song3': 45,
496 'song4': 46,
497 'song5': 47}
498 try:
499 cm = ConfusionMatrix(sets=sets, labels_map=labels_map)
500 except:
501 self.fail()
502 self.failUnless('3kHz / 38' in cm.asstring())
503
504 if externals.exists("pylab plottable"):
505 import pylab as P
506 P.figure()
507 labels_order = ("3kHz", "7kHz", "12kHz", "20kHz","30kHz", None,
508 "song1","song2","song3","song4","song5")
509
510
511 fig, im, cb = cm.plot(labels=labels_order[1:2] + labels_order[:1]
512 + labels_order[2:], numbers=True)
513 self.failUnless(cm._plotted_confusionmatrix[0,0] == cm.matrix[1,1])
514 self.failUnless(cm._plotted_confusionmatrix[0,1] == cm.matrix[1,0])
515 self.failUnless(cm._plotted_confusionmatrix[1,1] == cm.matrix[0,0])
516 self.failUnless(cm._plotted_confusionmatrix[1,0] == cm.matrix[0,1])
517 P.close(fig)
518 fig, im, cb = cm.plot(labels=labels_order, numbers=True)
519 P.close(fig)
520
521
523 """Based on a sample confusion which plots incorrectly
524
525 """
526
527 array = N.array
528 uint8 = N.uint8
529 sets = [(array([1, 2]), array([1, 1]),
530 array([[ 0.54343765, 0.45656235],
531 [ 0.92395853, 0.07604147]])),
532 (array([1, 2]), array([1, 1]),
533 array([[ 0.98030832, 0.01969168],
534 [ 0.78998763, 0.21001237]])),
535 (array([1, 2]), array([1, 1]),
536 array([[ 0.86125263, 0.13874737],
537 [ 0.83674113, 0.16325887]])),
538 (array([1, 2]), array([1, 1]),
539 array([[ 0.57870383, 0.42129617],
540 [ 0.59702509, 0.40297491]])),
541 (array([1, 2]), array([1, 1]),
542 array([[ 0.89530255, 0.10469745],
543 [ 0.69373919, 0.30626081]])),
544 (array([1, 2]), array([1, 1]),
545 array([[ 0.75015218, 0.24984782],
546 [ 0.9339767 , 0.0660233 ]])),
547 (array([1, 2]), array([1, 2]),
548 array([[ 0.97826616, 0.02173384],
549 [ 0.38620638, 0.61379362]])),
550 (array([2]), array([2]),
551 array([[ 0.46893776, 0.53106224]]))]
552 try:
553 cm = ConfusionMatrix(sets=sets)
554 except:
555 self.fail()
556 if externals.exists("pylab plottable"):
557 import pylab as P
558
559
560 fig, im, cb = cm.plot(origin='lower', numbers=True)
561
562 self.failUnless((cm._plotted_confusionmatrix == cm.matrix).all())
563 P.close(fig)
564
565
566
567
568
569 -def suite():
571
572
573 if __name__ == '__main__':
574 import runner
575