1   
  2   
  3   
  4   
  5   
  6   
  7   
  8   
  9  """Unit tests for PyMVPA Regressions""" 
 10   
 11  from mvpa.base import externals 
 12  from mvpa.support.copy import deepcopy 
 13   
 14  from mvpa.datasets import Dataset 
 15  from mvpa.mappers.mask import MaskMapper 
 16  from mvpa.datasets.splitters import NFoldSplitter, OddEvenSplitter 
 17   
 18  from mvpa.misc.errorfx import RMSErrorFx, RelativeRMSErrorFx, \ 
 19       CorrErrorFx, CorrErrorPFx 
 20   
 21  from mvpa.clfs.meta import SplitClassifier 
 22  from mvpa.clfs.transerror import TransferError 
 23  from mvpa.misc.exceptions import UnknownStateError 
 24   
 25  from mvpa.algorithms.cvtranserror import CrossValidatedTransferError 
 26   
 27  from tests_warehouse import * 
 28  from tests_warehouse_clfs import * 
 31   
 32      @sweepargs(ml=clfswh['regression']+regrswh[:]) 
 34          """Test If binary regression-based  classifiers have proper tag 
 35          """ 
 36          self.failUnless(('binary' in ml._clf_internals) != ml.regression, 
 37              msg="Inconsistent markin with binary and regression features" 
 38                  " detected in %s having %s" % (ml, `ml._clf_internals`)) 
  39   
 40      @sweepargs(regr=regrswh['regression']) 
 42          """Simple tests on regressions 
 43          """ 
 44          ds = datasets['chirp_linear'] 
 45   
 46          cve = CrossValidatedTransferError( 
 47              TransferError(regr, CorrErrorFx()), 
 48              splitter=NFoldSplitter(), 
 49              enable_states=['training_confusion', 'confusion']) 
 50          corr = cve(ds) 
 51   
 52          self.failUnless(corr == cve.confusion.stats['CCe']) 
 53   
 54          splitregr = SplitClassifier(regr, 
 55                                      splitter=OddEvenSplitter(), 
 56                                      enable_states=['training_confusion', 'confusion']) 
 57          splitregr.train(ds) 
 58          split_corr = splitregr.confusion.stats['CCe'] 
 59          split_corr_tr = splitregr.training_confusion.stats['CCe'] 
 60   
 61          for confusion, error in ((cve.confusion, corr), 
 62                                   (splitregr.confusion, split_corr), 
 63                                   (splitregr.training_confusion, split_corr_tr), 
 64                                   ): 
 65               
 66               
 67              for conf in confusion.summaries: 
 68                  stats = conf.stats 
 69                  self.failUnless(stats['CCe'] < 0.5) 
 70                  self.failUnlessEqual(stats['CCe'], stats['Summary CCe']) 
 71   
 72              s0 = confusion.asstring(short=True) 
 73              s1 = confusion.asstring(short=False) 
 74   
 75              for s in [s0, s1]: 
 76                  self.failUnless(len(s) > 10, 
 77                                  msg="We should get some string representation " 
 78                                  "of regression summary. Got %s" % s) 
 79   
 80              self.failUnless(error < 0.2, 
 81                              msg="Regressions should perform well on a simple " 
 82                              "dataset. Got correlation error of %s " % error) 
 83   
 84               
 85               
 86               
 87               
 88               
 89               
 90              self.failUnless(confusion.stats['CCe'] < 0.5) 
 91   
 92          split_predictions = splitregr.predict(ds.samples)  
  93   
 94           
 95           
 96           
 97           
 98   
 99      @sweepargs(clf=clfswh['regression']) 
115   
116   
117      @sweepargs(regr=regrswh['regression', 'has_sensitivity']) 
119          """Inspired by a snippet leading to segfault from Daniel Kimberg 
120   
121          lead to segfaults due to inappropriate access of SVs thinking 
122          that it is a classification problem (libsvm keeps SVs at None 
123          for those, although reports nr_class to be 2. 
124          """ 
125          myds = Dataset(samples=N.random.normal(size=(10,5)), 
126                         labels=N.random.normal(size=10)) 
127          sa = regr.getSensitivityAnalyzer() 
128          try: 
129              res = sa(myds) 
130          except Exception, e: 
131              self.fail('Failed to obtain a sensitivity due to %r' % (e,)) 
132          self.failUnless(res.shape == (myds.nfeatures,)) 
  133           
138   
139   
140  if __name__ == '__main__': 
141      import runner 
142