Package mvpa :: Package misc :: Package plot :: Module base
[hide private]
[frames] | no frames]

Source Code for Module mvpa.misc.plot.base

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Misc. plotting helpers.""" 
 10   
 11  __docformat__ = 'restructuredtext' 
 12   
 13  import pylab as P 
 14  import numpy as N 
 15   
 16  from mvpa.datasets.splitters import NFoldSplitter 
 17  from mvpa.clfs.distance import squared_euclidean_distance 
 18   
 19   
 20   
21 -def plotErrLine(data, x=None, errtype='ste', curves=None, linestyle='--', 22 fmt='o', perc_sigchg=False, baseline=None):
23 """Make a line plot with errorbars on the data points. 24 25 :Parameters: 26 data: sequence of sequences 27 First axis separates samples and second axis will appear as 28 x-axis in the plot. 29 x: sequence 30 Value to be used as 'x-values' corresponding to the elements of 31 the 2nd axis id `data`. If `None`, a sequence of ascending integers 32 will be generated. 33 errtype: 'ste' | 'std' 34 Type of error value to be computed per datapoint. 35 'ste': standard error of the mean 36 'std': standard deviation 37 curves: None | list of tuple(x, y) 38 Each tuple represents an additional curve, with x and y coordinates of 39 each point on the curve. 40 linestyle: str 41 matplotlib linestyle argument. Applied to either the additional 42 curve or a the line connecting the datapoints. Set to 'None' to 43 disable the line completely. 44 fmt: str 45 matplotlib plot style argument to be applied to the data points 46 and errorbars. 47 perc_sigchg: bool 48 If `True` the plot will show percent signal changes relative to a 49 baseline. 50 baseline: float | None 51 Baseline used for converting values into percent signal changes. 52 If `None` and `perc_sigchg` is `True`, the absolute of the mean of the 53 first feature (i.e. [:,0]) will be used as a baseline. 54 55 56 :Example: 57 58 Make dataset with 20 samples from a full sinus wave period, 59 computed 100 times with individual noise pattern. 60 61 >>> x = N.linspace(0, N.pi * 2, 20) 62 >>> data = N.vstack([N.sin(x)] * 30) 63 >>> data += N.random.normal(size=data.shape) 64 65 Now, plot mean data points with error bars, plus a high-res 66 version of the original sinus wave. 67 68 >>> x = N.linspace(0, N.pi * 2, 200) 69 >>> plotErrLine(data, curves=[(x, N.sin(x))]) 70 >>> #P.show() 71 """ 72 data = N.asanyarray(data) 73 74 if len(data.shape) < 2: 75 data = N.atleast_2d(data) 76 77 # compute mean signal course 78 md = data.mean(axis=0) 79 80 if baseline is None: 81 baseline = N.abs(md[0]) 82 83 if perc_sigchg: 84 md /= baseline 85 md -= 1.0 86 md *= 100.0 87 # not in-place to keep original data intact 88 data = data / baseline 89 data *= 100.0 90 91 # compute matching datapoint locations on x-axis 92 if x is None: 93 x = N.arange(len(md)) 94 else: 95 if not len(md) == len(x): 96 raise ValueError, "The length of `x` (%i) has to match the 2nd " \ 97 "axis of the data array (%i)" % (len(x), len(md)) 98 99 # plot highres line if present 100 if curves is not None: 101 for c in curves: 102 xc, yc = c 103 # scales line array to same range as datapoints 104 P.plot(xc, yc, linestyle=linestyle) 105 106 # no line between data points 107 linestyle = 'None' 108 109 # compute error per datapoint 110 if errtype == 'ste': 111 err = data.std(axis=0) / N.sqrt(len(data)) 112 elif errtype == 'std': 113 err = data.std(axis=0) 114 else: 115 raise ValueError, "Unknown error type '%s'" % errtype 116 117 # plot datapoints with error bars 118 P.errorbar(x, md, err, fmt=fmt, linestyle=linestyle)
119 120
121 -def plotFeatureHist(dataset, xlim=None, noticks=True, perchunk=False, 122 **kwargs):
123 """Plot histograms of feature values for each labels. 124 125 :Parameters: 126 dataset: Dataset 127 xlim: None | 2-tuple 128 Common x-axis limits for all histograms. 129 noticks: boolean 130 If True, no axis ticks will be plotted. This is useful to save 131 space in large plots. 132 perchunk: boolean 133 If True, one histogramm will be plotted per each label and each 134 chunk, resulting is a histogram grid (labels x chunks). 135 **kwargs: 136 Any additional arguments are passed to matplotlib's hist(). 137 """ 138 lsplit = NFoldSplitter(1, attr='labels') 139 csplit = NFoldSplitter(1, attr='chunks') 140 141 nrows = len(dataset.uniquelabels) 142 ncols = len(dataset.uniquechunks) 143 144 def doplot(data): 145 P.hist(data, **kwargs) 146 147 if xlim is not None: 148 P.xlim(xlim) 149 150 if noticks: 151 P.yticks([]) 152 P.xticks([])
153 154 fig = 1 155 156 # for all labels 157 for row, (ignore, ds) in enumerate(lsplit(dataset)): 158 if perchunk: 159 for col, (alsoignore, d) in enumerate(csplit(ds)): 160 161 P.subplot(nrows, ncols, fig) 162 doplot(d.samples.ravel()) 163 164 if row == 0: 165 P.title('C:' + str(d.uniquechunks[0])) 166 if col == 0: 167 P.ylabel('L:' + str(d.uniquelabels[0])) 168 169 fig += 1 170 else: 171 P.subplot(1, nrows, fig) 172 doplot(ds.samples) 173 174 P.title('L:' + str(ds.uniquelabels[0])) 175 176 fig += 1 177 178
179 -def plotSamplesDistance(dataset, sortbyattr=None):
180 """Plot the euclidean distances between all samples of a dataset. 181 182 :Parameters: 183 dataset: Dataset 184 Providing the samples. 185 sortbyattr: None | str 186 If None, the samples distances will be in the same order as their 187 appearance in the dataset. Alternatively, the name of a samples 188 attribute can be given, which wil then be used to sort/group the 189 samples, e.g. to investigate the similarity samples by label or by 190 chunks. 191 """ 192 if sortbyattr is not None: 193 slicer = [] 194 for attr in dataset.__getattribute__('unique' + sortbyattr): 195 slicer += \ 196 dataset.__getattribute__('idsby' + sortbyattr)(attr).tolist() 197 samples = dataset.samples[slicer] 198 else: 199 samples = dataset.samples 200 201 ed = N.sqrt(squared_euclidean_distance(samples)) 202 203 P.imshow(ed) 204 P.colorbar()
205 206
207 -def plotBars(data, labels=None, title=None, ylim=None, ylabel=None, 208 width=0.2, offset=0.2, color='0.6', distance=1.0, 209 yerr='ste', **kwargs):
210 """Make bar plots with automatically computed error bars. 211 212 Candlestick plot (multiple interleaved barplots) can be done, 213 by calling this function multiple time with appropriatly modified 214 `offset` argument. 215 216 :Parameters: 217 data: array (nbars x nobservations) | other sequence type 218 Source data for the barplot. Error measure is computed along the 219 second axis. 220 labels: list | None 221 If not None, a label from this list is placed on each bar. 222 title: str 223 An optional title of the barplot. 224 ylim: 2-tuple 225 Y-axis range. 226 ylabel: str 227 An optional label for the y-axis. 228 width: float 229 Width of a bar. The value should be in a reasonable relation to 230 `distance`. 231 offset: float 232 Constant offset of all bar along the x-axis. Can be used to create 233 candlestick plots. 234 color: matplotlib color spec 235 Color of the bars. 236 distance: float 237 Distance of two adjacent bars. 238 yerr: 'ste' | 'std' | None 239 Type of error for the errorbars. If `None` no errorbars are plotted. 240 **kwargs: 241 Any additional arguments are passed to matplotlib's `bar()` function. 242 """ 243 # determine location of bars 244 xlocations = (N.arange(len(data)) * distance) + offset 245 246 if yerr == 'ste': 247 yerr = [N.std(d) / N.sqrt(len(d)) for d in data] 248 elif yerr == 'std': 249 yerr = [N.std(d) for d in data] 250 else: 251 # if something that we do not know just pass on 252 pass 253 254 # plot bars 255 plot = P.bar(xlocations, 256 [N.mean(d) for d in data], 257 yerr=yerr, 258 width=width, 259 color=color, 260 ecolor='black', 261 **kwargs) 262 263 if ylim: 264 P.ylim(*(ylim)) 265 if title: 266 P.title(title) 267 268 if labels: 269 P.xticks(xlocations + width / 2, labels) 270 271 if ylabel: 272 P.ylabel(ylabel) 273 274 # leave some space after last bar 275 P.xlim(0, xlocations[-1] + width + offset) 276 277 return plot
278 279
280 -def inverseCmap(cmap_name):
281 """Create a new colormap from the named colormap, where it got reversed 282 283 """ 284 import matplotlib._cm as _cm 285 import matplotlib as mpl 286 try: 287 cmap_data = eval('_cm._%s_data' % cmap_name) 288 except: 289 raise ValueError, "Cannot obtain data for the colormap %s" % cmap_name 290 new_data = dict( [(k, [(v[i][0], v[-(i+1)][1], v[-(i+1)][2]) 291 for i in xrange(len(v))]) 292 for k,v in cmap_data.iteritems()] ) 293 return mpl.colors.LinearSegmentedColormap('%s_rev' % cmap_name, 294 new_data, _cm.LUTSIZE)
295 296
297 -def plotDatasetChunks(ds, clf_labels=None):
298 """Quick plot to see chunk sctructure in dataset with 2 features 299 300 if clf_labels is provided for the predicted labels, then 301 incorrectly labeled samples will have 'x' in them 302 """ 303 if ds.nfeatures != 2: 304 raise ValueError, "Can plot only in 2D, ie for datasets with 2 features" 305 if P.matplotlib.get_backend() == 'TkAgg': 306 P.ioff() 307 if clf_labels is not None and len(clf_labels) != ds.nsamples: 308 clf_labels = None 309 colors = ('b', 'g', 'r', 'c', 'm', 'y', 'k', 'w') 310 labels = ds.uniquelabels 311 labels_map = dict(zip(labels, colors[:len(labels)])) 312 for chunk in ds.uniquechunks: 313 chunk_text = str(chunk) 314 ids = ds.where(chunks=chunk) 315 ds_chunk = ds[ids] 316 for i in xrange(ds_chunk.nsamples): 317 s = ds_chunk.samples[i] 318 l = ds_chunk.labels[i] 319 format = '' 320 if clf_labels != None: 321 if clf_labels[i] != ds_chunk.labels[i]: 322 P.plot([s[0]], [s[1]], 'x' + labels_map[l]) 323 P.text(s[0], s[1], chunk_text, color=labels_map[l], 324 horizontalalignment='center', 325 verticalalignment='center', 326 ) 327 dss = ds.samples 328 P.axis((1.1 * N.min(dss[:, 0]), 329 1.1 * N.max(dss[:, 1]), 330 1.1 * N.max(dss[:, 0]), 331 1.1 * N.min(dss[:, 1]))) 332 P.draw() 333 P.ion()
334