Package mvpa :: Package algorithms :: Module cvtranserror
[hide private]
[frames] | no frames]

Source Code for Module mvpa.algorithms.cvtranserror

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Cross-validate a classifier on a dataset""" 
 10   
 11  __docformat__ = 'restructuredtext' 
 12   
 13  from mvpa.support.copy import deepcopy 
 14   
 15  from mvpa.measures.base import DatasetMeasure 
 16  from mvpa.datasets.splitters import NoneSplitter 
 17  from mvpa.base import warning 
 18  from mvpa.misc.state import StateVariable, Harvestable 
 19  from mvpa.misc.transformers import GrandMean 
 20   
 21  if __debug__: 
 22      from mvpa.base import debug 
 23   
 24   
25 -class CrossValidatedTransferError(DatasetMeasure, Harvestable):
26 """Classifier cross-validation. 27 28 This class provides a simple interface to cross-validate a classifier 29 on datasets generated by a splitter from a single source dataset. 30 31 Arbitrary performance/error values can be computed by specifying an error 32 function (used to compute an error value for each cross-validation fold) 33 and a combiner function that aggregates all computed error values across 34 cross-validation folds. 35 """ 36 37 results = StateVariable(enabled=False, doc= 38 """Store individual results in the state""") 39 splits = StateVariable(enabled=False, doc= 40 """Store the actual splits of the data. Can be memory expensive""") 41 transerrors = StateVariable(enabled=False, doc= 42 """Store copies of transerrors at each step. If enabled - 43 operates on clones of transerror, but for the last split original 44 transerror is used""") 45 confusion = StateVariable(enabled=False, doc= 46 """Store total confusion matrix (if available)""") 47 training_confusion = StateVariable(enabled=False, doc= 48 """Store total training confusion matrix (if available)""") 49 samples_error = StateVariable(enabled=False, 50 doc="Per sample errors.") 51 52
53 - def __init__(self, 54 transerror, 55 splitter=None, 56 combiner='mean', 57 expose_testdataset=False, 58 harvest_attribs=None, 59 copy_attribs='copy', 60 **kwargs):
61 """ 62 :Parameters: 63 transerror: TransferError instance 64 Provides the classifier used for cross-validation. 65 splitter: Splitter | None 66 Used to split the dataset for cross-validation folds. By 67 convention the first dataset in the tuple returned by the 68 splitter is used to train the provided classifier. If the 69 first element is 'None' no training is performed. The second 70 dataset is used to generate predictions with the (trained) 71 classifier. If `None` (default) an instance of 72 :class:`~mvpa.datasets.splitters.NoneSplitter` is used. 73 combiner: Functor | 'mean' 74 Used to aggregate the error values of all cross-validation 75 folds. If 'mean' (default) the grand mean of the transfer 76 errors is computed. 77 expose_testdataset: bool 78 In the proper pipeline, classifier must not know anything 79 about testing data, but in some cases it might lead only 80 to marginal harm, thus migth wanted to be enabled (provide 81 testdataset for RFE to determine stopping point). 82 harvest_attribs: list of basestr 83 What attributes of call to store and return within 84 harvested state variable 85 copy_attribs: None | basestr 86 Force copying values of attributes on harvesting 87 **kwargs: 88 All additional arguments are passed to the 89 :class:`~mvpa.measures.base.DatasetMeasure` base class. 90 """ 91 DatasetMeasure.__init__(self, **kwargs) 92 Harvestable.__init__(self, harvest_attribs, copy_attribs) 93 94 if splitter is None: 95 self.__splitter = NoneSplitter() 96 else: 97 self.__splitter = splitter 98 99 if combiner == 'mean': 100 self.__combiner = GrandMean 101 else: 102 self.__combiner = combiner 103 104 self.__transerror = transerror 105 self.__expose_testdataset = expose_testdataset
106 107 # TODO: put back in ASAP 108 # def __repr__(self): 109 # """String summary over the object 110 # """ 111 # return """CrossValidatedTransferError / 112 # splitter: %s 113 # classifier: %s 114 # errorfx: %s 115 # combiner: %s""" % (indentDoc(self.__splitter), indentDoc(self.__clf), 116 # indentDoc(self.__errorfx), indentDoc(self.__combiner)) 117 118
119 - def _call(self, dataset):
120 """Perform cross-validation on a dataset. 121 122 'dataset' is passed to the splitter instance and serves as the source 123 dataset to generate split for the single cross-validation folds. 124 """ 125 # store the results of the splitprocessor 126 results = [] 127 self.states.splits = [] 128 129 # local bindings 130 states = self.states 131 clf = self.__transerror.clf 132 expose_testdataset = self.__expose_testdataset 133 134 # what states to enable in terr 135 terr_enable = [] 136 for state_var in ['confusion', 'training_confusion', 'samples_error']: 137 if states.isEnabled(state_var): 138 terr_enable += [state_var] 139 140 # charge states with initial values 141 summaryClass = clf._summaryClass 142 clf_hastestdataset = hasattr(clf, 'testdataset') 143 144 self.states.confusion = summaryClass() 145 self.states.training_confusion = summaryClass() 146 self.states.transerrors = [] 147 self.states.samples_error = dict([(id, []) for id in dataset.origids]) 148 149 # enable requested states in child TransferError instance (restored 150 # again below) 151 if len(terr_enable): 152 self.__transerror.states._changeTemporarily( 153 enable_states=terr_enable) 154 155 # We better ensure that underlying classifier is not trained if we 156 # are going to deepcopy transerror 157 if states.isEnabled("transerrors"): 158 self.__transerror.untrain() 159 160 # splitter 161 for split in self.__splitter(dataset): 162 # only train classifier if splitter provides something in first 163 # element of tuple -- the is the behavior of TransferError 164 if states.isEnabled("splits"): 165 self.states.splits.append(split) 166 167 if states.isEnabled("transerrors"): 168 # copy first and then train, as some classifiers cannot be copied 169 # when already trained, e.g. SWIG'ed stuff 170 lastsplit = None 171 for ds in split: 172 if ds is not None: 173 lastsplit = ds._dsattr['lastsplit'] 174 break 175 if lastsplit: 176 # only if we could deduce that it was last split 177 # use the 'mother' transerror 178 transerror = self.__transerror 179 else: 180 # otherwise -- deep copy 181 transerror = deepcopy(self.__transerror) 182 else: 183 transerror = self.__transerror 184 185 # assign testing dataset if given classifier can digest it 186 if clf_hastestdataset and expose_testdataset: 187 transerror.clf.testdataset = split[1] 188 189 # run the beast 190 result = transerror(split[1], split[0]) 191 192 # unbind the testdataset from the classifier 193 if clf_hastestdataset and expose_testdataset: 194 transerror.clf.testdataset = None 195 196 # next line is important for 'self._harvest' call 197 self._harvest(locals()) 198 199 # XXX Look below -- may be we should have not auto added .? 200 # then transerrors also could be deprecated 201 if states.isEnabled("transerrors"): 202 self.states.transerrors.append(transerror) 203 204 # XXX: could be merged with next for loop using a utility class 205 # that can add dict elements into a list 206 if states.isEnabled("samples_error"): 207 for k, v in \ 208 transerror.states.samples_error.iteritems(): 209 self.states.samples_error[k].append(v) 210 211 # pull in child states 212 for state_var in ['confusion', 'training_confusion']: 213 if states.isEnabled(state_var): 214 states[state_var].value.__iadd__( 215 transerror.states[state_var].value) 216 217 if __debug__: 218 debug("CROSSC", "Split #%d: result %s" \ 219 % (len(results), `result`)) 220 results.append(result) 221 222 # Since we could have operated with a copy -- bind the last used one back 223 self.__transerror = transerror 224 225 # put states of child TransferError back into original config 226 if len(terr_enable): 227 self.__transerror.states._resetEnabledTemporarily() 228 229 self.states.results = results 230 """Store state variable if it is enabled""" 231 232 # Provide those labels_map if appropriate 233 try: 234 if states.isEnabled("confusion"): 235 states.confusion.labels_map = dataset.labels_map 236 if states.isEnabled("training_confusion"): 237 states.training_confusion.labels_map = dataset.labels_map 238 except: 239 pass 240 241 return self.__combiner(results)
242 243 244 splitter = property(fget=lambda self:self.__splitter, 245 doc="Access to the Splitter instance.") 246 transerror = property(fget=lambda self:self.__transerror, 247 doc="Access to the TransferError instance.") 248 combiner = property(fget=lambda self:self.__combiner, 249 doc="Access to the configured combiner.")
250