1
2
3
4
5
6
7
8
9 """Unit tests for PyMVPA nifti dataset"""
10
11 import unittest
12 import os
13 import numpy as N
14 from tempfile import mktemp
15
16 from mvpa import pymvpa_dataroot
17 from mvpa.datasets.nifti import *
18 from mvpa.misc.exceptions import *
19 from mvpa.misc.fsl import FslEV3
20 from mvpa.misc.support import Event
21
23 """Tests of various Nifti-based datasets
24 """
25
27 """Basic testing of NiftiDataset
28 """
29 data = NiftiDataset(samples=os.path.join(pymvpa_dataroot,'example4d'),
30 labels=[1,2])
31 self.failUnless(data.nfeatures == 294912)
32 self.failUnless(data.nsamples == 2)
33
34 self.failUnless((data.mapper.metric.elementsize \
35 == data.niftihdr['pixdim'][3:0:-1]).all())
36
37
38 nb22 = N.array([i for i in data.mapper.getNeighborIn((1, 1, 1), 2.2)])
39 nb20 = N.array([i for i in data.mapper.getNeighborIn((1, 1, 1), 2.0)])
40 self.failUnless(nb22.shape[0] == 7)
41 self.failUnless(nb20.shape[0] == 5)
42
43
44
45 self.failUnless(data.dt in [2.0, 2000.0])
46 self.failUnless(data.samplingrate in [5e-4, 5e-1])
47 merged = data + data
48
49 self.failUnless(merged.nfeatures == 294912)
50 self.failUnless(merged.nsamples == 4)
51
52
53
54 for k in merged.niftihdr.keys():
55 self.failUnless(N.mean(merged.niftihdr[k] == data.niftihdr[k]) == 1)
56
57
58 del data
59 self.failUnless(merged.samples[3, 120000] == merged.samples[1, 120000])
60
61
62 mask = N.zeros((24, 96, 128), dtype='bool')
63 mask[12, 20, 40] = True
64 nddata = NiftiDataset(samples=os.path.join(pymvpa_dataroot,'example4d'),
65 labels=[1,2],
66 mask=mask)
67 self.failUnless(nddata.nfeatures == 1)
68 rmap = nddata.mapReverse([44])
69 self.failUnless(rmap.shape == (24, 96, 128))
70 self.failUnless(N.sum(rmap) == 44)
71 self.failUnless(rmap[12, 20, 40] == 44)
72
73
75 """Basic testing of map2Nifti
76 """
77 data = NiftiDataset(samples=os.path.join(pymvpa_dataroot,'example4d'),
78 labels=[1,2])
79
80
81 vol = data.map2Nifti(N.ones((294912,), dtype='int16'))
82 self.failUnless(vol.data.shape == (24, 96, 128))
83 self.failUnless((vol.data == 1).all())
84
85
86 vol = data.map2Nifti(data)
87 self.failUnless(vol.data.shape == (2, 24, 96, 128))
88
89
107
108
127
128
130 """Basic testing of ERNiftiDataset
131 """
132 self.failUnlessRaises(DatasetError, ERNiftiDataset)
133
134
135 tssrc = os.path.join(pymvpa_dataroot, 'bold')
136 evsrc = os.path.join(pymvpa_dataroot, 'fslev3.txt')
137
138 evs = FslEV3(evsrc).toEvents()
139
140
141
142 self.failUnlessRaises(ValueError, ERNiftiDataset,
143 samples=tssrc, events=evs)
144
145
146 for ev in evs:
147 ev['label'] = 1
148
149
150
151 ds = ERNiftiDataset(samples=tssrc, events=evs)
152
153
154 self.failUnless(ds.nfeatures == 7201)
155 self.failUnless(ds.nsamples == len(evs))
156
157
158 origsamples = getNiftiFromAnySource(tssrc).data
159 for i, ev in enumerate(evs):
160 self.failUnless((ds.samples[i][:-1] \
161 == origsamples[ev['onset']:ev['onset'] + ev['duration']].ravel()
162 ).all())
163
164
165 ds = ERNiftiDataset(samples=tssrc, events=evs, evconv=True,
166 storeoffset=True)
167 self.failUnless(ds.nsamples == len(evs))
168
169
170 self.failUnless(ds.nfeatures == 3202)
171
172
173 nim = ds.map2Nifti()
174 self.failUnless(nim.data.shape == origsamples.shape)
175
176 nim = ds.map2Nifti(ds.samples[0])
177 self.failUnless(nim.data.shape == (4, 1, 20, 40))
178
179
181 """Some mapping testing -- more tests is better
182 """
183 sample_size = (4, 3, 2)
184 samples = N.arange(120).reshape((5,) + sample_size)
185 dsmask = N.arange(24).reshape(sample_size)%2
186 ds = ERNiftiDataset(samples=NiftiImage(samples),
187 events=[Event(onset=0, duration=2, label=1,
188 chunk=1, features=[1000, 1001]),
189 Event(onset=1, duration=2, label=2,
190 chunk=1, features=[2000, 2001])],
191 mask=dsmask)
192 nfeatures = ds.mapper._mappers[1].getInSize()
193 mask = N.zeros(sample_size)
194 mask[0, 0, 0] = mask[1, 0, 1] = mask[0, 0, 1] = 1
195
196
197
198
199
200 ds_sel = ds.selectFeatures(
201 ds.mapper.forward([mask, [1]*nfeatures]).nonzero()[0])
202
203
204 self.failUnless((mask.reshape(24).nonzero()[0] == [0, 1, 7]).all())
205 self.failUnless(ds_sel.samples.shape == (2, 6),
206 msg="We should have selected all samples, and 6 "
207 "features (2 voxels at 2 timepoints + 2 features). "
208 "Got %s" % (ds_sel.samples.shape,))
209 self.failUnless((ds_sel.samples[:, -2:] ==
210 [[1000, 1001], [2000, 2001]]).all(),
211 msg="We should have selected additional features "
212 "correctly. Got %s" % ds_sel.samples[:, -2:])
213 self.failUnless((ds_sel.samples[:, :-2] ==
214 [[ 1, 7, 25, 31],
215 [ 25, 31, 49, 55]]).all(),
216 msg="We should have selected original features "
217 "correctly. Got %s" % ds_sel.samples[:, :-2])
218
219
221 """Test NiftiDataset based on 3D volume(s)
222 """
223 tssrc = os.path.join(pymvpa_dataroot, 'bold')
224 masrc = os.path.join(pymvpa_dataroot, 'mask')
225
226
227
228
229 self.failUnlessRaises(Exception, NiftiDataset,
230 masrc, mask=masrc, labels=1, enforce4D=False)
231
232 ds = NiftiDataset(masrc, mask=masrc, labels=1)
233
234 plain_data = NiftiImage(masrc).data
235
236 self.failUnless(N.all(plain_data == \
237 ds.map2Nifti().data.reshape(plain_data.shape)))
238
239
240
241
242 self.failUnlessRaises(ValueError, NiftiDataset, (masrc, tssrc),
243 mask=masrc, labels=1)
244
245
246 dsfull = NiftiDataset(tssrc, mask=masrc, labels=1)
247 ds_selected = dsfull['samples', [3]]
248 nifti_selected = ds_selected.map2Nifti()
249
250
251
252 labels = [123, 2, 123]
253 ds2 = NiftiDataset((masrc, masrc, nifti_selected),
254 mask=masrc, labels=labels)
255 self.failUnless(ds2.nsamples == 3)
256 self.failUnless((ds2.samples[0] == ds2.samples[1]).all())
257 self.failUnless((ds2.samples[2] == dsfull.samples[3]).all())
258 self.failUnless((ds2.labels == labels).all())
259
261 """Test if we could request neighbors within spherical ROI whenever
262 center is outside of the mask
263 """
264
265
266 mask_roi = N.zeros((24, 96, 128), dtype='bool')
267 mask_roi[12, 20, 38:42] = True
268 mask_roi[23, 20, 38:42] = True
269 ds_full = NiftiDataset(samples=os.path.join(pymvpa_dataroot,'example4d'),
270 labels=[1,2])
271 ds_roi = NiftiDataset(samples=os.path.join(pymvpa_dataroot,'example4d'),
272 labels=[1,2], mask=mask_roi)
273
274 ids_roi = ds_roi.mapper.getNeighbors(ds_roi.mapper.getOutId((12, 20, 40)),
275 radius=20)
276 self.failUnless(len(ids_roi) == 4)
277
278
279 self.failUnlessRaises(ValueError, ds_roi.mapper.getOutId, (12, 20, 37))
280
281
282 ids_out = []
283 for id_in in ds_full.mapper.getNeighborIn( (12, 20, 37), radius=20):
284 try:
285 ids_out.append(ds_roi.mapper.getOutId(id_in))
286 except ValueError:
287 pass
288 self.failUnless(ids_out == ids_roi)
289
291 """Test if loading scaled data works correctly
292
293 Is relevant only for pynifti interface -- nibabel always does scaling
294 """
295
296 orig_filename = os.path.join(pymvpa_dataroot,'mask.nii.gz')
297 filename = mktemp('mvpa', 'test_scl_nifti') + '.nii.gz'
298 ni = NiftiImage(orig_filename)
299 orig_value = ni.data[0, 3, 4]
300 ni.data[0, 3, 4] = 5
301
302 hdr = ni.header
303 hdr['scl_slope'] = 15
304 hdr['scl_inter'] = 100
305 ni.header = hdr
306 ni.save(filename)
307
308
309 ds_scaled = NiftiDataset(filename, labels=1)
310 ds_scaled_multi = NiftiDataset([orig_filename, filename], labels=1)
311 ds_raw = NiftiDataset(filename, labels=1, scale_data=False)
312
313 fid = ds_scaled.mapper.getOutId([0,3,4])
314
315 self.failUnlessEqual(ds_scaled.samples[0, fid], 175)
316 self.failUnlessEqual(ds_scaled_multi.samples[0, fid], orig_value)
317 self.failUnlessEqual(ds_scaled_multi.samples[1, fid], 175)
318
319 self.failUnlessEqual(ds_raw.samples[0, fid], 5)
320
321
322 self.failUnlessEqual(ds_scaled.niftihdr['scl_slope'], 15.0)
323 self.failUnlessEqual(ds_scaled.niftihdr['scl_inter'], 100)
324
325
326 self.failUnlessEqual(ds_scaled_multi.niftihdr['scl_slope'], 1.0)
327 self.failUnlessEqual(ds_scaled_multi.niftihdr['scl_inter'], 0)
328
329 self.failUnlessEqual(ds_raw.niftihdr['scl_slope'], 15.0)
330 self.failUnlessEqual(ds_raw.niftihdr['scl_inter'], 100)
331
332
333 ni0 = ds_scaled.map2Nifti(ds_raw.samples)
334 self.failUnlessEqual(ni0.header['scl_slope'], 1.0)
335 self.failUnlessEqual(ni0.header['scl_inter'], 0.)
336
337 self.failUnlessEqual(ds_scaled.niftihdr['scl_slope'], 15.0)
338 self.failUnlessEqual(ds_scaled.niftihdr['scl_inter'], 100)
339
340
341
342 self.failUnlessEqual(ni0.data[0, 0, 3, 4], 5)
343 os.remove(filename)
344
347
348
349 if __name__ == '__main__':
350 import runner
351