Package mvpa :: Package tests :: Module test_transerror
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_transerror

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Unit tests for PyMVPA classifier cross-validation""" 
 10   
 11  import unittest 
 12  from mvpa.support.copy import copy 
 13   
 14  from mvpa.base import externals 
 15  from mvpa.datasets import Dataset 
 16  from mvpa.datasets.splitters import OddEvenSplitter 
 17   
 18  from mvpa.clfs.meta import MulticlassClassifier 
 19  from mvpa.clfs.transerror import \ 
 20       TransferError, ConfusionMatrix, ConfusionBasedError 
 21  from mvpa.algorithms.cvtranserror import CrossValidatedTransferError 
 22   
 23  from mvpa.clfs.stats import MCNullDist 
 24   
 25  from mvpa.misc.exceptions import UnknownStateError 
 26   
 27  from tests_warehouse import datasets, sweepargs 
 28  from tests_warehouse_clfs import * 
29 30 -class ErrorsTests(unittest.TestCase):
31
32 - def testConfusionMatrix(self):
33 data = N.array([1,2,1,2,2,2,3,2,1], ndmin=2).T 34 reg = [1,1,1,2,2,2,3,3,3] 35 regl = [1,2,1,2,2,2,3,2,1] 36 correct_cm = [[2,0,1],[1,3,1],[0,0,1]] 37 # Check if we are ok with any input type - either list, or N.array, or tuple 38 for t in [reg, tuple(reg), list(reg), N.array(reg)]: 39 for p in [regl, tuple(regl), list(regl), N.array(regl)]: 40 cm = ConfusionMatrix(targets=t, predictions=p) 41 # check table content 42 self.failUnless((cm.matrix == correct_cm).all()) 43 44 45 # Do a bit more thorough checking 46 cm = ConfusionMatrix() 47 self.failUnlessRaises(ZeroDivisionError, lambda x:x.percentCorrect, cm) 48 """No samples -- raise exception""" 49 50 cm.add(reg, regl) 51 52 self.failUnlessEqual(len(cm.sets), 1, 53 msg="Should have a single set so far") 54 self.failUnlessEqual(cm.matrix.shape, (3,3), 55 msg="should be square matrix (len(reglabels) x len(reglabels)") 56 57 self.failUnlessRaises(ValueError, cm.add, reg, N.array([1])) 58 """ConfusionMatrix must complaint if number of samples different""" 59 60 # check table content 61 self.failUnless((cm.matrix == correct_cm).all()) 62 63 # lets add with new labels (not yet known) 64 cm.add(reg, N.array([1,4,1,2,2,2,4,2,1])) 65 66 self.failUnlessEqual(cm.labels, [1,2,3,4], 67 msg="We should have gotten 4th label") 68 69 matrices = cm.matrices # separate CM per each given set 70 self.failUnlessEqual(len(matrices), 2, 71 msg="Have gotten two splits") 72 73 self.failUnless((matrices[0].matrix + matrices[1].matrix == cm.matrix).all(), 74 msg="Total votes should match the sum across split CMs") 75 76 # check pretty print 77 # just a silly test to make sure that printing works 78 self.failUnless(len(cm.asstring( 79 header=True, summary=True, 80 description=True))>100) 81 self.failUnless(len(str(cm))>100) 82 # and that it knows some parameters for printing 83 self.failUnless(len(cm.asstring(summary=True, 84 header=False))>100) 85 86 # lets check iadd -- just itself to itself 87 cm += cm 88 self.failUnlessEqual(len(cm.matrices), 4, msg="Must be 4 sets now") 89 90 # lets check add -- just itself to itself 91 cm2 = cm + cm 92 self.failUnlessEqual(len(cm2.matrices), 8, msg="Must be 8 sets now") 93 self.failUnlessEqual(cm2.percentCorrect, cm.percentCorrect, 94 msg="Percent of corrrect should remain the same ;-)") 95 96 self.failUnlessEqual(cm2.error, 1.0-cm.percentCorrect/100.0, 97 msg="Test if we get proper error value")
98 99
100 - def testDegenerateConfusion(self):
101 # We must not just puke -- some testing splits might 102 # have just a single target label 103 104 for orig in ([1], [1, 1], [0], [0, 0]): 105 cm = ConfusionMatrix(targets=orig, predictions=orig, values=orig) 106 107 scm = str(cm) 108 self.failUnless(cm.stats['ACC%'] == 100)
109 110
111 - def testConfusionMatrixACC(self):
112 reg = [0,0,1,1] 113 regl = [1,0,1,0] 114 cm = ConfusionMatrix(targets=reg, predictions=regl) 115 self.failUnless('ACC% 50' in str(cm))
116 117
119 data = N.array([1,2,1,2,2,2,3,2,1], ndmin=2).T 120 reg = [1,1,1,2,2,2,3,3,3] 121 regl = [1,2,1,2,2,2,3,2,1] 122 correct_cm = [[2,0,1], [1,3,1], [0,0,1]] 123 lm = {'apple':1, 'orange':2, 'shitty apple':1, 'candy':3} 124 cm = ConfusionMatrix(targets=reg, predictions=regl, 125 labels_map=lm) 126 # check table content 127 self.failUnless((cm.matrix == correct_cm).all()) 128 # assure that all labels are somewhere listed ;-) 129 s = str(cm) 130 for l in lm.keys(): 131 self.failUnless(l in s)
132 133 134 135 @sweepargs(l_clf=clfswh['linear', 'svm'])
136 - def testConfusionBasedError(self, l_clf):
137 train = datasets['uni2medium_train'] 138 # to check if we fail to classify for 3 labels 139 test3 = datasets['uni3medium_train'] 140 err = ConfusionBasedError(clf=l_clf) 141 terr = TransferError(clf=l_clf) 142 143 self.failUnlessRaises(UnknownStateError, err, None) 144 """Shouldn't be able to access the state yet""" 145 146 l_clf.train(train) 147 e, te = err(None), terr(train) 148 self.failUnless(abs(e-te) < 1e-10, 149 msg="ConfusionBasedError (%.2g) should be equal to TransferError " 150 "(%.2g) on traindataset" % (e, te)) 151 152 # this will print nasty WARNING but it is ok -- it is just checking code 153 # NB warnings are not printed while doing whole testing 154 self.failIf(terr(test3) is None) 155 156 # try copying the beast 157 terr_copy = copy(terr)
158 159 160 @sweepargs(l_clf=clfswh['linear', 'svm'])
161 - def testNullDistProb(self, l_clf):
162 train = datasets['uni2medium'] 163 164 num_perm = 10 165 # define class to estimate NULL distribution of errors 166 # use left tail of the distribution since we use MeanMatchFx as error 167 # function and lower is better 168 terr = TransferError( 169 clf=l_clf, 170 null_dist=MCNullDist(permutations=num_perm, 171 tail='left')) 172 173 # check reasonable error range 174 err = terr(train, train) 175 self.failUnless(err < 0.4) 176 177 # Lets do the same for CVTE 178 cvte = CrossValidatedTransferError( 179 TransferError(clf=l_clf), 180 OddEvenSplitter(), 181 null_dist=MCNullDist(permutations=num_perm, 182 tail='left', 183 enable_states=['dist_samples'])) 184 cv_err = cvte(train) 185 186 # check that the result is highly significant since we know that the 187 # data has signal 188 null_prob = terr.states.null_prob 189 if cfg.getboolean('tests', 'labile', default='yes'): 190 self.failUnless(null_prob <= 0.1, 191 msg="Failed to check that the result is highly significant " 192 "(got %f) since we know that the data has signal" 193 % null_prob) 194 195 self.failUnless(cvte.states.null_prob <= 0.1, 196 msg="Failed to check that the result is highly significant " 197 "(got p(cvte)=%f) since we know that the data has signal" 198 % cvte.states.null_prob) 199 200 # and we should be able to access the actual samples of the distribution 201 self.failUnlessEqual(len(cvte.null_dist.states.dist_samples), 202 num_perm)
203 204 205 @sweepargs(l_clf=clfswh['linear', 'svm'])
206 - def testPerSampleError(self, l_clf):
207 train = datasets['uni2medium'] 208 terr = TransferError(clf=l_clf, enable_states=['samples_error']) 209 err = terr(train, train) 210 se = terr.samples_error 211 212 # one error per sample 213 self.failUnless(len(se) == train.nsamples) 214 # for this simple test it can only be correct or misclassified 215 # (boolean) 216 self.failUnless( 217 N.sum(N.array(se.values(), dtype='float') \ 218 - N.array(se.values(), dtype='b')) == 0)
219 220 221 @sweepargs(clf=clfswh['multiclass'])
222 - def testAUC(self, clf):
223 """Test AUC computation 224 """ 225 if isinstance(clf, MulticlassClassifier): 226 # TODO: handle those values correctly 227 return 228 clf.states._changeTemporarily(enable_states = ['values']) 229 # uni2 dataset with reordered labels 230 ds2 = datasets['uni2small'].copy() 231 ds2.labels = 1 - ds2.labels # revert labels 232 # same with uni3 233 ds3 = datasets['uni3small'].copy() 234 ul = ds3.uniquelabels 235 nl = ds3.labels.copy() 236 for l in xrange(3): 237 nl[ds3.labels == ul[l]] = ul[(l+1)%3] 238 ds3.labels = nl 239 for ds in [datasets['uni2small'], ds2, 240 datasets['uni3small'], ds3]: 241 cv = CrossValidatedTransferError( 242 TransferError(clf), 243 OddEvenSplitter(), 244 enable_states=['confusion', 'training_confusion']) 245 cverror = cv(ds) 246 stats = cv.confusion.stats 247 Nlabels = len(ds.uniquelabels) 248 # so we at least do slightly above chance 249 self.failUnless(stats['ACC'] > 1.2 / Nlabels) 250 auc = stats['AUC'] 251 if (Nlabels == 2) or (Nlabels > 2 and auc[0] is not N.nan): 252 mauc = N.min(stats['AUC']) 253 if cfg.getboolean('tests', 'labile', default='yes'): 254 self.failUnless(mauc > 0.55, 255 msg='All AUCs must be above chance. Got minimal ' 256 'AUC=%.2g among %s' % (mauc, stats['AUC'])) 257 clf.states._resetEnabledTemporarily()
258 259 260 261
262 - def testConfusionPlot(self):
263 """Based on existing cell dataset results. 264 265 Let in for possible future testing, but is not a part of the 266 unittests suite 267 """ 268 #from matplotlib import rc as rcmpl 269 #rcmpl('font',**{'family':'sans-serif','sans-serif':['DejaVu Sans']}) 270 ##rcmpl('text', usetex=True) 271 ##rcmpl('font', family='sans', style='normal', variant='normal', 272 ## weight='bold', stretch='normal', size='large') 273 #import numpy as N 274 #from mvpa.clfs.transerror import \ 275 # TransferError, ConfusionMatrix, ConfusionBasedError 276 277 array = N.array 278 uint8 = N.uint8 279 sets = [ 280 (array([47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 281 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 282 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 283 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 284 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 285 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 286 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 287 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 288 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 289 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 290 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 291 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 292 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44], dtype=uint8), 293 array([40, 39, 47, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45, 41, 44, 294 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 46, 295 45, 38, 44, 39, 46, 38, 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 296 40, 47, 43, 45, 41, 44, 40, 46, 42, 38, 39, 40, 43, 45, 41, 44, 39, 297 46, 42, 47, 38, 38, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 298 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 38, 299 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 47, 43, 45, 41, 44, 40, 46, 300 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 45, 41, 301 44, 47, 46, 42, 47, 38, 39, 43, 45, 40, 44, 40, 46, 42, 47, 39, 40, 302 43, 45, 41, 44, 38, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 41, 303 47, 39, 38, 46, 45, 41, 44, 40, 46, 42, 40, 38, 38, 43, 45, 41, 44, 304 40, 45, 42, 47, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 42, 43, 305 45, 41, 44, 39, 46, 42, 39, 39, 39, 47, 45, 41, 44], dtype=uint8)), 306 (array([40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 307 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 308 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 309 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 310 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 311 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 312 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 313 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 314 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 315 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 316 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 317 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 318 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43], dtype=uint8), 319 array([40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 47, 46, 42, 47, 39, 40, 43, 320 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 321 39, 38, 43, 45, 41, 44, 39, 46, 42, 47, 47, 47, 43, 45, 41, 44, 40, 322 46, 42, 43, 39, 38, 43, 45, 41, 44, 38, 38, 42, 38, 39, 38, 43, 45, 323 41, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 40, 42, 47, 40, 324 40, 43, 45, 41, 44, 38, 38, 42, 47, 38, 38, 47, 45, 41, 44, 40, 46, 325 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 47, 39, 43, 45, 41, 326 44, 40, 46, 42, 39, 39, 42, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 327 43, 45, 41, 44, 47, 46, 42, 40, 39, 39, 43, 45, 41, 44, 40, 46, 42, 328 47, 39, 38, 43, 45, 40, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 329 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 46, 47, 38, 39, 43, 330 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 331 39, 38, 47, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43], dtype=uint8)), 332 (array([45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 333 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 334 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 335 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 336 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 337 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 338 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 339 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 340 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 341 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 342 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 343 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 344 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47], dtype=uint8), 345 array([45, 41, 44, 40, 46, 42, 47, 39, 46, 43, 45, 41, 44, 40, 46, 42, 47, 346 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 347 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 43, 43, 45, 348 40, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 349 40, 43, 45, 41, 44, 40, 47, 42, 38, 47, 38, 43, 45, 41, 44, 40, 40, 350 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 351 44, 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 40, 38, 352 43, 45, 41, 44, 40, 46, 38, 38, 39, 38, 43, 45, 41, 44, 39, 46, 42, 353 47, 40, 39, 43, 45, 38, 44, 38, 46, 42, 47, 47, 40, 43, 45, 41, 44, 354 40, 40, 42, 47, 40, 38, 43, 39, 41, 44, 41, 46, 42, 39, 39, 38, 38, 355 45, 41, 44, 38, 46, 40, 46, 46, 46, 43, 45, 38, 44, 40, 46, 42, 39, 356 39, 45, 43, 45, 41, 44, 38, 46, 42, 38, 39, 39, 43, 45, 41, 38, 40, 357 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 40], dtype=uint8)), 358 (array([39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 359 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 360 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 361 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 362 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 363 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 364 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 365 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 366 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 367 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 368 39, 38, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 41, 44, 40, 46, 369 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 370 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40], dtype=uint8), 371 array([39, 38, 43, 45, 41, 44, 40, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40, 372 46, 42, 47, 39, 38, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45, 373 41, 44, 40, 38, 43, 47, 38, 38, 43, 45, 41, 44, 39, 46, 42, 39, 39, 374 38, 43, 45, 41, 44, 43, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 375 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 39, 38, 38, 43, 45, 40, 376 44, 47, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 377 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 47, 44, 45, 46, 42, 378 38, 39, 41, 43, 45, 41, 44, 38, 38, 42, 39, 40, 40, 43, 45, 41, 39, 379 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 47, 42, 47, 38, 38, 43, 380 45, 41, 44, 47, 46, 42, 47, 40, 47, 43, 45, 41, 44, 40, 46, 42, 47, 381 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 46, 44, 38, 46, 382 42, 47, 38, 44, 43, 45, 42, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41, 383 44, 38, 46, 42, 39, 39, 38, 43, 45, 41, 44, 40], dtype=uint8)), 384 (array([46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 385 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 386 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 387 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 388 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 389 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 390 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 391 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 392 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 393 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 394 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 395 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 396 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45], dtype=uint8), 397 array([46, 42, 39, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 42, 43, 45, 398 42, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 399 40, 43, 45, 41, 44, 41, 46, 42, 38, 39, 38, 43, 45, 41, 44, 38, 46, 400 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 46, 38, 38, 43, 45, 41, 401 44, 39, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 402 43, 45, 41, 44, 40, 47, 42, 47, 38, 39, 43, 45, 41, 44, 39, 46, 42, 403 47, 39, 46, 43, 45, 41, 44, 39, 46, 42, 39, 39, 38, 43, 45, 41, 44, 404 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 405 45, 41, 44, 40, 38, 42, 46, 39, 38, 43, 45, 41, 44, 38, 46, 42, 46, 406 46, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 38, 38, 45, 41, 44, 38, 407 38, 42, 43, 39, 40, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 47, 45, 408 46, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 46, 42, 47, 40, 409 38, 43, 45, 41, 44, 38, 46, 42, 38, 39, 38, 47, 45], dtype=uint8)), 410 (array([41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 411 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 412 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 413 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 414 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 415 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 416 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 417 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 418 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 419 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 420 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 421 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 422 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39], dtype=uint8), 423 array([41, 44, 38, 46, 42, 47, 39, 47, 40, 45, 41, 44, 40, 46, 42, 38, 40, 424 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45, 41, 44, 46, 38, 425 42, 40, 38, 39, 43, 45, 41, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41, 426 44, 40, 46, 42, 38, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 43, 39, 427 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 428 40, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39, 39, 39, 43, 45, 41, 44, 429 40, 46, 42, 39, 38, 47, 43, 45, 38, 44, 40, 38, 42, 47, 38, 38, 43, 430 45, 41, 44, 40, 38, 46, 47, 38, 38, 43, 45, 41, 44, 41, 46, 42, 40, 431 38, 38, 40, 45, 41, 44, 40, 40, 42, 43, 38, 40, 43, 39, 41, 44, 40, 432 40, 42, 47, 38, 46, 43, 45, 41, 44, 47, 41, 42, 43, 40, 47, 43, 45, 433 41, 44, 41, 38, 42, 40, 39, 40, 43, 45, 41, 44, 39, 43, 42, 47, 39, 434 40, 43, 45, 41, 44, 42, 46, 42, 47, 40, 46, 43, 45, 41, 44, 38, 46, 435 42, 47, 47, 38, 43, 45, 41, 44, 40, 38, 39, 47, 38], dtype=uint8)), 436 (array([38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 437 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 438 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 439 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 440 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 441 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 442 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 443 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 444 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 445 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 446 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 447 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 448 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46], dtype=uint8), 449 array([39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 41, 46, 450 42, 47, 47, 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 451 44, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 45, 38, 452 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 453 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 454 40, 46, 42, 47, 40, 39, 43, 45, 41, 44, 40, 39, 42, 40, 39, 38, 43, 455 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39, 456 39, 47, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 457 46, 42, 46, 47, 39, 47, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 458 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 47, 39, 459 38, 43, 45, 42, 44, 39, 47, 42, 39, 39, 47, 43, 47, 40, 44, 40, 46, 460 42, 39, 39, 38, 39, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 461 44, 46, 38, 42, 47, 39, 43, 43, 45, 41, 44, 40, 46], dtype=uint8)), 462 (array([42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 463 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 464 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 465 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 466 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 467 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 468 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 469 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 470 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 471 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 472 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 473 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 474 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45], dtype=uint8), 475 array([42, 38, 38, 40, 43, 45, 41, 44, 39, 46, 42, 47, 39, 38, 43, 45, 41, 476 44, 39, 38, 42, 47, 41, 40, 43, 45, 41, 44, 40, 41, 42, 47, 38, 46, 477 43, 45, 41, 44, 41, 41, 42, 40, 39, 39, 43, 45, 41, 44, 46, 45, 42, 478 39, 39, 40, 43, 45, 41, 44, 40, 46, 42, 40, 44, 38, 43, 41, 41, 44, 479 39, 46, 42, 39, 39, 39, 43, 45, 41, 44, 40, 43, 42, 47, 39, 39, 43, 480 45, 41, 44, 40, 47, 42, 38, 46, 39, 47, 45, 41, 44, 39, 46, 42, 47, 481 41, 38, 43, 45, 41, 44, 42, 46, 42, 46, 39, 38, 43, 45, 41, 44, 41, 482 46, 42, 46, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45, 483 41, 44, 38, 46, 42, 39, 40, 43, 43, 45, 41, 44, 39, 38, 40, 40, 38, 484 38, 43, 45, 41, 44, 41, 40, 42, 39, 39, 39, 43, 45, 41, 44, 40, 46, 485 42, 47, 40, 40, 43, 45, 41, 44, 40, 46, 42, 41, 39, 39, 43, 45, 41, 486 44, 40, 38, 42, 40, 39, 46, 43, 45, 41, 44, 47, 46, 42, 47, 39, 38, 487 43, 45, 41, 44, 41, 46, 42, 43, 39, 39, 43, 45], dtype=uint8))] 488 labels_map = {'12kHz': 40, 489 '20kHz': 41, 490 '30kHz': 42, 491 '3kHz': 38, 492 '7kHz': 39, 493 'song1': 43, 494 'song2': 44, 495 'song3': 45, 496 'song4': 46, 497 'song5': 47} 498 try: 499 cm = ConfusionMatrix(sets=sets, labels_map=labels_map) 500 except: 501 self.fail() 502 self.failUnless('3kHz / 38' in cm.asstring()) 503 504 if externals.exists("pylab plottable"): 505 import pylab as P 506 P.figure() 507 labels_order = ("3kHz", "7kHz", "12kHz", "20kHz","30kHz", None, 508 "song1","song2","song3","song4","song5") 509 #print cm 510 #fig, im, cb = cm.plot(origin='lower', labels=labels_order) 511 fig, im, cb = cm.plot(labels=labels_order[1:2] + labels_order[:1] 512 + labels_order[2:], numbers=True) 513 self.failUnless(cm._plotted_confusionmatrix[0,0] == cm.matrix[1,1]) 514 self.failUnless(cm._plotted_confusionmatrix[0,1] == cm.matrix[1,0]) 515 self.failUnless(cm._plotted_confusionmatrix[1,1] == cm.matrix[0,0]) 516 self.failUnless(cm._plotted_confusionmatrix[1,0] == cm.matrix[0,1]) 517 P.close(fig) 518 fig, im, cb = cm.plot(labels=labels_order, numbers=True) 519 P.close(fig)
520 # P.show() 521
522 - def testConfusionPlot2(self):
523 """Based on a sample confusion which plots incorrectly 524 525 """ 526 527 array = N.array 528 uint8 = N.uint8 529 sets = [(array([1, 2]), array([1, 1]), 530 array([[ 0.54343765, 0.45656235], 531 [ 0.92395853, 0.07604147]])), 532 (array([1, 2]), array([1, 1]), 533 array([[ 0.98030832, 0.01969168], 534 [ 0.78998763, 0.21001237]])), 535 (array([1, 2]), array([1, 1]), 536 array([[ 0.86125263, 0.13874737], 537 [ 0.83674113, 0.16325887]])), 538 (array([1, 2]), array([1, 1]), 539 array([[ 0.57870383, 0.42129617], 540 [ 0.59702509, 0.40297491]])), 541 (array([1, 2]), array([1, 1]), 542 array([[ 0.89530255, 0.10469745], 543 [ 0.69373919, 0.30626081]])), 544 (array([1, 2]), array([1, 1]), 545 array([[ 0.75015218, 0.24984782], 546 [ 0.9339767 , 0.0660233 ]])), 547 (array([1, 2]), array([1, 2]), 548 array([[ 0.97826616, 0.02173384], 549 [ 0.38620638, 0.61379362]])), 550 (array([2]), array([2]), 551 array([[ 0.46893776, 0.53106224]]))] 552 try: 553 cm = ConfusionMatrix(sets=sets) 554 except: 555 self.fail() 556 if externals.exists("pylab plottable"): 557 import pylab as P 558 #P.figure() 559 #print cm 560 fig, im, cb = cm.plot(origin='lower', numbers=True) 561 #P.plot() 562 self.failUnless((cm._plotted_confusionmatrix == cm.matrix).all()) 563 P.close(fig)
564 #fig, im, cb = cm.plot(labels=labels_order, numbers=True)
565 #P.close(fig) 566 #P.show() 567 568 569 -def suite():
570 return unittest.makeSuite(ErrorsTests)
571 572 573 if __name__ == '__main__': 574 import runner 575