Package mvpa :: Package tests :: Module test_rfe
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_rfe

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Unit tests for PyMVPA recursive feature elimination""" 
 10   
 11  from mvpa.datasets.splitters import NFoldSplitter 
 12  from mvpa.algorithms.cvtranserror import CrossValidatedTransferError 
 13  from mvpa.datasets.masked import MaskedDataset 
 14  from mvpa.measures.base import FeaturewiseDatasetMeasure 
 15  from mvpa.featsel.rfe import RFE 
 16  from mvpa.featsel.base import \ 
 17       SensitivityBasedFeatureSelection, \ 
 18       FeatureSelectionPipeline 
 19  from mvpa.featsel.helpers import \ 
 20       NBackHistoryStopCrit, FractionTailSelector, FixedErrorThresholdStopCrit, \ 
 21       MultiStopCrit, NStepsStopCrit, \ 
 22       FixedNElementTailSelector, BestDetector, RangeElementSelector 
 23   
 24  from mvpa.clfs.meta import FeatureSelectionClassifier, SplitClassifier 
 25  from mvpa.clfs.transerror import TransferError, ConfusionBasedError 
 26  from mvpa.clfs.stats import MCNullDist 
 27  from mvpa.misc.transformers import Absolute, FirstAxisMean 
 28   
 29  from mvpa.misc.state import UnknownStateError 
 30   
 31  from tests_warehouse import * 
 32  from tests_warehouse_clfs import * 
33 34 -class SillySensitivityAnalyzer(FeaturewiseDatasetMeasure):
35 """Simple one which just returns xrange[-N/2, N/2], where N is the 36 number of features 37 """ 38
39 - def __init__(self, mult=1, **kwargs):
40 FeaturewiseDatasetMeasure.__init__(self, **kwargs) 41 self.__mult = mult
42
43 - def __call__(self, dataset):
44 """Train linear SVM on `dataset` and extract weights from classifier. 45 """ 46 return( self.__mult *( N.arange(dataset.nfeatures) - int(dataset.nfeatures/2) ))
47
48 49 -class RFETests(unittest.TestCase):
50
51 - def getData(self):
52 return datasets['uni2medium_train']
53
54 - def getDataT(self):
55 return datasets['uni2medium_test']
56 57
58 - def testBestDetector(self):
59 bd = BestDetector() 60 61 # for empty history -- no best 62 self.failUnless(bd([]) == False) 63 # we got the best if we have just 1 64 self.failUnless(bd([1]) == True) 65 # we got the best if we have the last minimal 66 self.failUnless(bd([1, 0.9, 0.8]) == True) 67 68 # test for alternative func 69 bd = BestDetector(func=max) 70 self.failUnless(bd([0.8, 0.9, 1.0]) == True) 71 self.failUnless(bd([0.8, 0.9, 1.0]+[0.9]*9) == False) 72 self.failUnless(bd([0.8, 0.9, 1.0]+[0.9]*10) == False) 73 74 # test to detect earliest and latest minimum 75 bd = BestDetector(lastminimum=True) 76 self.failUnless(bd([3, 2, 1, 1, 1, 2, 1]) == True) 77 bd = BestDetector() 78 self.failUnless(bd([3, 2, 1, 1, 1, 2, 1]) == False)
79 80
81 - def testNBackHistoryStopCrit(self):
82 """Test stopping criterion""" 83 stopcrit = NBackHistoryStopCrit() 84 # for empty history -- no best but just go 85 self.failUnless(stopcrit([]) == False) 86 # should not stop if we got 10 more after minimal 87 self.failUnless(stopcrit( 88 [1, 0.9, 0.8]+[0.9]*(stopcrit.steps-1)) == False) 89 # should stop if we got 10 more after minimal 90 self.failUnless(stopcrit( 91 [1, 0.9, 0.8]+[0.9]*stopcrit.steps) == True) 92 93 # test for alternative func 94 stopcrit = NBackHistoryStopCrit(BestDetector(func=max)) 95 self.failUnless(stopcrit([0.8, 0.9, 1.0]+[0.9]*9) == False) 96 self.failUnless(stopcrit([0.8, 0.9, 1.0]+[0.9]*10) == True) 97 98 # test to detect earliest and latest minimum 99 stopcrit = NBackHistoryStopCrit(BestDetector(lastminimum=True)) 100 self.failUnless(stopcrit([3, 2, 1, 1, 1, 2, 1]) == False) 101 stopcrit = NBackHistoryStopCrit(steps=4) 102 self.failUnless(stopcrit([3, 2, 1, 1, 1, 2, 1]) == True)
103 104
106 """Test stopping criterion""" 107 stopcrit = FixedErrorThresholdStopCrit(0.5) 108 109 self.failUnless(stopcrit([]) == False) 110 self.failUnless(stopcrit([0.8, 0.9, 0.5]) == False) 111 self.failUnless(stopcrit([0.8, 0.9, 0.4]) == True) 112 # only last error has to be below to stop 113 self.failUnless(stopcrit([0.8, 0.4, 0.6]) == False)
114 115
116 - def testNStepsStopCrit(self):
117 """Test stopping criterion""" 118 stopcrit = NStepsStopCrit(2) 119 120 self.failUnless(stopcrit([]) == False) 121 self.failUnless(stopcrit([0.8, 0.9]) == True) 122 self.failUnless(stopcrit([0.8]) == False)
123 124
125 - def testMultiStopCrit(self):
126 """Test multiple stop criteria""" 127 stopcrit = MultiStopCrit([FixedErrorThresholdStopCrit(0.5), 128 NBackHistoryStopCrit(steps=4)]) 129 130 # default 'or' mode 131 # nback triggers 132 self.failUnless(stopcrit([1, 0.9, 0.8]+[0.9]*4) == True) 133 # threshold triggers 134 self.failUnless(stopcrit([1, 0.9, 0.2]) == True) 135 136 # alternative 'and' mode 137 stopcrit = MultiStopCrit([FixedErrorThresholdStopCrit(0.5), 138 NBackHistoryStopCrit(steps=4)], 139 mode = 'and') 140 # nback triggers not 141 self.failUnless(stopcrit([1, 0.9, 0.8]+[0.9]*4) == False) 142 # threshold triggers not 143 self.failUnless(stopcrit([1, 0.9, 0.2]) == False) 144 # only both satisfy 145 self.failUnless(stopcrit([1, 0.9, 0.4]+[0.4]*4) == True)
146 147
148 - def testFeatureSelector(self):
149 """Test feature selector""" 150 # remove 10% weekest 151 selector = FractionTailSelector(0.1) 152 data = N.array([3.5, 10, 7, 5, -0.4, 0, 0, 2, 10, 9]) 153 # == rank [4, 5, 6, 7, 0, 3, 2, 9, 1, 8] 154 target10 = N.array([0, 1, 2, 3, 5, 6, 7, 8, 9]) 155 target30 = N.array([0, 1, 2, 3, 7, 8, 9]) 156 157 self.failUnlessRaises(UnknownStateError, 158 selector.__getattribute__, 'ndiscarded') 159 self.failUnless((selector(data) == target10).all()) 160 selector.felements = 0.30 # discard 30% 161 self.failUnless(selector.felements == 0.3) 162 self.failUnless((selector(data) == target30).all()) 163 self.failUnless(selector.ndiscarded == 3) # se 3 were discarded 164 165 selector = FixedNElementTailSelector(1) 166 # 0 1 2 3 4 5 6 7 8 9 167 data = N.array([3.5, 10, 7, 5, -0.4, 0, 0, 2, 10, 9]) 168 self.failUnless((selector(data) == target10).all()) 169 170 selector.nelements = 3 171 self.failUnless(selector.nelements == 3) 172 self.failUnless((selector(data) == target30).all()) 173 self.failUnless(selector.ndiscarded == 3) 174 175 # test range selector 176 # simple range 'above' 177 self.failUnless((RangeElementSelector(lower=0)(data) == \ 178 N.array([0,1,2,3,7,8,9])).all()) 179 180 self.failUnless((RangeElementSelector(lower=0, 181 inclusive=True)(data) == \ 182 N.array([0,1,2,3,5,6,7,8,9])).all()) 183 184 self.failUnless((RangeElementSelector(lower=0, mode='discard', 185 inclusive=True)(data) == \ 186 N.array([4])).all()) 187 188 # simple range 'below' 189 self.failUnless((RangeElementSelector(upper=2)(data) == \ 190 N.array([4,5,6])).all()) 191 192 self.failUnless((RangeElementSelector(upper=2, 193 inclusive=True)(data) == \ 194 N.array([4,5,6,7])).all()) 195 196 self.failUnless((RangeElementSelector(upper=2, mode='discard', 197 inclusive=True)(data) == \ 198 N.array([0,1,2,3,8,9])).all()) 199 200 201 # ranges 202 self.failUnless((RangeElementSelector(lower=2, upper=9)(data) == \ 203 N.array([0,2,3])).all()) 204 205 self.failUnless((RangeElementSelector(lower=2, upper=9, 206 inclusive=True)(data) == \ 207 N.array([0,2,3,7,9])).all()) 208 209 self.failUnless((RangeElementSelector(upper=2, lower=9, mode='discard', 210 inclusive=True)(data) == 211 RangeElementSelector(lower=2, upper=9, 212 inclusive=False)(data)).all()) 213 214 # non-0 elements -- should be equivalent to N.nonzero()[0] 215 self.failUnless((RangeElementSelector()(data) == \ 216 N.nonzero(data)[0]).all())
217 218 219 @sweepargs(clf=clfswh['has_sensitivity', '!meta'])
221 222 # sensitivity analyser and transfer error quantifier use the SAME clf! 223 sens_ana = clf.getSensitivityAnalyzer() 224 225 # of features to remove 226 Nremove = 2 227 228 # because the clf is already trained when computing the sensitivity 229 # map, prevent retraining for transfer error calculation 230 # Use absolute of the svm weights as sensitivity 231 fe = SensitivityBasedFeatureSelection(sens_ana, 232 feature_selector=FixedNElementTailSelector(2), 233 enable_states=["sensitivity", "selected_ids"]) 234 235 wdata = self.getData() 236 wdata_nfeatures = wdata.nfeatures 237 tdata = self.getDataT() 238 tdata_nfeatures = tdata.nfeatures 239 240 sdata, stdata = fe(wdata, tdata) 241 242 # fail if orig datasets are changed 243 self.failUnless(wdata.nfeatures == wdata_nfeatures) 244 self.failUnless(tdata.nfeatures == tdata_nfeatures) 245 246 # silly check if nfeatures got a single one removed 247 self.failUnlessEqual(wdata.nfeatures, sdata.nfeatures+Nremove, 248 msg="We had to remove just a single feature") 249 250 self.failUnlessEqual(tdata.nfeatures, stdata.nfeatures+Nremove, 251 msg="We had to remove just a single feature in testing as well") 252 253 self.failUnlessEqual(len(fe.sensitivity), wdata_nfeatures, 254 msg="Sensitivity have to have # of features equal to original") 255 256 self.failUnlessEqual(len(fe.selected_ids), sdata.nfeatures, 257 msg="# of selected features must be equal the one in the result dataset")
258 259
261 sens_ana = SillySensitivityAnalyzer() 262 263 wdata = self.getData() 264 wdata_nfeatures = wdata.nfeatures 265 tdata = self.getDataT() 266 tdata_nfeatures = tdata.nfeatures 267 268 # test silly one first ;-) 269 self.failUnlessEqual(sens_ana(wdata)[0], -int(wdata_nfeatures/2)) 270 271 # OLD: first remove 25% == 6, and then 4, total removing 10 272 # NOW: test should be independent of the numerical number of features 273 feature_selections = [SensitivityBasedFeatureSelection( 274 sens_ana, 275 FractionTailSelector(0.25)), 276 SensitivityBasedFeatureSelection( 277 sens_ana, 278 FixedNElementTailSelector(4)) 279 ] 280 281 # create a FeatureSelection pipeline 282 feat_sel_pipeline = FeatureSelectionPipeline( 283 feature_selections=feature_selections, 284 enable_states=['nfeatures', 'selected_ids']) 285 286 sdata, stdata = feat_sel_pipeline(wdata, tdata) 287 288 self.failUnlessEqual(len(feat_sel_pipeline.feature_selections), 289 len(feature_selections), 290 msg="Test the property feature_selections") 291 292 desired_nfeatures = int(N.ceil(wdata_nfeatures*0.75)) 293 self.failUnlessEqual(feat_sel_pipeline.nfeatures, 294 [wdata_nfeatures, desired_nfeatures], 295 msg="Test if nfeatures get assigned properly." 296 " Got %s!=%s" % (feat_sel_pipeline.nfeatures, 297 [wdata_nfeatures, desired_nfeatures])) 298 299 self.failUnlessEqual(list(feat_sel_pipeline.selected_ids), 300 range(int(wdata_nfeatures*0.25)+4, wdata_nfeatures))
301 302 303 # TODO: should later on work for any clfs_with_sens 304 @sweepargs(clf=clfswh['has_sensitivity', '!meta'][:1])
305 - def testRFE(self, clf):
306 307 # sensitivity analyser and transfer error quantifier use the SAME clf! 308 sens_ana = clf.getSensitivityAnalyzer() 309 trans_error = TransferError(clf) 310 # because the clf is already trained when computing the sensitivity 311 # map, prevent retraining for transfer error calculation 312 # Use absolute of the svm weights as sensitivity 313 rfe = RFE(sens_ana, 314 trans_error, 315 feature_selector=FixedNElementTailSelector(1), 316 train_clf=False) 317 318 wdata = self.getData() 319 wdata_nfeatures = wdata.nfeatures 320 tdata = self.getDataT() 321 tdata_nfeatures = tdata.nfeatures 322 323 sdata, stdata = rfe(wdata, tdata) 324 325 # fail if orig datasets are changed 326 self.failUnless(wdata.nfeatures == wdata_nfeatures) 327 self.failUnless(tdata.nfeatures == tdata_nfeatures) 328 329 # check that the features set with the least error is selected 330 if len(rfe.errors): 331 e = N.array(rfe.errors) 332 self.failUnless(sdata.nfeatures == wdata_nfeatures - e.argmin()) 333 else: 334 self.failUnless(sdata.nfeatures == wdata_nfeatures) 335 336 # silly check if nfeatures is in decreasing order 337 nfeatures = N.array(rfe.nfeatures).copy() 338 nfeatures.sort() 339 self.failUnless( (nfeatures[::-1] == rfe.nfeatures).all() ) 340 341 # check if history has elements for every step 342 self.failUnless(set(rfe.history) 343 == set(range(len(N.array(rfe.errors))))) 344 345 # Last (the largest number) can be present multiple times even 346 # if we remove 1 feature at a time -- just need to stop well 347 # in advance when we have more than 1 feature left ;) 348 self.failUnless(rfe.nfeatures[-1] 349 == len(N.where(rfe.history 350 ==max(rfe.history))[0]))
351 352 # XXX add a test where sensitivity analyser and transfer error do not 353 # use the same classifier 354 355
356 - def testJamesProblem(self):
357 percent = 80 358 dataset = datasets['uni2small'] 359 rfesvm_split = LinearCSVMC() 360 fs = \ 361 RFE(sensitivity_analyzer=rfesvm_split.getSensitivityAnalyzer(), 362 transfer_error=TransferError(rfesvm_split), 363 feature_selector=FractionTailSelector( 364 percent / 100.0, 365 mode='select', tail='upper'), update_sensitivity=True) 366 367 clf = FeatureSelectionClassifier( 368 clf = LinearCSVMC(), 369 # on features selected via RFE 370 feature_selection = fs) 371 # update sensitivity at each step (since we're not using the 372 # same CLF as sensitivity analyzer) 373 clf.states.enable('feature_ids') 374 375 cv = CrossValidatedTransferError( 376 TransferError(clf), 377 NFoldSplitter(cvtype=1), 378 enable_states=['confusion'], 379 expose_testdataset=True) 380 #cv = SplitClassifier(clf) 381 try: 382 error = cv(dataset) 383 except Exception, e: 384 self.fail('CrossValidation cannot handle classifier with RFE ' 385 'feature selection. Got exception: %s' % e) 386 self.failUnless(error < 0.2)
387 388
389 - def __testMatthiasQuestion(self):
390 rfe_clf = LinearCSVMC(C=1) 391 392 rfesvm_split = SplitClassifier(rfe_clf) 393 clf = \ 394 FeatureSelectionClassifier( 395 clf = LinearCSVMC(C=1), 396 feature_selection = RFE( 397 sensitivity_analyzer = rfesvm_split.getSensitivityAnalyzer( 398 combiner=FirstAxisMean, 399 transformer=N.abs), 400 transfer_error=ConfusionBasedError( 401 rfesvm_split, 402 confusion_state="confusion"), 403 stopping_criterion=FixedErrorThresholdStopCrit(0.20), 404 feature_selector=FractionTailSelector( 405 0.2, mode='discard', tail='lower'), 406 update_sensitivity=True)) 407 408 splitter = NFoldSplitter(cvtype=1) 409 no_permutations = 1000 410 411 cv = CrossValidatedTransferError( 412 TransferError(clf), 413 splitter, 414 null_dist=MCNullDist(permutations=no_permutations, 415 tail='left'), 416 enable_states=['confusion']) 417 error = cv(datasets['uni2small']) 418 self.failUnless(error < 0.4) 419 self.failUnless(cv.states.null_prob < 0.05)
420
421 -def suite():
422 return unittest.makeSuite(RFETests)
423 424 425 if __name__ == '__main__': 426 import runner 427