Package mvpa :: Package clfs :: Module ridge
[hide private]
[frames] | no frames]

Source Code for Module mvpa.clfs.ridge

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Ridge regression classifier.""" 
 10   
 11  __docformat__ = 'restructuredtext' 
 12   
 13   
 14  import numpy as N 
 15  from mvpa.base import externals 
 16   
 17  if externals.exists("scipy", raiseException=True): 
 18      from scipy.linalg import lstsq 
 19   
 20  from mvpa.clfs.base import Classifier 
 21   
22 -class RidgeReg(Classifier):
23 """Ridge regression `Classifier`. 24 25 This ridge regression adds an intercept term so your labels do not 26 have to be zero-centered. 27 """ 28 29 _clf_internals = ['ridge', 'regression', 'linear'] 30
31 - def __init__(self, lm=None, **kwargs):
32 """ 33 Initialize a ridge regression analysis. 34 35 :Parameters: 36 lm : float 37 the penalty term lambda. 38 (Defaults to .05*nFeatures) 39 """ 40 # init base class first 41 Classifier.__init__(self, **kwargs) 42 43 # pylint happiness 44 self.w = None 45 46 # It does not make sense to calculate a confusion matrix for a 47 # ridge regression 48 self.states.enable('training_confusion', False) 49 50 # verify that they specified lambda 51 self.__lm = lm 52 53 # store train method config 54 self.__implementation = 'direct'
55 56
57 - def __repr__(self):
58 """String summary of the object 59 """ 60 if self.__lm is None: 61 return """Ridge(lm=.05*nfeatures, enable_states=%s)""" % \ 62 (str(self.states.enabled)) 63 else: 64 return """Ridge(lm=%f, enable_states=%s)""" % \ 65 (self.__lm, str(self.states.enabled))
66 67
68 - def _train(self, data):
69 """Train the classifier using `data` (`Dataset`). 70 """ 71 72 if self.__implementation == "direct": 73 # create matrices to solve with additional penalty term 74 # determine the lambda matrix 75 if self.__lm is None: 76 # Not specified, so calculate based on .05*nfeatures 77 Lambda = .05*data.nfeatures*N.eye(data.nfeatures) 78 else: 79 # use the provided penalty 80 Lambda = self.__lm*N.eye(data.nfeatures) 81 82 # add the penalty term 83 a = N.concatenate( \ 84 (N.concatenate((data.samples, N.ones((data.nsamples, 1))), 1), 85 N.concatenate((Lambda, N.zeros((data.nfeatures, 1))), 1))) 86 b = N.concatenate((data.labels, N.zeros(data.nfeatures))) 87 88 # perform the least sq regression and save the weights 89 self.w = lstsq(a, b)[0] 90 else: 91 raise ValueError, "Unknown implementation '%s'" \ 92 % self.__implementation
93 94
95 - def _predict(self, data):
96 """ 97 Predict the output for the provided data. 98 """ 99 # predict using the trained weights 100 return N.dot(N.concatenate((data, N.ones((len(data), 1))), 1), 101 self.w)
102