Package mvpa :: Package clfs :: Module base
[hide private]
[frames] | no frames]

Source Code for Module mvpa.clfs.base

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Base class for all classifiers. 
 10   
 11  At the moment, regressions are treated just as a special case of 
 12  classifier (or vise verse), so the same base class `Classifier` is 
 13  utilized for both kinds. 
 14  """ 
 15   
 16  __docformat__ = 'restructuredtext' 
 17   
 18  import numpy as N 
 19   
 20  from mvpa.support.copy import deepcopy 
 21   
 22  import time 
 23   
 24  from mvpa.misc.support import idhash 
 25  from mvpa.misc.state import StateVariable, ClassWithCollections 
 26  from mvpa.misc.param import Parameter 
 27   
 28  from mvpa.clfs.transerror import ConfusionMatrix, RegressionStatistics 
 29   
 30  from mvpa.base import warning 
 31   
 32  if __debug__: 
 33      from mvpa.base import debug 
34 35 -class LearnerError(Exception):
36 """Base class for exceptions thrown by the learners (classifiers, 37 regressions)""" 38 pass
39
40 -class DegenerateInputError(LearnerError):
41 """Exception to be thrown by learners if input data is bogus, i.e. no 42 features or samples""" 43 pass
44
45 -class FailedToTrainError(LearnerError):
46 """Exception to be thrown whenever classifier fails to learn for 47 some reason""" 48 pass
49
50 -class FailedToPredictError(LearnerError):
51 """Exception to be thrown whenever classifier fails to provide predictions. 52 Usually happens if it was trained on degenerate data but without any complaints. 53 """ 54 pass
55
56 -class Classifier(ClassWithCollections):
57 """Abstract classifier class to be inherited by all classifiers 58 """ 59 60 # Kept separate from doc to don't pollute help(clf), especially if 61 # we including help for the parent class 62 _DEV__doc__ = """ 63 Required behavior: 64 65 For every classifier is has to be possible to be instantiated without 66 having to specify the training pattern. 67 68 Repeated calls to the train() method with different training data have to 69 result in a valid classifier, trained for the particular dataset. 70 71 It must be possible to specify all classifier parameters as keyword 72 arguments to the constructor. 73 74 Recommended behavior: 75 76 Derived classifiers should provide access to *values* -- i.e. that 77 information that is finally used to determine the predicted class label. 78 79 Michael: Maybe it works well if each classifier provides a 'values' 80 state member. This variable is a list as long as and in same order 81 as Dataset.uniquelabels (training data). Each item in the list 82 corresponds to the likelyhood of a sample to belong to the 83 respective class. However the semantics might differ between 84 classifiers, e.g. kNN would probably store distances to class- 85 neighbors, where PLR would store the raw function value of the 86 logistic function. So in the case of kNN low is predictive and for 87 PLR high is predictive. Don't know if there is the need to unify 88 that. 89 90 As the storage and/or computation of this information might be 91 demanding its collection should be switchable and off be default. 92 93 Nomenclature 94 * predictions : corresponds to the quantized labels if classifier spits 95 out labels by .predict() 96 * values : might be different from predictions if a classifier's predict() 97 makes a decision based on some internal value such as 98 probability or a distance. 99 """ 100 # Dict that contains the parameters of a classifier. 101 # This shall provide an interface to plug generic parameter optimizer 102 # on all classifiers (e.g. grid- or line-search optimizer) 103 # A dictionary is used because Michael thinks that access by name is nicer. 104 # Additionally Michael thinks ATM that additional information might be 105 # necessary in some situations (e.g. reasonably predefined parameter range, 106 # minimal iteration stepsize, ...), therefore the value to each key should 107 # also be a dict or we should use mvpa.misc.param.Parameter'... 108 109 trained_labels = StateVariable(enabled=True, 110 doc="Set of unique labels it has been trained on") 111 112 trained_nsamples = StateVariable(enabled=True, 113 doc="Number of samples it has been trained on") 114 115 trained_dataset = StateVariable(enabled=False, 116 doc="The dataset it has been trained on") 117 118 training_confusion = StateVariable(enabled=False, 119 doc="Confusion matrix of learning performance") 120 121 predictions = StateVariable(enabled=True, 122 doc="Most recent set of predictions") 123 124 values = StateVariable(enabled=True, 125 doc="Internal classifier values the most recent " + 126 "predictions are based on") 127 128 training_time = StateVariable(enabled=True, 129 doc="Time (in seconds) which took classifier to train") 130 131 predicting_time = StateVariable(enabled=True, 132 doc="Time (in seconds) which took classifier to predict") 133 134 feature_ids = StateVariable(enabled=False, 135 doc="Feature IDS which were used for the actual training.") 136 137 _clf_internals = [] 138 """Describes some specifics about the classifier -- is that it is 139 doing regression for instance....""" 140 141 regression = Parameter(False, allowedtype='bool', 142 doc="""Either to use 'regression' as regression. By default any 143 Classifier-derived class serves as a classifier, so regression 144 does binary classification.""", index=1001) 145 146 # TODO: make it available only for actually retrainable classifiers 147 retrainable = Parameter(False, allowedtype='bool', 148 doc="""Either to enable retraining for 'retrainable' classifier.""", 149 index=1002) 150 151
152 - def __init__(self, **kwargs):
153 """Cheap initialization. 154 """ 155 ClassWithCollections.__init__(self, **kwargs) 156 157 158 self.__trainednfeatures = None 159 """Stores number of features for which classifier was trained. 160 If None -- it wasn't trained at all""" 161 162 self._setRetrainable(self.params.retrainable, force=True) 163 164 if self.params.regression: 165 for statevar in [ "trained_labels"]: #, "training_confusion" ]: 166 if self.states.isEnabled(statevar): 167 if __debug__: 168 debug("CLF", 169 "Disabling state %s since doing regression, " % 170 statevar + "not classification") 171 self.states.disable(statevar) 172 self._summaryClass = RegressionStatistics 173 else: 174 self._summaryClass = ConfusionMatrix 175 clf_internals = self._clf_internals 176 if 'regression' in clf_internals and not ('binary' in clf_internals): 177 # regressions are used as binary classifiers if not 178 # asked to perform regression explicitly 179 # We need a copy of the list, so we don't override class-wide 180 self._clf_internals = clf_internals + ['binary']
181 182 # deprecate 183 #self.__trainedidhash = None 184 #"""Stores id of the dataset on which it was trained to signal 185 #in trained() if it was trained already on the same dataset""" 186 187
188 - def __str__(self):
189 if __debug__ and 'CLF_' in debug.active: 190 return "%s / %s" % (repr(self), super(Classifier, self).__str__()) 191 else: 192 return repr(self)
193
194 - def __repr__(self, prefixes=[]):
195 return super(Classifier, self).__repr__(prefixes=prefixes)
196 197
198 - def _pretrain(self, dataset):
199 """Functionality prior to training 200 """ 201 # So we reset all state variables and may be free up some memory 202 # explicitly 203 params = self.params 204 if not params.retrainable: 205 self.untrain() 206 else: 207 # just reset the states, do not untrain 208 self.states.reset() 209 if not self.__changedData_isset: 210 self.__resetChangedData() 211 _changedData = self._changedData 212 __idhashes = self.__idhashes 213 __invalidatedChangedData = self.__invalidatedChangedData 214 215 # if we don't know what was changed we need to figure 216 # them out 217 if __debug__: 218 debug('CLF_', "IDHashes are %s" % (__idhashes)) 219 220 # Look at the data if any was changed 221 for key, data_ in (('traindata', dataset.samples), 222 ('labels', dataset.labels)): 223 _changedData[key] = self.__wasDataChanged(key, data_) 224 # if those idhashes were invalidated by retraining 225 # we need to adjust _changedData accordingly 226 if __invalidatedChangedData.get(key, False): 227 if __debug__ and not _changedData[key]: 228 debug('CLF_', 'Found that idhash for %s was ' 229 'invalidated by retraining' % key) 230 _changedData[key] = True 231 232 # Look at the parameters 233 for col in self._paramscols: 234 changedParams = self._collections[col].whichSet() 235 if len(changedParams): 236 _changedData[col] = changedParams 237 238 self.__invalidatedChangedData = {} # reset it on training 239 240 if __debug__: 241 debug('CLF_', "Obtained _changedData is %s" 242 % (self._changedData)) 243 244 if not params.regression and 'regression' in self._clf_internals \ 245 and not self.states.isEnabled('trained_labels'): 246 # if classifier internally does regression we need to have 247 # labels it was trained on 248 if __debug__: 249 debug("CLF", "Enabling trained_labels state since it is needed") 250 self.states.enable('trained_labels')
251 252
253 - def _posttrain(self, dataset):
254 """Functionality post training 255 256 For instance -- computing confusion matrix 257 :Parameters: 258 dataset : Dataset 259 Data which was used for training 260 """ 261 if self.states.isEnabled('trained_labels'): 262 self.trained_labels = dataset.uniquelabels 263 264 self.trained_dataset = dataset 265 self.trained_nsamples = dataset.nsamples 266 267 # needs to be assigned first since below we use predict 268 self.__trainednfeatures = dataset.nfeatures 269 270 if __debug__ and 'CHECK_TRAINED' in debug.active: 271 self.__trainedidhash = dataset.idhash 272 273 if self.states.isEnabled('training_confusion') and \ 274 not self.states.isSet('training_confusion'): 275 # we should not store predictions for training data, 276 # it is confusing imho (yoh) 277 self.states._changeTemporarily( 278 disable_states=["predictions"]) 279 if self.params.retrainable: 280 # we would need to recheck if data is the same, 281 # XXX think if there is a way to make this all 282 # efficient. For now, probably, retrainable 283 # classifiers have no chance but not to use 284 # training_confusion... sad 285 self.__changedData_isset = False 286 predictions = self.predict(dataset.samples) 287 self.states._resetEnabledTemporarily() 288 self.training_confusion = self._summaryClass( 289 targets=dataset.labels, 290 predictions=predictions) 291 292 try: 293 self.training_confusion.labels_map = dataset.labels_map 294 except: 295 pass 296 297 if self.states.isEnabled('feature_ids'): 298 self.feature_ids = self._getFeatureIds()
299 300
301 - def _getFeatureIds(self):
302 """Virtual method to return feature_ids used while training 303 304 Is not intended to be called anywhere but from _posttrain, 305 thus classifier is assumed to be trained at this point 306 """ 307 # By default all features are used 308 return range(self.__trainednfeatures)
309 310
311 - def summary(self):
312 """Providing summary over the classifier""" 313 314 s = "Classifier %s" % self 315 states = self.states 316 states_enabled = states.enabled 317 318 if self.trained: 319 s += "\n trained" 320 if states.isSet('training_time'): 321 s += ' in %.3g sec' % states.training_time 322 s += ' on data with' 323 if states.isSet('trained_labels'): 324 s += ' labels:%s' % list(states.trained_labels) 325 326 nsamples, nchunks = None, None 327 if states.isSet('trained_nsamples'): 328 nsamples = states.trained_nsamples 329 if states.isSet('trained_dataset'): 330 td = states.trained_dataset 331 nsamples, nchunks = td.nsamples, len(td.uniquechunks) 332 if nsamples is not None: 333 s += ' #samples:%d' % nsamples 334 if nchunks is not None: 335 s += ' #chunks:%d' % nchunks 336 337 s += " #features:%d" % self.__trainednfeatures 338 if states.isSet('feature_ids'): 339 s += ", used #features:%d" % len(states.feature_ids) 340 if states.isSet('training_confusion'): 341 s += ", training error:%.3g" % states.training_confusion.error 342 else: 343 s += "\n not yet trained" 344 345 if len(states_enabled): 346 s += "\n enabled states:%s" % ', '.join([str(states[x]) 347 for x in states_enabled]) 348 return s
349 350
351 - def clone(self):
352 """Create full copy of the classifier. 353 354 It might require classifier to be untrained first due to 355 present SWIG bindings. 356 357 TODO: think about proper re-implementation, without enrollment of deepcopy 358 """ 359 if __debug__: 360 debug("CLF", "Cloning %s#%s" % (self, id(self))) 361 try: 362 return deepcopy(self) 363 except: 364 self.untrain() 365 return deepcopy(self)
366 367
368 - def _train(self, dataset):
369 """Function to be actually overridden in derived classes 370 """ 371 raise NotImplementedError
372 373
374 - def train(self, dataset):
375 """Train classifier on a dataset 376 377 Shouldn't be overridden in subclasses unless explicitly needed 378 to do so 379 """ 380 if dataset.nfeatures == 0 or dataset.nsamples == 0: 381 raise DegenerateInputError( 382 "Cannot train classifier %s on degenerate data %s" 383 % (self, dataset)) 384 if __debug__: 385 debug("CLF", "Training classifier %(clf)s on dataset %(dataset)s", 386 msgargs={'clf':self, 'dataset':dataset}) 387 388 self._pretrain(dataset) 389 390 # remember the time when started training 391 t0 = time.time() 392 393 if dataset.nfeatures > 0: 394 result = self._train(dataset) 395 else: 396 warning("Trying to train on dataset with no features present") 397 if __debug__: 398 debug("CLF", 399 "No features present for training, no actual training " \ 400 "is called") 401 result = None 402 403 self.training_time = time.time() - t0 404 self._posttrain(dataset) 405 return result
406 407
408 - def _prepredict(self, data):
409 """Functionality prior prediction 410 """ 411 if not ('notrain2predict' in self._clf_internals): 412 # check if classifier was trained if that is needed 413 if not self.trained: 414 raise ValueError, \ 415 "Classifier %s wasn't yet trained, therefore can't " \ 416 "predict" % self 417 nfeatures = data.shape[1] 418 # check if number of features is the same as in the data 419 # it was trained on 420 if nfeatures != self.__trainednfeatures: 421 raise ValueError, \ 422 "Classifier %s was trained on data with %d features, " % \ 423 (self, self.__trainednfeatures) + \ 424 "thus can't predict for %d features" % nfeatures 425 426 427 if self.params.retrainable: 428 if not self.__changedData_isset: 429 self.__resetChangedData() 430 _changedData = self._changedData 431 _changedData['testdata'] = \ 432 self.__wasDataChanged('testdata', data) 433 if __debug__: 434 debug('CLF_', "prepredict: Obtained _changedData is %s" 435 % (_changedData))
436 437
438 - def _postpredict(self, data, result):
439 """Functionality after prediction is computed 440 """ 441 self.predictions = result 442 if self.params.retrainable: 443 self.__changedData_isset = False
444
445 - def _predict(self, data):
446 """Actual prediction 447 """ 448 raise NotImplementedError
449 450
451 - def predict(self, data):
452 """Predict classifier on data 453 454 Shouldn't be overridden in subclasses unless explicitly needed 455 to do so. Also subclasses trying to call super class's predict 456 should call _predict if within _predict instead of predict() 457 since otherwise it would loop 458 """ 459 data = N.asarray(data) 460 if __debug__: 461 debug("CLF", "Predicting classifier %(clf)s on data %(data)s", 462 msgargs={'clf':self, 'data':data.shape}) 463 464 # remember the time when started computing predictions 465 t0 = time.time() 466 467 states = self.states 468 # to assure that those are reset (could be set due to testing 469 # post-training) 470 states.reset(['values', 'predictions']) 471 472 self._prepredict(data) 473 474 if self.__trainednfeatures > 0 \ 475 or 'notrain2predict' in self._clf_internals: 476 result = self._predict(data) 477 else: 478 warning("Trying to predict using classifier trained on no features") 479 if __debug__: 480 debug("CLF", 481 "No features were present for training, prediction is " \ 482 "bogus") 483 result = [None]*data.shape[0] 484 485 states.predicting_time = time.time() - t0 486 487 if 'regression' in self._clf_internals and not self.params.regression: 488 # We need to convert regression values into labels 489 # XXX unify may be labels -> internal_labels conversion. 490 #if len(self.trained_labels) != 2: 491 # raise RuntimeError, "Ask developer to implement for " \ 492 # "multiclass mapping from regression into classification" 493 494 # must be N.array so we copy it to assign labels directly 495 # into labels, or should we just recreate "result"??? 496 result_ = N.array(result) 497 if states.isEnabled('values'): 498 # values could be set by now so assigning 'result' would 499 # be misleading 500 if not states.isSet('values'): 501 states.values = result_.copy() 502 else: 503 # it might be the values are pointing to result at 504 # the moment, so lets assure this silly way that 505 # they do not overlap 506 states.values = N.array(states.values, copy=True) 507 508 trained_labels = self.trained_labels 509 for i, value in enumerate(result): 510 dists = N.abs(value - trained_labels) 511 result[i] = trained_labels[N.argmin(dists)] 512 513 if __debug__: 514 debug("CLF_", "Converted regression result %(result_)s " 515 "into labels %(result)s for %(self_)s", 516 msgargs={'result_':result_, 'result':result, 517 'self_': self}) 518 519 self._postpredict(data, result) 520 return result
521 522 # deprecate ???
523 - def isTrained(self, dataset=None):
524 """Either classifier was already trained. 525 526 MUST BE USED WITH CARE IF EVER""" 527 if dataset is None: 528 # simply return if it was trained on anything 529 return not self.__trainednfeatures is None 530 else: 531 res = (self.__trainednfeatures == dataset.nfeatures) 532 if __debug__ and 'CHECK_TRAINED' in debug.active: 533 res2 = (self.__trainedidhash == dataset.idhash) 534 if res2 != res: 535 raise RuntimeError, \ 536 "isTrained is weak and shouldn't be relied upon. " \ 537 "Got result %b although comparing of idhash says %b" \ 538 % (res, res2) 539 return res
540 541
542 - def _regressionIsBogus(self):
543 """Some classifiers like BinaryClassifier can't be used for 544 regression""" 545 546 if self.params.regression: 547 raise ValueError, "Regression mode is meaningless for %s" % \ 548 self.__class__.__name__ + " thus don't enable it"
549 550 551 @property
552 - def trained(self):
553 """Either classifier was already trained""" 554 return self.isTrained()
555
556 - def untrain(self):
557 """Reset trained state""" 558 self.__trainednfeatures = None 559 # probably not needed... retrainable shouldn't be fully untrained 560 # or should be??? 561 #if self.params.retrainable: 562 # # ??? don't duplicate the code ;-) 563 # self.__idhashes = {'traindata': None, 'labels': None, 564 # 'testdata': None, 'testtraindata': None} 565 super(Classifier, self).reset()
566 567
568 - def getSensitivityAnalyzer(self, **kwargs):
569 """Factory method to return an appropriate sensitivity analyzer for 570 the respective classifier.""" 571 raise NotImplementedError
572 573 574 # 575 # Methods which are needed for retrainable classifiers 576 #
577 - def _setRetrainable(self, value, force=False):
578 """Assign value of retrainable parameter 579 580 If retrainable flag is to be changed, classifier has to be 581 untrained. Also internal attributes such as _changedData, 582 __changedData_isset, and __idhashes should be initialized if 583 it becomes retrainable 584 """ 585 pretrainable = self.params['retrainable'] 586 if (force or value != pretrainable.value) \ 587 and 'retrainable' in self._clf_internals: 588 if __debug__: 589 debug("CLF_", "Setting retrainable to %s" % value) 590 if 'meta' in self._clf_internals: 591 warning("Retrainability is not yet crafted/tested for " 592 "meta classifiers. Unpredictable behavior might occur") 593 # assure that we don't drag anything behind 594 if self.trained: 595 self.untrain() 596 states = self.states 597 if not value and states.isKnown('retrained'): 598 states.remove('retrained') 599 states.remove('repredicted') 600 if value: 601 if not 'retrainable' in self._clf_internals: 602 warning("Setting of flag retrainable for %s has no effect" 603 " since classifier has no such capability. It would" 604 " just lead to resources consumption and slowdown" 605 % self) 606 states.add(StateVariable(enabled=True, 607 name='retrained', 608 doc="Either retrainable classifier was retrained")) 609 states.add(StateVariable(enabled=True, 610 name='repredicted', 611 doc="Either retrainable classifier was repredicted")) 612 613 pretrainable.value = value 614 615 # if retrainable we need to keep track of things 616 if value: 617 self.__idhashes = {'traindata': None, 'labels': None, 618 'testdata': None} #, 'testtraindata': None} 619 if __debug__ and 'CHECK_RETRAIN' in debug.active: 620 # ??? it is not clear though if idhash is faster than 621 # simple comparison of (dataset != __traineddataset).any(), 622 # but if we like to get rid of __traineddataset then we 623 # should use idhash anyways 624 self.__trained = self.__idhashes.copy() # just same Nones 625 self.__resetChangedData() 626 self.__invalidatedChangedData = {} 627 elif 'retrainable' in self._clf_internals: 628 #self.__resetChangedData() 629 self.__changedData_isset = False 630 self._changedData = None 631 self.__idhashes = None 632 if __debug__ and 'CHECK_RETRAIN' in debug.active: 633 self.__trained = None
634
635 - def __resetChangedData(self):
636 """For retrainable classifier we keep track of what was changed 637 This function resets that dictionary 638 """ 639 if __debug__: 640 debug('CLF_', 641 'Retrainable: resetting flags on either data was changed') 642 keys = self.__idhashes.keys() + self._paramscols 643 # we might like to just reinit values to False??? 644 #_changedData = self._changedData 645 #if isinstance(_changedData, dict): 646 # for key in _changedData.keys(): 647 # _changedData[key] = False 648 self._changedData = dict(zip(keys, [False]*len(keys))) 649 self.__changedData_isset = False
650 651
652 - def __wasDataChanged(self, key, entry, update=True):
653 """Check if given entry was changed from what known prior. 654 655 If so -- store only the ones needed for retrainable beastie 656 """ 657 idhash_ = idhash(entry) 658 __idhashes = self.__idhashes 659 660 changed = __idhashes[key] != idhash_ 661 if __debug__ and 'CHECK_RETRAIN' in debug.active: 662 __trained = self.__trained 663 changed2 = entry != __trained[key] 664 if isinstance(changed2, N.ndarray): 665 changed2 = changed2.any() 666 if changed != changed2 and not changed: 667 raise RuntimeError, \ 668 'idhash found to be weak for %s. Though hashid %s!=%s %s, '\ 669 'values %s!=%s %s' % \ 670 (key, idhash_, __idhashes[key], changed, 671 entry, __trained[key], changed2) 672 if update: 673 __trained[key] = entry 674 675 if __debug__ and changed: 676 debug('CLF_', "Changed %s from %s to %s.%s" 677 % (key, __idhashes[key], idhash_, 678 ('','updated')[int(update)])) 679 if update: 680 __idhashes[key] = idhash_ 681 682 return changed
683 684 685 # def __updateHashIds(self, key, data): 686 # """Is twofold operation: updates hashid if was said that it changed. 687 # 688 # or if it wasn't said that data changed, but CHECK_RETRAIN and it found 689 # to be changed -- raise Exception 690 # """ 691 # 692 # check_retrain = __debug__ and 'CHECK_RETRAIN' in debug.active 693 # chd = self._changedData 694 # 695 # # we need to updated idhashes 696 # if chd[key] or check_retrain: 697 # keychanged = self.__wasDataChanged(key, data) 698 # if check_retrain and keychanged and not chd[key]: 699 # raise RuntimeError, \ 700 # "Data %s found changed although wasn't " \ 701 # "labeled as such" % key 702 703 704 # 705 # Additional API which is specific only for retrainable classifiers. 706 # For now it would just puke if asked from not retrainable one. 707 # 708 # Might come useful and efficient for statistics testing, so if just 709 # labels of dataset changed, then 710 # self.retrain(dataset, labels=True) 711 # would cause efficient retraining (no kernels recomputed etc) 712 # and subsequent self.repredict(data) should be also quite fase ;-) 713
714 - def retrain(self, dataset, **kwargs):
715 """Helper to avoid check if data was changed actually changed 716 717 Useful if just some aspects of classifier were changed since 718 its previous training. For instance if dataset wasn't changed 719 but only classifier parameters, then kernel matrix does not 720 have to be computed. 721 722 Words of caution: classifier must be previously trained, 723 results always should first be compared to the results on not 724 'retrainable' classifier (without calling retrain). Some 725 additional checks are enabled if debug id 'CHECK_RETRAIN' is 726 enabled, to guard against obvious mistakes. 727 728 :Parameters: 729 kwargs 730 that is what _changedData gets updated with. So, smth like 731 ``(params=['C'], labels=True)`` if parameter C and labels 732 got changed 733 """ 734 # Note that it also demolishes anything for repredicting, 735 # which should be ok in most of the cases 736 if __debug__: 737 if not self.params.retrainable: 738 raise RuntimeError, \ 739 "Do not use re(train,predict) on non-retrainable %s" % \ 740 self 741 742 if kwargs.has_key('params') or kwargs.has_key('kernel_params'): 743 raise ValueError, \ 744 "Retraining for changed params not working yet" 745 746 self.__resetChangedData() 747 748 # local bindings 749 chd = self._changedData 750 ichd = self.__invalidatedChangedData 751 752 chd.update(kwargs) 753 # mark for future 'train()' items which are explicitely 754 # mentioned as changed 755 for key, value in kwargs.iteritems(): 756 if value: 757 ichd[key] = True 758 self.__changedData_isset = True 759 760 # To check if we are not fooled 761 if __debug__ and 'CHECK_RETRAIN' in debug.active: 762 for key, data_ in (('traindata', dataset.samples), 763 ('labels', dataset.labels)): 764 # so it wasn't told to be invalid 765 if not chd[key] and not ichd.get(key, False): 766 if self.__wasDataChanged(key, data_, update=False): 767 raise RuntimeError, \ 768 "Data %s found changed although wasn't " \ 769 "labeled as such" % key 770 771 # TODO: parameters of classifiers... for now there is explicit 772 # 'forbidance' above 773 774 # Below check should be superseeded by check above, thus never occur. 775 # remove later on ??? 776 if __debug__ and 'CHECK_RETRAIN' in debug.active and self.trained \ 777 and not self._changedData['traindata'] \ 778 and self.__trained['traindata'].shape != dataset.samples.shape: 779 raise ValueError, "In retrain got dataset with %s size, " \ 780 "whenever previousely was trained on %s size" \ 781 % (dataset.samples.shape, self.__trained['traindata'].shape) 782 self.train(dataset)
783 784
785 - def repredict(self, data, **kwargs):
786 """Helper to avoid check if data was changed actually changed 787 788 Useful if classifier was (re)trained but with the same data 789 (so just parameters were changed), so that it could be 790 repredicted easily (on the same data as before) without 791 recomputing for instance train/test kernel matrix. Should be 792 used with caution and always compared to the results on not 793 'retrainable' classifier. Some additional checks are enabled 794 if debug id 'CHECK_RETRAIN' is enabled, to guard against 795 obvious mistakes. 796 797 :Parameters: 798 data 799 data which is conventionally given to predict 800 kwargs 801 that is what _changedData gets updated with. So, smth like 802 ``(params=['C'], labels=True)`` if parameter C and labels 803 got changed 804 """ 805 if len(kwargs)>0: 806 raise RuntimeError, \ 807 "repredict for now should be used without params since " \ 808 "it makes little sense to repredict if anything got changed" 809 if __debug__ and not self.params.retrainable: 810 raise RuntimeError, \ 811 "Do not use retrain/repredict on non-retrainable classifiers" 812 813 self.__resetChangedData() 814 chd = self._changedData 815 chd.update(**kwargs) 816 self.__changedData_isset = True 817 818 819 # check if we are attempted to perform on the same data 820 if __debug__ and 'CHECK_RETRAIN' in debug.active: 821 for key, data_ in (('testdata', data),): 822 # so it wasn't told to be invalid 823 #if not chd[key]:# and not ichd.get(key, False): 824 if self.__wasDataChanged(key, data_, update=False): 825 raise RuntimeError, \ 826 "Data %s found changed although wasn't " \ 827 "labeled as such" % key 828 829 # Should be superseded by above 830 # remove in future??? 831 if __debug__ and 'CHECK_RETRAIN' in debug.active \ 832 and not self._changedData['testdata'] \ 833 and self.__trained['testdata'].shape != data.shape: 834 raise ValueError, "In repredict got dataset with %s size, " \ 835 "whenever previously was trained on %s size" \ 836 % (data.shape, self.__trained['testdata'].shape) 837 838 return self.predict(data)
839 840 841 # TODO: callback into retrainable parameter 842 #retrainable = property(fget=_getRetrainable, fset=_setRetrainable, 843 # doc="Specifies either classifier should be retrainable") 844