diff --git a/dmriprep/data/tests/dwi_b0.nii.gz b/dmriprep/data/tests/dwi_b0.nii.gz new file mode 100644 index 00000000..cfe4ad75 Binary files /dev/null and b/dmriprep/data/tests/dwi_b0.nii.gz differ diff --git a/dmriprep/interfaces/images.py b/dmriprep/interfaces/images.py index 1848c64d..a8fadb0e 100644 --- a/dmriprep/interfaces/images.py +++ b/dmriprep/interfaces/images.py @@ -33,7 +33,7 @@ traits, ) -from dmriprep.utils.images import extract_b0, median, rescale_b0 +from dmriprep.utils.images import extract_b0, rescale_b0, summarize_images LOGGER = logging.getLogger('nipype.interface') @@ -124,5 +124,5 @@ def _run_interface(self, runtime): self._results['out_b0s'], self._results['signal_drift'] = rescale_b0( self.inputs.in_file, self.inputs.mask_file, out_b0s ) - self._results['out_ref'] = median(self._results['out_b0s'], out_path=out_ref) + self._results['out_ref'] = summarize_images(self._results['out_b0s'], out_path=out_ref) return runtime diff --git a/dmriprep/utils/images.py b/dmriprep/utils/images.py index 6c2821f4..ca98e0f4 100644 --- a/dmriprep/utils/images.py +++ b/dmriprep/utils/images.py @@ -28,7 +28,32 @@ def extract_b0(in_file, b0_ixs, out_path=None): - """Extract the *b0* volumes from a DWI dataset.""" + """ + Extract the *b0* volumes from a DWI dataset. + + Parameters + ---------- + in_file : str + DWI NIfTI file. + b0_ixs : list + List of B0 indices in `in_file`. + out_path : str + Optionally specify an output path. + + Returns + ------- + out_path : str + 4D NIfTI file consisting of B0's. + + Examples + -------- + >>> os.chdir(tmpdir) + >>> b0_ixs = np.where(np.loadtxt(str(data_dir / 'bval')) <= 50)[0].tolist()[:2] + >>> in_file = str(data_dir / 'dwi.nii.gz') + >>> out_path = extract_b0(in_file, b0_ixs) + >>> assert os.path.isfile(out_path) + + """ if out_path is None: out_path = fname_presuffix(in_file, suffix='_b0') @@ -43,7 +68,33 @@ def extract_b0(in_file, b0_ixs, out_path=None): def rescale_b0(in_file, mask_file, out_path=None): - """Rescale the input volumes using the median signal intensity.""" + """ + Rescale the input volumes using the median signal intensity. + + Parameters + ---------- + in_file : str + A NIfTI file consisting of one or more B0's. + mask_file : str + A B0 mask NIFTI file. + out_path : str + Optionally specify an output path. + + Returns + ------- + out_path : str + A rescaled B0 NIFTI file. + + Examples + -------- + >>> os.chdir(tmpdir) + >>> mask_file = str(data_dir / 'dwi_mask.nii.gz') + >>> in_file = str(data_dir / 'dwi_b0.nii.gz') + >>> out_path, drifts = rescale_b0(in_file, mask_file) + >>> assert os.path.isfile(out_path) + + """ + if out_path is None: out_path = fname_presuffix(in_file, suffix='_rescaled', use_ext=True) @@ -65,8 +116,45 @@ def rescale_b0(in_file, mask_file, out_path=None): return out_path, signal_drift.tolist() -def median(in_file, out_path=None): - """Average a 4D dataset across the last dimension using median.""" +def summarize_images(in_file, method=np.median, dtype=None, out_path=None): + """ + Summarize a 4D dataset across the last dimension using a + callable method. + + Parameters + ---------- + in_file : str + A NIfTI file consisting of one or more 3D images. + method : callable + A numpy function such as `np.mean` or `np.median`. + dtype : str + Optioally specify a datatype (e.g. 'float32'). + out_path : str + Optionally specify an output path for `out_path`. + + Returns + ------- + out_path : str + A 3D NIFTI image file. + + Examples + -------- + >>> os.chdir(tmpdir) + >>> in_file = str(dipy_datadir / "HARDI193.nii.gz") + >>> # Median case + >>> out_path = summarize_images(in_file) + >>> assert os.path.isfile(out_path) + >>> # Mean case + >>> out_path = summarize_images(in_file, method=np.mean) + >>> assert os.path.isfile(out_path) + + """ + + if not callable(method): + raise ValueError('method must be callable') + + # TODO: Check that callable is applicable (i.e. contains `axis` arg). + # if method.__call__() if out_path is None: out_path = fname_presuffix(in_file, suffix='_b0ref', use_ext=True) @@ -77,8 +165,174 @@ def median(in_file, out_path=None): nb.squeeze_image(img).to_filename(out_path) return out_path - dtype = img.get_data_dtype() - median_data = np.median(img.get_fdata(), axis=-1) + summary_data = method(img.get_fdata(dtype=dtype), axis=-1) - nb.Nifti1Image(median_data.astype(dtype), img.affine, img.header).to_filename(out_path) + hdr = img.header.copy() + hdr.set_xyzt_units('mm') + if dtype is not None: + hdr.set_data_dtype(dtype) + else: + dtype = hdr.get_data_dtype() + nb.Nifti1Image(summary_data.astype(dtype), img.affine, hdr).to_filename(out_path) return out_path + + +def get_list_data(file_list, dtype=np.float32): + """ + Load 3D volumes from a list of file paths into a 4D array. + + Parameters + ---------- + file_list : str + A list of file paths to 3D NIFTI images. + + Returns + ------- + Nibabel image object + + Examples + -------- + >>> os.chdir(tmpdir) + >>> in_file = str(dipy_datadir / "HARDI193.nii.gz") + >>> out_files = save_4d_to_3d(in_file) + >>> assert len(out_files) == get_list_data(out_files).shape[-1] + """ + return nb.concat_images([nb.load(fname) for fname in file_list]).get_fdata(dtype=dtype) + + +def match_transforms(dwi_files, transforms, b0_ixs): + """ + Arrange the order of a list of transforms. + + This is a helper function for :abbr:`EMC (Eddy-currents and Motion Correction)`. + Sorts the input list of affine transforms to correspond with that of + each individual dwi volume file, accounting for the indices of :math:`b = 0` volumes. + + Parameters + ---------- + dwi_files : list + A list of file paths to 3D diffusion-weighted NIFTI volumes. + transforms : list + A list of ndarrays. + b0_ixs : list + List of B0 indices. + + Returns + ------- + nearest_affines : list + A list of affine file paths that correspond to each of the split + dwi volumes. + + Examples + -------- + >>> os.chdir(tmpdir) + >>> from dmriprep.utils.vectors import DiffusionGradientTable + >>> dwi_file = str(dipy_datadir / "HARDI193.nii.gz") + >>> check = DiffusionGradientTable( + ... dwi_file=dwi_file, + ... bvecs=str(dipy_datadir / "HARDI193.bvec"), + ... bvals=str(dipy_datadir / "HARDI193.bval")) + >>> check.generate_rasb() + >>> # Conform to the orientation of the image: + >>> affines = np.zeros((check.gradients.shape[0], 4, 4)) + >>> transforms = [] + >>> for ii, aff in enumerate(affines): + ... aff_file = f'aff_{ii}.npy' + ... np.save(aff_file, aff) + ... transforms.append(aff_file) + >>> dwi_files = save_4d_to_3d(dwi_file) + >>> b0_ixs = np.where((check.bvals) <= 50)[0].tolist()[:2] + >>> nearest_affines = match_transforms(dwi_files, transforms, b0_ixs) + >>> assert sum([os.path.isfile(i) for i in nearest_affines]) == len(nearest_affines) + >>> assert len(nearest_affines) == len(dwi_files) + """ + num_dwis = len(dwi_files) + num_transforms = len(transforms) + + if num_dwis == num_transforms: + return transforms + + # Do sanity checks + if not len(transforms) == len(b0_ixs): + raise Exception('number of transforms does not match number of b0 images') + + # Create a list of which emc affines go with each of the split images + nearest_affines = [] + for index in range(num_dwis): + nearest_b0_num = np.argmin(np.abs(index - np.array(b0_ixs))) + this_transform = transforms[nearest_b0_num] + nearest_affines.append(this_transform) + + return nearest_affines + + +def save_4d_to_3d(in_file): + """ + Split a 4D dataset along the last dimension into multiple 3D volumes. + + Parameters + ---------- + in_file : str + DWI NIfTI file. + + Returns + ------- + out_files : list + A list of file paths to 3d NIFTI images. + + Examples + -------- + >>> os.chdir(tmpdir) + >>> in_file = str(dipy_datadir / "HARDI193.nii.gz") + >>> out_files = save_4d_to_3d(in_file) + >>> assert len(out_files) == nb.load(in_file).shape[-1] + """ + filenii = nb.load(in_file) + if len(filenii.shape) != 4: + raise RuntimeError(f'Input image ({filenii}) is not 4D.') + + files_3d = nb.four_to_three(filenii) + out_files = [] + for i, file_3d in enumerate(files_3d): + out_file = fname_presuffix(in_file, suffix=f'_tmp_{i}') + file_3d.to_filename(out_file) + out_files.append(out_file) + del files_3d + return out_files + + +def save_3d_to_4d(in_files): + """ + Concatenate a list of 3D volumes into a 4D output. + + Parameters + ---------- + in_files : list + A list of file paths to 3D NIFTI images. + + Returns + ------- + out_file : str + A file path to a 4d NIFTI image of concatenated 3D volumes. + + Examples + -------- + >>> os.chdir(tmpdir) + >>> in_file = str(dipy_datadir / "HARDI193.nii.gz") + >>> threeD_files = save_4d_to_3d(in_file) + >>> out_file = save_3d_to_4d(threeD_files) + >>> assert len(threeD_files) == nb.load(out_file).shape[-1] + """ + # Remove one-sized extra dimensions + nii_list = [] + for _i, f in enumerate(in_files): + filenii = nb.load(f) + filenii = nb.squeeze_image(filenii) + if len(filenii.shape) != 3: + raise RuntimeError(f'Input image ({f}) is not 3D.') + else: + nii_list.append(filenii) + img_4d = nb.funcs.concat_images(nii_list) + out_file = fname_presuffix(in_files[0], suffix='_merged') + img_4d.to_filename(out_file) + return out_file