1
2
3
4
5
6
7
8
9 """Base class for all classifiers.
10
11 At the moment, regressions are treated just as a special case of
12 classifier (or vise verse), so the same base class `Classifier` is
13 utilized for both kinds.
14 """
15
16 __docformat__ = 'restructuredtext'
17
18 import numpy as N
19
20 from mvpa.support.copy import deepcopy
21
22 import time
23
24 from mvpa.misc.support import idhash
25 from mvpa.misc.state import StateVariable, ClassWithCollections
26 from mvpa.misc.param import Parameter
27
28 from mvpa.clfs.transerror import ConfusionMatrix, RegressionStatistics
29
30 from mvpa.base import warning
31
32 if __debug__:
33 from mvpa.base import debug
36 """Base class for exceptions thrown by the learners (classifiers,
37 regressions)"""
38 pass
39
44
46 """Exception to be thrown whenever classifier fails to learn for
47 some reason"""
48 pass
49
51 """Exception to be thrown whenever classifier fails to provide predictions.
52 Usually happens if it was trained on degenerate data but without any complaints.
53 """
54 pass
55
57 """Abstract classifier class to be inherited by all classifiers
58 """
59
60
61
62 _DEV__doc__ = """
63 Required behavior:
64
65 For every classifier is has to be possible to be instantiated without
66 having to specify the training pattern.
67
68 Repeated calls to the train() method with different training data have to
69 result in a valid classifier, trained for the particular dataset.
70
71 It must be possible to specify all classifier parameters as keyword
72 arguments to the constructor.
73
74 Recommended behavior:
75
76 Derived classifiers should provide access to *values* -- i.e. that
77 information that is finally used to determine the predicted class label.
78
79 Michael: Maybe it works well if each classifier provides a 'values'
80 state member. This variable is a list as long as and in same order
81 as Dataset.uniquelabels (training data). Each item in the list
82 corresponds to the likelyhood of a sample to belong to the
83 respective class. However the semantics might differ between
84 classifiers, e.g. kNN would probably store distances to class-
85 neighbors, where PLR would store the raw function value of the
86 logistic function. So in the case of kNN low is predictive and for
87 PLR high is predictive. Don't know if there is the need to unify
88 that.
89
90 As the storage and/or computation of this information might be
91 demanding its collection should be switchable and off be default.
92
93 Nomenclature
94 * predictions : corresponds to the quantized labels if classifier spits
95 out labels by .predict()
96 * values : might be different from predictions if a classifier's predict()
97 makes a decision based on some internal value such as
98 probability or a distance.
99 """
100
101
102
103
104
105
106
107
108
109 trained_labels = StateVariable(enabled=True,
110 doc="Set of unique labels it has been trained on")
111
112 trained_nsamples = StateVariable(enabled=True,
113 doc="Number of samples it has been trained on")
114
115 trained_dataset = StateVariable(enabled=False,
116 doc="The dataset it has been trained on")
117
118 training_confusion = StateVariable(enabled=False,
119 doc="Confusion matrix of learning performance")
120
121 predictions = StateVariable(enabled=True,
122 doc="Most recent set of predictions")
123
124 values = StateVariable(enabled=True,
125 doc="Internal classifier values the most recent " +
126 "predictions are based on")
127
128 training_time = StateVariable(enabled=True,
129 doc="Time (in seconds) which took classifier to train")
130
131 predicting_time = StateVariable(enabled=True,
132 doc="Time (in seconds) which took classifier to predict")
133
134 feature_ids = StateVariable(enabled=False,
135 doc="Feature IDS which were used for the actual training.")
136
137 _clf_internals = []
138 """Describes some specifics about the classifier -- is that it is
139 doing regression for instance...."""
140
141 regression = Parameter(False, allowedtype='bool',
142 doc="""Either to use 'regression' as regression. By default any
143 Classifier-derived class serves as a classifier, so regression
144 does binary classification.""", index=1001)
145
146
147 retrainable = Parameter(False, allowedtype='bool',
148 doc="""Either to enable retraining for 'retrainable' classifier.""",
149 index=1002)
150
151
153 """Cheap initialization.
154 """
155 ClassWithCollections.__init__(self, **kwargs)
156
157
158 self.__trainednfeatures = None
159 """Stores number of features for which classifier was trained.
160 If None -- it wasn't trained at all"""
161
162 self._setRetrainable(self.params.retrainable, force=True)
163
164 if self.params.regression:
165 for statevar in [ "trained_labels"]:
166 if self.states.isEnabled(statevar):
167 if __debug__:
168 debug("CLF",
169 "Disabling state %s since doing regression, " %
170 statevar + "not classification")
171 self.states.disable(statevar)
172 self._summaryClass = RegressionStatistics
173 else:
174 self._summaryClass = ConfusionMatrix
175 clf_internals = self._clf_internals
176 if 'regression' in clf_internals and not ('binary' in clf_internals):
177
178
179
180 self._clf_internals = clf_internals + ['binary']
181
182
183
184
185
186
187
189 if __debug__ and 'CLF_' in debug.active:
190 return "%s / %s" % (repr(self), super(Classifier, self).__str__())
191 else:
192 return repr(self)
193
196
197
199 """Functionality prior to training
200 """
201
202
203 params = self.params
204 if not params.retrainable:
205 self.untrain()
206 else:
207
208 self.states.reset()
209 if not self.__changedData_isset:
210 self.__resetChangedData()
211 _changedData = self._changedData
212 __idhashes = self.__idhashes
213 __invalidatedChangedData = self.__invalidatedChangedData
214
215
216
217 if __debug__:
218 debug('CLF_', "IDHashes are %s" % (__idhashes))
219
220
221 for key, data_ in (('traindata', dataset.samples),
222 ('labels', dataset.labels)):
223 _changedData[key] = self.__wasDataChanged(key, data_)
224
225
226 if __invalidatedChangedData.get(key, False):
227 if __debug__ and not _changedData[key]:
228 debug('CLF_', 'Found that idhash for %s was '
229 'invalidated by retraining' % key)
230 _changedData[key] = True
231
232
233 for col in self._paramscols:
234 changedParams = self._collections[col].whichSet()
235 if len(changedParams):
236 _changedData[col] = changedParams
237
238 self.__invalidatedChangedData = {}
239
240 if __debug__:
241 debug('CLF_', "Obtained _changedData is %s"
242 % (self._changedData))
243
244 if not params.regression and 'regression' in self._clf_internals \
245 and not self.states.isEnabled('trained_labels'):
246
247
248 if __debug__:
249 debug("CLF", "Enabling trained_labels state since it is needed")
250 self.states.enable('trained_labels')
251
252
253 - def _posttrain(self, dataset):
254 """Functionality post training
255
256 For instance -- computing confusion matrix
257 :Parameters:
258 dataset : Dataset
259 Data which was used for training
260 """
261 if self.states.isEnabled('trained_labels'):
262 self.trained_labels = dataset.uniquelabels
263
264 self.trained_dataset = dataset
265 self.trained_nsamples = dataset.nsamples
266
267
268 self.__trainednfeatures = dataset.nfeatures
269
270 if __debug__ and 'CHECK_TRAINED' in debug.active:
271 self.__trainedidhash = dataset.idhash
272
273 if self.states.isEnabled('training_confusion') and \
274 not self.states.isSet('training_confusion'):
275
276
277 self.states._changeTemporarily(
278 disable_states=["predictions"])
279 if self.params.retrainable:
280
281
282
283
284
285 self.__changedData_isset = False
286 predictions = self.predict(dataset.samples)
287 self.states._resetEnabledTemporarily()
288 self.training_confusion = self._summaryClass(
289 targets=dataset.labels,
290 predictions=predictions)
291
292 try:
293 self.training_confusion.labels_map = dataset.labels_map
294 except:
295 pass
296
297 if self.states.isEnabled('feature_ids'):
298 self.feature_ids = self._getFeatureIds()
299
300
302 """Virtual method to return feature_ids used while training
303
304 Is not intended to be called anywhere but from _posttrain,
305 thus classifier is assumed to be trained at this point
306 """
307
308 return range(self.__trainednfeatures)
309
310
312 """Providing summary over the classifier"""
313
314 s = "Classifier %s" % self
315 states = self.states
316 states_enabled = states.enabled
317
318 if self.trained:
319 s += "\n trained"
320 if states.isSet('training_time'):
321 s += ' in %.3g sec' % states.training_time
322 s += ' on data with'
323 if states.isSet('trained_labels'):
324 s += ' labels:%s' % list(states.trained_labels)
325
326 nsamples, nchunks = None, None
327 if states.isSet('trained_nsamples'):
328 nsamples = states.trained_nsamples
329 if states.isSet('trained_dataset'):
330 td = states.trained_dataset
331 nsamples, nchunks = td.nsamples, len(td.uniquechunks)
332 if nsamples is not None:
333 s += ' #samples:%d' % nsamples
334 if nchunks is not None:
335 s += ' #chunks:%d' % nchunks
336
337 s += " #features:%d" % self.__trainednfeatures
338 if states.isSet('feature_ids'):
339 s += ", used #features:%d" % len(states.feature_ids)
340 if states.isSet('training_confusion'):
341 s += ", training error:%.3g" % states.training_confusion.error
342 else:
343 s += "\n not yet trained"
344
345 if len(states_enabled):
346 s += "\n enabled states:%s" % ', '.join([str(states[x])
347 for x in states_enabled])
348 return s
349
350
352 """Create full copy of the classifier.
353
354 It might require classifier to be untrained first due to
355 present SWIG bindings.
356
357 TODO: think about proper re-implementation, without enrollment of deepcopy
358 """
359 if __debug__:
360 debug("CLF", "Cloning %s#%s" % (self, id(self)))
361 try:
362 return deepcopy(self)
363 except:
364 self.untrain()
365 return deepcopy(self)
366
367
369 """Function to be actually overridden in derived classes
370 """
371 raise NotImplementedError
372
373
374 - def train(self, dataset):
375 """Train classifier on a dataset
376
377 Shouldn't be overridden in subclasses unless explicitly needed
378 to do so
379 """
380 if dataset.nfeatures == 0 or dataset.nsamples == 0:
381 raise DegenerateInputError(
382 "Cannot train classifier %s on degenerate data %s"
383 % (self, dataset))
384 if __debug__:
385 debug("CLF", "Training classifier %(clf)s on dataset %(dataset)s",
386 msgargs={'clf':self, 'dataset':dataset})
387
388 self._pretrain(dataset)
389
390
391 t0 = time.time()
392
393 if dataset.nfeatures > 0:
394 result = self._train(dataset)
395 else:
396 warning("Trying to train on dataset with no features present")
397 if __debug__:
398 debug("CLF",
399 "No features present for training, no actual training " \
400 "is called")
401 result = None
402
403 self.training_time = time.time() - t0
404 self._posttrain(dataset)
405 return result
406
407
409 """Functionality prior prediction
410 """
411 if not ('notrain2predict' in self._clf_internals):
412
413 if not self.trained:
414 raise ValueError, \
415 "Classifier %s wasn't yet trained, therefore can't " \
416 "predict" % self
417 nfeatures = data.shape[1]
418
419
420 if nfeatures != self.__trainednfeatures:
421 raise ValueError, \
422 "Classifier %s was trained on data with %d features, " % \
423 (self, self.__trainednfeatures) + \
424 "thus can't predict for %d features" % nfeatures
425
426
427 if self.params.retrainable:
428 if not self.__changedData_isset:
429 self.__resetChangedData()
430 _changedData = self._changedData
431 _changedData['testdata'] = \
432 self.__wasDataChanged('testdata', data)
433 if __debug__:
434 debug('CLF_', "prepredict: Obtained _changedData is %s"
435 % (_changedData))
436
437
438 - def _postpredict(self, data, result):
439 """Functionality after prediction is computed
440 """
441 self.predictions = result
442 if self.params.retrainable:
443 self.__changedData_isset = False
444
446 """Actual prediction
447 """
448 raise NotImplementedError
449
450
452 """Predict classifier on data
453
454 Shouldn't be overridden in subclasses unless explicitly needed
455 to do so. Also subclasses trying to call super class's predict
456 should call _predict if within _predict instead of predict()
457 since otherwise it would loop
458 """
459 data = N.asarray(data)
460 if __debug__:
461 debug("CLF", "Predicting classifier %(clf)s on data %(data)s",
462 msgargs={'clf':self, 'data':data.shape})
463
464
465 t0 = time.time()
466
467 states = self.states
468
469
470 states.reset(['values', 'predictions'])
471
472 self._prepredict(data)
473
474 if self.__trainednfeatures > 0 \
475 or 'notrain2predict' in self._clf_internals:
476 result = self._predict(data)
477 else:
478 warning("Trying to predict using classifier trained on no features")
479 if __debug__:
480 debug("CLF",
481 "No features were present for training, prediction is " \
482 "bogus")
483 result = [None]*data.shape[0]
484
485 states.predicting_time = time.time() - t0
486
487 if 'regression' in self._clf_internals and not self.params.regression:
488
489
490
491
492
493
494
495
496 result_ = N.array(result)
497 if states.isEnabled('values'):
498
499
500 if not states.isSet('values'):
501 states.values = result_.copy()
502 else:
503
504
505
506 states.values = N.array(states.values, copy=True)
507
508 trained_labels = self.trained_labels
509 for i, value in enumerate(result):
510 dists = N.abs(value - trained_labels)
511 result[i] = trained_labels[N.argmin(dists)]
512
513 if __debug__:
514 debug("CLF_", "Converted regression result %(result_)s "
515 "into labels %(result)s for %(self_)s",
516 msgargs={'result_':result_, 'result':result,
517 'self_': self})
518
519 self._postpredict(data, result)
520 return result
521
522
524 """Either classifier was already trained.
525
526 MUST BE USED WITH CARE IF EVER"""
527 if dataset is None:
528
529 return not self.__trainednfeatures is None
530 else:
531 res = (self.__trainednfeatures == dataset.nfeatures)
532 if __debug__ and 'CHECK_TRAINED' in debug.active:
533 res2 = (self.__trainedidhash == dataset.idhash)
534 if res2 != res:
535 raise RuntimeError, \
536 "isTrained is weak and shouldn't be relied upon. " \
537 "Got result %b although comparing of idhash says %b" \
538 % (res, res2)
539 return res
540
541
543 """Some classifiers like BinaryClassifier can't be used for
544 regression"""
545
546 if self.params.regression:
547 raise ValueError, "Regression mode is meaningless for %s" % \
548 self.__class__.__name__ + " thus don't enable it"
549
550
551 @property
553 """Either classifier was already trained"""
554 return self.isTrained()
555
557 """Reset trained state"""
558 self.__trainednfeatures = None
559
560
561
562
563
564
565 super(Classifier, self).reset()
566
567
569 """Factory method to return an appropriate sensitivity analyzer for
570 the respective classifier."""
571 raise NotImplementedError
572
573
574
575
576
578 """Assign value of retrainable parameter
579
580 If retrainable flag is to be changed, classifier has to be
581 untrained. Also internal attributes such as _changedData,
582 __changedData_isset, and __idhashes should be initialized if
583 it becomes retrainable
584 """
585 pretrainable = self.params['retrainable']
586 if (force or value != pretrainable.value) \
587 and 'retrainable' in self._clf_internals:
588 if __debug__:
589 debug("CLF_", "Setting retrainable to %s" % value)
590 if 'meta' in self._clf_internals:
591 warning("Retrainability is not yet crafted/tested for "
592 "meta classifiers. Unpredictable behavior might occur")
593
594 if self.trained:
595 self.untrain()
596 states = self.states
597 if not value and states.isKnown('retrained'):
598 states.remove('retrained')
599 states.remove('repredicted')
600 if value:
601 if not 'retrainable' in self._clf_internals:
602 warning("Setting of flag retrainable for %s has no effect"
603 " since classifier has no such capability. It would"
604 " just lead to resources consumption and slowdown"
605 % self)
606 states.add(StateVariable(enabled=True,
607 name='retrained',
608 doc="Either retrainable classifier was retrained"))
609 states.add(StateVariable(enabled=True,
610 name='repredicted',
611 doc="Either retrainable classifier was repredicted"))
612
613 pretrainable.value = value
614
615
616 if value:
617 self.__idhashes = {'traindata': None, 'labels': None,
618 'testdata': None}
619 if __debug__ and 'CHECK_RETRAIN' in debug.active:
620
621
622
623
624 self.__trained = self.__idhashes.copy()
625 self.__resetChangedData()
626 self.__invalidatedChangedData = {}
627 elif 'retrainable' in self._clf_internals:
628
629 self.__changedData_isset = False
630 self._changedData = None
631 self.__idhashes = None
632 if __debug__ and 'CHECK_RETRAIN' in debug.active:
633 self.__trained = None
634
636 """For retrainable classifier we keep track of what was changed
637 This function resets that dictionary
638 """
639 if __debug__:
640 debug('CLF_',
641 'Retrainable: resetting flags on either data was changed')
642 keys = self.__idhashes.keys() + self._paramscols
643
644
645
646
647
648 self._changedData = dict(zip(keys, [False]*len(keys)))
649 self.__changedData_isset = False
650
651
653 """Check if given entry was changed from what known prior.
654
655 If so -- store only the ones needed for retrainable beastie
656 """
657 idhash_ = idhash(entry)
658 __idhashes = self.__idhashes
659
660 changed = __idhashes[key] != idhash_
661 if __debug__ and 'CHECK_RETRAIN' in debug.active:
662 __trained = self.__trained
663 changed2 = entry != __trained[key]
664 if isinstance(changed2, N.ndarray):
665 changed2 = changed2.any()
666 if changed != changed2 and not changed:
667 raise RuntimeError, \
668 'idhash found to be weak for %s. Though hashid %s!=%s %s, '\
669 'values %s!=%s %s' % \
670 (key, idhash_, __idhashes[key], changed,
671 entry, __trained[key], changed2)
672 if update:
673 __trained[key] = entry
674
675 if __debug__ and changed:
676 debug('CLF_', "Changed %s from %s to %s.%s"
677 % (key, __idhashes[key], idhash_,
678 ('','updated')[int(update)]))
679 if update:
680 __idhashes[key] = idhash_
681
682 return changed
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714 - def retrain(self, dataset, **kwargs):
715 """Helper to avoid check if data was changed actually changed
716
717 Useful if just some aspects of classifier were changed since
718 its previous training. For instance if dataset wasn't changed
719 but only classifier parameters, then kernel matrix does not
720 have to be computed.
721
722 Words of caution: classifier must be previously trained,
723 results always should first be compared to the results on not
724 'retrainable' classifier (without calling retrain). Some
725 additional checks are enabled if debug id 'CHECK_RETRAIN' is
726 enabled, to guard against obvious mistakes.
727
728 :Parameters:
729 kwargs
730 that is what _changedData gets updated with. So, smth like
731 ``(params=['C'], labels=True)`` if parameter C and labels
732 got changed
733 """
734
735
736 if __debug__:
737 if not self.params.retrainable:
738 raise RuntimeError, \
739 "Do not use re(train,predict) on non-retrainable %s" % \
740 self
741
742 if kwargs.has_key('params') or kwargs.has_key('kernel_params'):
743 raise ValueError, \
744 "Retraining for changed params not working yet"
745
746 self.__resetChangedData()
747
748
749 chd = self._changedData
750 ichd = self.__invalidatedChangedData
751
752 chd.update(kwargs)
753
754
755 for key, value in kwargs.iteritems():
756 if value:
757 ichd[key] = True
758 self.__changedData_isset = True
759
760
761 if __debug__ and 'CHECK_RETRAIN' in debug.active:
762 for key, data_ in (('traindata', dataset.samples),
763 ('labels', dataset.labels)):
764
765 if not chd[key] and not ichd.get(key, False):
766 if self.__wasDataChanged(key, data_, update=False):
767 raise RuntimeError, \
768 "Data %s found changed although wasn't " \
769 "labeled as such" % key
770
771
772
773
774
775
776 if __debug__ and 'CHECK_RETRAIN' in debug.active and self.trained \
777 and not self._changedData['traindata'] \
778 and self.__trained['traindata'].shape != dataset.samples.shape:
779 raise ValueError, "In retrain got dataset with %s size, " \
780 "whenever previousely was trained on %s size" \
781 % (dataset.samples.shape, self.__trained['traindata'].shape)
782 self.train(dataset)
783
784
786 """Helper to avoid check if data was changed actually changed
787
788 Useful if classifier was (re)trained but with the same data
789 (so just parameters were changed), so that it could be
790 repredicted easily (on the same data as before) without
791 recomputing for instance train/test kernel matrix. Should be
792 used with caution and always compared to the results on not
793 'retrainable' classifier. Some additional checks are enabled
794 if debug id 'CHECK_RETRAIN' is enabled, to guard against
795 obvious mistakes.
796
797 :Parameters:
798 data
799 data which is conventionally given to predict
800 kwargs
801 that is what _changedData gets updated with. So, smth like
802 ``(params=['C'], labels=True)`` if parameter C and labels
803 got changed
804 """
805 if len(kwargs)>0:
806 raise RuntimeError, \
807 "repredict for now should be used without params since " \
808 "it makes little sense to repredict if anything got changed"
809 if __debug__ and not self.params.retrainable:
810 raise RuntimeError, \
811 "Do not use retrain/repredict on non-retrainable classifiers"
812
813 self.__resetChangedData()
814 chd = self._changedData
815 chd.update(**kwargs)
816 self.__changedData_isset = True
817
818
819
820 if __debug__ and 'CHECK_RETRAIN' in debug.active:
821 for key, data_ in (('testdata', data),):
822
823
824 if self.__wasDataChanged(key, data_, update=False):
825 raise RuntimeError, \
826 "Data %s found changed although wasn't " \
827 "labeled as such" % key
828
829
830
831 if __debug__ and 'CHECK_RETRAIN' in debug.active \
832 and not self._changedData['testdata'] \
833 and self.__trained['testdata'].shape != data.shape:
834 raise ValueError, "In repredict got dataset with %s size, " \
835 "whenever previously was trained on %s size" \
836 % (data.shape, self.__trained['testdata'].shape)
837
838 return self.predict(data)
839
840
841
842
843
844