1
2
3
4
5
6
7
8
9 """Cross-validate a classifier on a dataset"""
10
11 __docformat__ = 'restructuredtext'
12
13 from mvpa.support.copy import deepcopy
14
15 from mvpa.measures.base import DatasetMeasure
16 from mvpa.datasets.splitters import NoneSplitter
17 from mvpa.base import warning
18 from mvpa.misc.state import StateVariable, Harvestable
19 from mvpa.misc.transformers import GrandMean
20
21 if __debug__:
22 from mvpa.base import debug
23
24
26 """Classifier cross-validation.
27
28 This class provides a simple interface to cross-validate a classifier
29 on datasets generated by a splitter from a single source dataset.
30
31 Arbitrary performance/error values can be computed by specifying an error
32 function (used to compute an error value for each cross-validation fold)
33 and a combiner function that aggregates all computed error values across
34 cross-validation folds.
35 """
36
37 results = StateVariable(enabled=False, doc=
38 """Store individual results in the state""")
39 splits = StateVariable(enabled=False, doc=
40 """Store the actual splits of the data. Can be memory expensive""")
41 transerrors = StateVariable(enabled=False, doc=
42 """Store copies of transerrors at each step. If enabled -
43 operates on clones of transerror, but for the last split original
44 transerror is used""")
45 confusion = StateVariable(enabled=False, doc=
46 """Store total confusion matrix (if available)""")
47 training_confusion = StateVariable(enabled=False, doc=
48 """Store total training confusion matrix (if available)""")
49 samples_error = StateVariable(enabled=False,
50 doc="Per sample errors.")
51
52
53 - def __init__(self,
54 transerror,
55 splitter=None,
56 combiner='mean',
57 expose_testdataset=False,
58 harvest_attribs=None,
59 copy_attribs='copy',
60 **kwargs):
61 """
62 :Parameters:
63 transerror: TransferError instance
64 Provides the classifier used for cross-validation.
65 splitter: Splitter | None
66 Used to split the dataset for cross-validation folds. By
67 convention the first dataset in the tuple returned by the
68 splitter is used to train the provided classifier. If the
69 first element is 'None' no training is performed. The second
70 dataset is used to generate predictions with the (trained)
71 classifier. If `None` (default) an instance of
72 :class:`~mvpa.datasets.splitters.NoneSplitter` is used.
73 combiner: Functor | 'mean'
74 Used to aggregate the error values of all cross-validation
75 folds. If 'mean' (default) the grand mean of the transfer
76 errors is computed.
77 expose_testdataset: bool
78 In the proper pipeline, classifier must not know anything
79 about testing data, but in some cases it might lead only
80 to marginal harm, thus migth wanted to be enabled (provide
81 testdataset for RFE to determine stopping point).
82 harvest_attribs: list of basestr
83 What attributes of call to store and return within
84 harvested state variable
85 copy_attribs: None | basestr
86 Force copying values of attributes on harvesting
87 **kwargs:
88 All additional arguments are passed to the
89 :class:`~mvpa.measures.base.DatasetMeasure` base class.
90 """
91 DatasetMeasure.__init__(self, **kwargs)
92 Harvestable.__init__(self, harvest_attribs, copy_attribs)
93
94 if splitter is None:
95 self.__splitter = NoneSplitter()
96 else:
97 self.__splitter = splitter
98
99 if combiner == 'mean':
100 self.__combiner = GrandMean
101 else:
102 self.__combiner = combiner
103
104 self.__transerror = transerror
105 self.__expose_testdataset = expose_testdataset
106
107
108
109
110
111
112
113
114
115
116
117
118
119 - def _call(self, dataset):
120 """Perform cross-validation on a dataset.
121
122 'dataset' is passed to the splitter instance and serves as the source
123 dataset to generate split for the single cross-validation folds.
124 """
125
126 results = []
127 self.states.splits = []
128
129
130 states = self.states
131 clf = self.__transerror.clf
132 expose_testdataset = self.__expose_testdataset
133
134
135 terr_enable = []
136 for state_var in ['confusion', 'training_confusion', 'samples_error']:
137 if states.isEnabled(state_var):
138 terr_enable += [state_var]
139
140
141 summaryClass = clf._summaryClass
142 clf_hastestdataset = hasattr(clf, 'testdataset')
143
144 self.states.confusion = summaryClass()
145 self.states.training_confusion = summaryClass()
146 self.states.transerrors = []
147 self.states.samples_error = dict([(id, []) for id in dataset.origids])
148
149
150
151 if len(terr_enable):
152 self.__transerror.states._changeTemporarily(
153 enable_states=terr_enable)
154
155
156
157 if states.isEnabled("transerrors"):
158 self.__transerror.untrain()
159
160
161 for split in self.__splitter(dataset):
162
163
164 if states.isEnabled("splits"):
165 self.states.splits.append(split)
166
167 if states.isEnabled("transerrors"):
168
169
170 lastsplit = None
171 for ds in split:
172 if ds is not None:
173 lastsplit = ds._dsattr['lastsplit']
174 break
175 if lastsplit:
176
177
178 transerror = self.__transerror
179 else:
180
181 transerror = deepcopy(self.__transerror)
182 else:
183 transerror = self.__transerror
184
185
186 if clf_hastestdataset and expose_testdataset:
187 transerror.clf.testdataset = split[1]
188
189
190 result = transerror(split[1], split[0])
191
192
193 if clf_hastestdataset and expose_testdataset:
194 transerror.clf.testdataset = None
195
196
197 self._harvest(locals())
198
199
200
201 if states.isEnabled("transerrors"):
202 self.states.transerrors.append(transerror)
203
204
205
206 if states.isEnabled("samples_error"):
207 for k, v in \
208 transerror.states.samples_error.iteritems():
209 self.states.samples_error[k].append(v)
210
211
212 for state_var in ['confusion', 'training_confusion']:
213 if states.isEnabled(state_var):
214 states[state_var].value.__iadd__(
215 transerror.states[state_var].value)
216
217 if __debug__:
218 debug("CROSSC", "Split #%d: result %s" \
219 % (len(results), `result`))
220 results.append(result)
221
222
223 self.__transerror = transerror
224
225
226 if len(terr_enable):
227 self.__transerror.states._resetEnabledTemporarily()
228
229 self.states.results = results
230 """Store state variable if it is enabled"""
231
232
233 try:
234 if states.isEnabled("confusion"):
235 states.confusion.labels_map = dataset.labels_map
236 if states.isEnabled("training_confusion"):
237 states.training_confusion.labels_map = dataset.labels_map
238 except:
239 pass
240
241 return self.__combiner(results)
242
243
244 splitter = property(fget=lambda self:self.__splitter,
245 doc="Access to the Splitter instance.")
246 transerror = property(fget=lambda self:self.__transerror,
247 doc="Access to the TransferError instance.")
248 combiner = property(fget=lambda self:self.__combiner,
249 doc="Access to the configured combiner.")
250