1
2
3
4
5
6
7
8
9 """Provide sensitivity measures for libsvm's SVM."""
10
11 __docformat__ = 'restructuredtext'
12
13 import numpy as N
14
15 from mvpa.base import warning
16 from mvpa.misc.state import StateVariable
17 from mvpa.misc.param import Parameter
18 from mvpa.measures.base import Sensitivity
19
20 if __debug__:
21 from mvpa.base import debug
22
24 """`SensitivityAnalyzer` for the LIBSVM implementation of a linear SVM.
25 """
26
27 _ATTRIBUTE_COLLECTIONS = ['params']
28
29 biases = StateVariable(enabled=True,
30 doc="Offsets of separating hyperplanes")
31
32 split_weights = Parameter(False, allowedtype='bool',
33 doc="If binary classification either to sum SVs per each "
34 "class separately")
35
37 """Initialize the analyzer with the classifier it shall use.
38
39 :Parameters:
40 clf: LinearSVM
41 classifier to use. Only classifiers sub-classed from
42 `LinearSVM` may be used.
43 """
44
45 Sensitivity.__init__(self, clf, **kwargs)
46
47
48 - def _call(self, dataset, callables=[]):
49
50 clf = self.clf
51 model = clf.model
52 if clf.params.regression:
53 nr_class = None
54 else:
55 nr_class = model.nr_class
56
57 if not nr_class in [None, 2]:
58 warning("You are estimating sensitivity for SVM %s trained on %d" %
59 (str(self.clf), self.clf.model.nr_class) +
60 " classes. Make sure that it is what you intended to do" )
61
62 svcoef = N.matrix(model.getSVCoef())
63 svs = N.matrix(model.getSV())
64 rhos = N.asarray(model.getRho())
65
66 self.biases = rhos
67 if self.split_weights:
68 if nr_class != 2:
69 raise NotImplementedError, \
70 "Cannot compute per-class weights for" \
71 " non-binary classification task"
72
73
74 svm_labels = model.getLabels()
75 ds_labels = list(dataset.uniquelabels)
76 senses = [None for i in ds_labels]
77
78 for i, (c, l) in enumerate( [(svcoef > 0, lambda x: x),
79 (svcoef < 0, lambda x: x*-1)] ):
80
81 c_ = c.A[0]
82 senses[ds_labels.index(svm_labels[i])] = \
83 (l(svcoef[:, c_] * svs[c_, :])).A[0]
84 weights = N.array(senses)
85 else:
86
87
88
89
90
91
92
93
94
95 weights = svcoef * svs
96
97 if __debug__ and 'SVM' in debug.active:
98 if clf.params.regression:
99 nsvs = model.getTotalNSV()
100 else:
101 nsvs = model.getNSV()
102 if clf.regression:
103 svm_type = clf._svm_impl
104 else:
105 svm_type = '%d-class SVM(%s)' % (nr_class, clf._svm_impl)
106 debug('SVM',
107 "Extracting weights for %s: #SVs=%s, " % \
108 (svm_type, nsvs) + \
109 " SVcoefshape=%s SVs.shape=%s Rhos=%s." % \
110 (svcoef.shape, svs.shape, rhos) + \
111 " Result: min=%f max=%f" % (N.min(weights), N.max(weights)))
112
113 return N.asarray(weights.T)
114
115 _customizeDocInherit = True
116