1
2
3
4
5
6
7
8
9 """k-Nearest-Neighbour classifier."""
10
11 __docformat__ = 'restructuredtext'
12
13 import sys
14
15 _dict_has_key = sys.version_info >= (2, 5)
16
17 import numpy as N
18
19 from mvpa.base import warning
20 from mvpa.misc.support import indentDoc
21 from mvpa.clfs.base import Classifier
22 from mvpa.base.dochelpers import enhancedDocString
23 from mvpa.clfs.distance import squared_euclidean_distance
24
25 if __debug__:
26 from mvpa.base import debug
27
28
29 -class kNN(Classifier):
30 """
31 k-Nearest-Neighbour classifier.
32
33 This is a simple classifier that bases its decision on the distances
34 between the training dataset samples and the test sample(s). Distances
35 are computed using a customizable distance function. A certain number
36 (`k`)of nearest neighbors is selected based on the smallest distances
37 and the labels of this neighboring samples are fed into a voting
38 function to determine the labels of the test sample.
39
40 Training a kNN classifier is extremely quick, as no actuall training
41 is performed as the training dataset is simply stored in the
42 classifier. All computations are done during classifier prediction.
43
44 .. note::
45 If enabled, kNN stores the votes per class in the 'values' state after
46 calling predict().
47
48 """
49
50 _clf_internals = ['knn', 'non-linear', 'binary', 'multiclass',
51 'notrain2predict' ]
52
55 """
56 :Parameters:
57 k: unsigned integer
58 Number of nearest neighbours to be used for voting.
59 dfx: functor
60 Function to compute the distances between training and test samples.
61 Default: squared euclidean distance
62 voting: str
63 Voting method used to derive predictions from the nearest neighbors.
64 Possible values are 'majority' (simple majority of classes
65 determines vote) and 'weighted' (votes are weighted according to the
66 relative frequencies of each class in the training data).
67 **kwargs:
68 Additonal arguments are passed to the base class.
69 """
70
71
72 Classifier.__init__(self, **kwargs)
73
74 self.__k = k
75 self.__dfx = dfx
76 self.__voting = voting
77 self.__data = None
78
79
81 """Representation of the object
82 """
83 return super(kNN, self).__repr__(
84 ["k=%d" % self.__k, "dfx=%s" % self.__dfx,
85 "voting=%s" % repr(self.__voting)]
86 + prefixes)
87
88
92
93
95 """Train the classifier.
96
97 For kNN it is degenerate -- just stores the data.
98 """
99 self.__data = data
100 if __debug__:
101 if str(data.samples.dtype).startswith('uint') \
102 or str(data.samples.dtype).startswith('int'):
103 warning("kNN: input data is in integers. " + \
104 "Overflow on arithmetic operations might result in"+\
105 " errors. Please convert dataset's samples into" +\
106 " floating datatype if any error is reported.")
107 self.__weights = None
108
109
110 uniquelabels = data.uniquelabels
111 self.__votes_init = dict(zip(uniquelabels,
112 [0] * len(uniquelabels)))
113
114
116 """Predict the class labels for the provided data.
117
118 Returns a list of class labels (one for each data sample).
119 """
120
121 data = N.asarray(data)
122
123
124 if __debug__:
125 if not data.ndim == 2:
126 raise ValueError, "Data array must be two-dimensional."
127
128 if not data.shape[1] == self.__data.nfeatures:
129 raise ValueError, "Length of data samples (features) does " \
130 "not match the classifier."
131
132
133
134
135 dists = self.__dfx(self.__data.samples, data).T
136
137
138 knns = dists.argsort(axis=1)[:, :self.__k]
139
140
141 predicted = []
142
143 if self.__voting == 'majority':
144 vfx = self.getMajorityVote
145 elif self.__voting == 'weighted':
146 vfx = self.getWeightedVote
147 else:
148 raise ValueError, "kNN told to perform unknown voting '%s'." \
149 % self.__voting
150
151
152 results = [vfx(knn) for knn in knns]
153
154
155 predicted = [r[0] for r in results]
156
157
158
159 self.predictions = predicted
160 self.values = [r[1] for r in results]
161
162 return predicted
163
164
166 """Simple voting by choosing the majority of class neighbors.
167 """
168
169 _data = self.__data
170 labels = _data.labels
171
172
173 votes = self.__votes_init.copy()
174 for nn in knn_ids:
175 votes[labels[nn]] += 1
176
177
178
179 if _dict_has_key:
180
181 maxvotes = max(votes.iteritems(), key=lambda x:x[1])[0]
182 else:
183
184 maxvotes = max([(v, k) for k, v in votes.iteritems()])[1]
185
186 return maxvotes, \
187 [votes[ul] for ul in _data.uniquelabels]
188
189
191 """Vote with classes weighted by the number of samples per class.
192 """
193
194 _data = self.__data
195 uniquelabels = _data.uniquelabels
196
197
198 if self.__weights is None:
199
200
201
202
203 self.__labels = labels = self.__data.labels
204 Nlabels = len(labels)
205 Nuniquelabels = len(uniquelabels)
206
207
208
209
210
211
212
213 self.__weights = \
214 [ 1.0 - ((labels == label).sum() / Nlabels) \
215 for label in uniquelabels ]
216 self.__weights = dict(zip(uniquelabels, self.__weights))
217
218 labels = self.__labels
219
220 votes = self.__votes_init.copy()
221 for nn in knn_ids:
222 votes[labels[nn]] += 1
223
224
225 votes = [ self.__weights[ul] * votes[ul] for ul in uniquelabels]
226
227
228
229 return uniquelabels[N.asarray(votes).argmax()], \
230 votes
231
232
234 """Reset trained state"""
235 self.__data = None
236 super(kNN, self).untrain()
237