diff --git a/pyproject.toml b/pyproject.toml index c70dba162..ed19efbdd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -206,6 +206,7 @@ scil_sh_fusion = "scilpy.cli.scil_sh_fusion:main" scil_sh_to_aodf = "scilpy.cli.scil_sh_to_aodf:main" scil_sh_to_rish = "scilpy.cli.scil_sh_to_rish:main" scil_sh_to_sf = "scilpy.cli.scil_sh_to_sf:main" +scil_fodf_threshold_by_amplitude = "scilpy.cli.scil_fodf_threshold_by_amplitude:main" scil_stats_group_comparison = "scilpy.cli.scil_stats_group_comparison:main" scil_surface_assign_custom_color = "scilpy.cli.scil_surface_assign_custom_color:main" scil_surface_assign_uniform_color = "scilpy.cli.scil_surface_assign_uniform_color:main" diff --git a/src/scilpy/cli/scil_fodf_threshold_by_amplitude.py b/src/scilpy/cli/scil_fodf_threshold_by_amplitude.py new file mode 100644 index 000000000..f5de7f32c --- /dev/null +++ b/src/scilpy/cli/scil_fodf_threshold_by_amplitude.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Compute a binary mask based on a global SF threshold. +The script masks voxels where the maximum SF amplitude is below +either a relative factor or an absolute threshold (or both). + +When fODFs are evaluated on a sphere (SF), the amplitude of the lobes +corresponds to the strength of the diffusion signal in those directions. +Thresholding these amplitudes is a common practice to filter out spurious +peaks arising from noise or the deconvolution process (e.g., ringing effects). + +The absolute threshold can be estimated from the mean/median maximum fODF +in the ventricles, computed with scil_fodf_max_in_ventricles. + +If both --relative and --absolute are provided, the final threshold is the +maximum of the two resulting values. +""" + +import argparse +import logging + +import nibabel as nib +import numpy as np + +from scilpy.io.utils import (add_sh_basis_args, add_sphere_arg, + add_verbose_arg, add_overwrite_arg, + assert_inputs_exist, assert_outputs_exist, + parse_sh_basis_arg) +from scilpy.tracking.utils import compute_sf_threshold_mask +from scilpy.version import version_string + + +def _build_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter, + epilog=version_string) + + p.add_argument('in_odf', + help='Input ODF file (SH or Peaks) (.nii.gz).') + p.add_argument('out_mask', + help='Output binary mask (.nii.gz).') + + thr_g = p.add_argument_group('Threshold options') + thr_g.add_argument('--relative', type=float, + help='Global SF threshold relative factor (0-1).') + thr_g.add_argument('--absolute', type=float, + help='Global SF absolute threshold.') + add_sh_basis_args(p) + add_sphere_arg(p) + add_overwrite_arg(p) + add_verbose_arg(p) + + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + + if args.relative is None and args.absolute is None: + parser.error("At least one of --relative or --absolute must be " + "provided.") + + assert_inputs_exist(parser, args.in_odf) + assert_outputs_exist(parser, args, args.out_mask) + + sh_basis, is_legacy = parse_sh_basis_arg(args) + + logging.info("Loading ODF data.") + img = nib.load(args.in_odf) + data = img.get_fdata(dtype=np.float32) + + logging.info("Computing global SF threshold mask.") + mask, global_max, threshold = compute_sf_threshold_mask( + data, sphere_name=args.sphere, relative_factor=args.relative, + absolute_threshold=args.absolute, sh_basis=sh_basis, + is_legacy=is_legacy) + + logging.info("Global max SF amplitude: {:.4f}".format(global_max)) + if args.relative is not None and args.absolute is not None: + logging.info("Both relative and absolute thresholds used. " + "Final threshold: {:.4f}".format(threshold)) + elif args.relative is not None: + logging.info("Relative threshold: {:.4f} (Factor: {})" + .format(threshold, args.relative)) + else: + logging.info("Absolute threshold used: {:.4f}".format(args.absolute)) + + logging.info("Number of voxels in mask: {}".format(np.sum(mask))) + + # Save mask + mask_img = nib.Nifti1Image(mask.astype(np.uint8), img.affine, + img.header) + nib.save(mask_img, args.out_mask) + + +if __name__ == "__main__": + main() diff --git a/src/scilpy/cli/scil_tracking_local.py b/src/scilpy/cli/scil_tracking_local.py index bed9e5e7b..23ae42ae4 100755 --- a/src/scilpy/cli/scil_tracking_local.py +++ b/src/scilpy/cli/scil_tracking_local.py @@ -5,7 +5,7 @@ The tracking direction is chosen in the aperture cone defined by the previous tracking direction and the angular constraint. -WARNING: This script DOES NOT support asymetric FODF input (aFODF). +WARNING: This script DOES NOT support asymmetric FODF input (aFODF). Algo 'eudx': select the peak from the spherical function (SF) most closely aligned to the previous direction, and follow an average of it and the previous @@ -41,9 +41,10 @@ * Forward tracking: For GPU tracking, the `--forward_only` flag can be used to disable backward tracking. This option isn't available for CPU tracking. - * Random number generator seed (RNG): CPU and GPU use different RNG implementations,< - so the same `--seed` is reproducible within a backend but does not guarantee - identical streamlines across CPU vs GPU tracking. + * Random number generator seed (RNG): CPU and GPU use different RNG + implementations, so the same `--seed` is reproducible within a + backend but does not guarantee identical streamlines across CPU vs + GPU tracking. All the input nifti files must be in isotropic resolution. @@ -61,26 +62,27 @@ import logging from time import perf_counter -import nibabel as nib -import numpy as np -from nibabel.streamlines import TrkFile, detect_format - from dipy.data import get_sphere from dipy.tracking import utils as track_utils from dipy.tracking.local_tracking import LocalTracking from dipy.tracking.stopping_criterion import BinaryStoppingCriterion from dipy.tracking.tracker import eudx_tracking +import nibabel as nib +from nibabel.streamlines import TrkFile, detect_format +import numpy as np + from scilpy.io.image import get_data_as_mask from scilpy.io.utils import (add_sphere_arg, add_verbose_arg, assert_headers_compatible, assert_inputs_exist, - assert_outputs_exist, parse_sh_basis_arg, - verify_compression_th, load_matrix_in_any_format) + assert_outputs_exist, load_matrix_in_any_format, + parse_sh_basis_arg, verify_compression_th) from scilpy.tracking.tracker import GPUTracker from scilpy.tracking.utils import (add_mandatory_options_tracking, add_out_options, add_seeding_options, add_tracking_options, add_tracking_ptt_options, - get_direction_getter, get_theta, + get_direction_getter, + get_global_sf_threshold_mask, get_theta, save_tractogram, verify_seed_options, verify_streamline_length_options) from scilpy.version import version_string @@ -104,7 +106,7 @@ def _build_arg_parser(): # Other options, only available in this script: track_g.add_argument('--sh_to_pmf', action='store_true', - help='If set, map sherical harmonics to spherical ' + help='If set, map spherical harmonics to spherical ' 'function (pmf) before \ntracking (faster, ' 'requires more memory)') track_g.add_argument('--algo', default='prob', @@ -200,6 +202,17 @@ def main(): logging.debug("Loading masks and finding seeds.") mask_data = get_data_as_mask(nib.load(args.in_mask), dtype=bool) + # ODF data for thresholding + odf_sh_data = odf_sh_img.get_fdata(dtype=np.float32) + + sh_basis, is_legacy = parse_sh_basis_arg(args) + + sf_mask = None + if args.global_sf_rel_thr is not None or \ + args.global_sf_abs_thr is not None: + sf_mask = get_global_sf_threshold_mask( + odf_sh_data, args, sh_basis, is_legacy) + if args.npv: nb_seeds = args.npv seed_per_vox = True @@ -235,11 +248,16 @@ def main(): random_seed=args.seed) total_nb_seeds = len(seeds) + combined_mask = mask_data + if sf_mask is not None: + combined_mask = np.logical_and(mask_data, sf_mask) + if not args.use_gpu: # LocalTracking.maxlen is actually the maximum length # per direction, we need to filter post-tracking. max_steps_per_direction = int(args.max_length / args.step_size) - stopping_criterion = BinaryStoppingCriterion(mask_data) + + stopping_criterion = BinaryStoppingCriterion(combined_mask) logging.info("Starting CPU local tracking.") if args.algo == 'eudx': @@ -248,7 +266,7 @@ def main(): stopping_criterion, np.eye(4), pam=get_direction_getter( - args.in_odf, args.algo, args.sphere, + odf_sh_data, args.algo, args.sphere, args.sub_sphere, args.theta, sh_basis, voxel_size, args.sf_threshold, args.sh_to_pmf, args.probe_length, args.probe_radius, @@ -264,7 +282,7 @@ def main(): else: streamlines_generator = LocalTracking( get_direction_getter( - args.in_odf, args.algo, args.sphere, + odf_sh_data, args.algo, args.sphere, args.sub_sphere, args.theta, sh_basis, voxel_size, args.sf_threshold, args.sh_to_pmf, args.probe_length, args.probe_radius, @@ -284,14 +302,13 @@ def main(): max_strl_len = int(2.0 * args.max_length / args.step_size) + 1 # data volume - odf_sh = odf_sh_img.get_fdata(dtype=np.float32) # GPU tracking needs the full sphere sphere = get_sphere(name=args.sphere).subdivide(n=args.sub_sphere) logging.info("Starting GPU local tracking.") streamlines_generator = GPUTracker( - odf_sh, mask_data, seeds, + odf_sh_data, combined_mask, seeds, vox_step_size, max_strl_len, theta=get_theta(args.theta, args.algo), sf_threshold=args.sf_threshold, diff --git a/src/scilpy/cli/scil_tracking_local_dev.py b/src/scilpy/cli/scil_tracking_local_dev.py index 6b36fcb63..ff984f6d6 100755 --- a/src/scilpy/cli/scil_tracking_local_dev.py +++ b/src/scilpy/cli/scil_tracking_local_dev.py @@ -32,7 +32,7 @@ works well for deterministic tracking. However, in the context of probabilistic tracking, the next tracking directions cannot be estimated in advance, because they are picked randomly from a distribution. It is - therefore recommanded to keep the rk_order to 1 for probabilistic + therefore recommended to keep the rk_order to 1 for probabilistic tracking. 2. As a rule of thumb, doubling the rk_order will double the computation time in the worst case. @@ -58,39 +58,35 @@ """ import argparse +import json import logging import time -import json import dipy.core.geometry as gm +from dipy.io.stateful_tractogram import Origin, Space, StatefulTractogram +from dipy.io.streamline import save_tractogram import nibabel as nib +from nibabel.streamlines import TrkFile, detect_format import numpy as np -from dipy.io.stateful_tractogram import StatefulTractogram, Space -from dipy.io.stateful_tractogram import Origin -from dipy.io.streamline import save_tractogram -from nibabel.streamlines import detect_format, TrkFile - -from scilpy.io.image import assert_same_resolution -from scilpy.io.utils import (add_processes_arg, add_sphere_arg, - add_verbose_arg, - assert_inputs_exist, assert_outputs_exist, - parse_sh_basis_arg, verify_compression_th, - load_matrix_in_any_format) +from scilpy.image.labels import get_data_as_labels from scilpy.image.volume_space_management import DataVolume +from scilpy.io.image import assert_same_resolution, get_data_as_mask +from scilpy.io.utils import (add_processes_arg, add_sphere_arg, + add_verbose_arg, assert_inputs_exist, + assert_outputs_exist, load_matrix_in_any_format, + parse_sh_basis_arg, verify_compression_th) from scilpy.tracking.propagator import ODFPropagator from scilpy.tracking.rap import RAPContinue, RAPSwitch -from scilpy.tracking.seed import SeedGenerator, CustomSeedsDispenser +from scilpy.tracking.seed import CustomSeedsDispenser, SeedGenerator from scilpy.tracking.tracker import Tracker from scilpy.tracking.utils import (add_mandatory_options_tracking, add_out_options, add_seeding_options, add_tracking_options, - get_theta, - verify_streamline_length_options, - verify_seed_options) + get_global_sf_threshold_mask, get_theta, + verify_seed_options, + verify_streamline_length_options) from scilpy.version import version_string -from scilpy.image.labels import get_data_as_labels -from scilpy.io.image import get_data_as_mask def _build_arg_parser(): @@ -167,22 +163,24 @@ def _build_arg_parser(): rap_g = p.add_argument_group('Region-Adaptive Propagation options') rap_mode = rap_g.add_mutually_exclusive_group() rap_mode.add_argument('--rap_mask', default=None, - help='Region-Adaptive Propagation mask (.nii.gz).\n' - 'Region-Adaptive Propagation tractography will start within ' - 'this mask.') + help='Region-Adaptive Propagation mask ' + '(.nii.gz).\nRegion-Adaptive Propagation ' + 'tractography will start within this mask.') rap_mode.add_argument('--rap_labels', default=None, - help='Region-Adaptive Propagation label volume (.nii.gz) .\n' - 'Voxel values are integer labels (0=background, 1..N=regions) .\n' - 'Used with --rap_method switch to select policies per label.') + help='Region-Adaptive Propagation label volume ' + '(.nii.gz) .\nVoxel values are integer labels ' + '(0=background, 1..N=regions) .\nUsed with ' + '--rap_method switch to select policies per ' + 'label.') rap_g.add_argument('--rap_method', default='None', choices=['None', 'continue', 'switch'], - help="Region-Adaptive Propagation tractography method.\n" - "'continue': continues tracking with same params,\n" - "'switch': switches tracking params inside RAP mask.\n" - " [%(default)s]") + help="Region-Adaptive Propagation tractography " + "method.\n'continue': continues tracking with " + "same params,\n'switch': switches tracking " + "params inside RAP mask.\n [%(default)s]") rap_g.add_argument('--rap_save_entry_exit', default=None, - help='Save RAP entry/exit coordinates as a binary mask.\n' - 'Provide output filename (.nii.gz).') + help='Save RAP entry/exit coordinates as a binary ' + 'mask.\nProvide output filename (.nii.gz).') m_g = p.add_argument_group('Memory options') add_processes_arg(m_g) @@ -206,7 +204,8 @@ def main(): if args.rap_params: with open(args.rap_params, 'r') as f: rap_params = json.load(f) - filenames = [cfg['filename'] for cfg in rap_params.get('methods', {}).values() + filenames = [cfg['filename'] for cfg + in rap_params.get('methods', {}).values() if 'filename' in cfg] assert_inputs_exist(parser, filenames) @@ -218,7 +217,8 @@ def main(): verify_compression_th(args.compress_th) verify_seed_options(parser, args) - if (args.rap_mask is not None or args.rap_labels is not None) and args.rap_method == "None": + if (args.rap_mask is not None or args.rap_labels is not None) \ + and args.rap_method == "None": parser.error('No RAP method selected.') if args.rap_method == 'continue' and args.rap_mask is None: parser.error('RAP method "continue" requires --rap_mask.') @@ -231,6 +231,12 @@ def main(): 'RAP method "switch" requires --rap_params to be specified.') if args.rap_params is not None and args.rap_method != 'switch': parser.error('--rap_params can only be used with --rap_method switch.') + + if (args.global_sf_rel_thr is not None or + args.global_sf_abs_thr is not None) and not args.in_odf: + parser.error('Global SF thresholding requires a global ODF ' + '(--in_odf).') + tracts_format = detect_format(args.out_tractogram) if tracts_format is not TrkFile: logging.warning("You have selected option --save_seeds but you are " @@ -290,8 +296,6 @@ def main(): logging.info("Loading tracking mask.") mask_img = nib.load(args.in_mask) mask_data = mask_img.get_fdata(caching='unchanged', dtype=float) - mask_res = mask_img.header.get_zooms()[:3] - mask = DataVolume(mask_data, mask_res, args.mask_interp) # ------- INSTANTIATING PROPAGATOR ------- if args.in_odf: @@ -301,6 +305,15 @@ def main(): odf_sh_res = odf_sh_img.header.get_zooms()[:3] dataset = DataVolume(odf_sh_data, odf_sh_res, args.sh_interp) + sh_basis, is_legacy = parse_sh_basis_arg(args) + + if args.global_sf_rel_thr is not None or \ + args.global_sf_abs_thr is not None: + sf_mask = get_global_sf_threshold_mask( + odf_sh_data, args, sh_basis, is_legacy) + + mask_data = np.logical_and(mask_data, sf_mask) + logging.info("Instantiating propagator.") # Converting step size to vox space # We only support iso vox for now but allow slightly different vox @@ -312,7 +325,6 @@ def main(): # Using space and origin in the propagator: vox and center, like # in dipy. - sh_basis, is_legacy = parse_sh_basis_arg(args) propagator = ODFPropagator( dataset, vox_step_size, args.rk_order, args.algo, sh_basis, @@ -334,7 +346,8 @@ def main(): odf_sh_img = nib.load(filename) odf_sh_res = odf_sh_img.header.get_zooms()[:3] voxel_size = odf_sh_img.header.get_zooms()[0] - vox_step_size = cfg.get('step_size', args.step_size) / voxel_size + vox_step_size = cfg.get('step_size', + args.step_size) / voxel_size loaded_datasets[filename] = DataVolume( odf_sh_img.get_fdata(caching='unchanged', dtype=float), odf_sh_res, args.sh_interp) @@ -344,7 +357,8 @@ def main(): sh_basis = ('descoteaux07' if 'descoteaux07' in sh_basis_name else 'tournier07') algo = cfg.get('algo', args.algo) - theta = gm.math.radians(get_theta(cfg.get('theta', args.theta), algo)) + theta = gm.math.radians(get_theta( + cfg.get('theta', args.theta), algo)) is_legacy = 'legacy' in sh_basis_name # Build propagator from rap_policies file @@ -368,6 +382,9 @@ def main(): if propagator is None and propagators: propagator = next(iter(propagators.values())) + mask_res = mask_img.header.get_zooms()[:3] + mask = DataVolume(mask_data, mask_res, args.mask_interp) + # ------- INSTANTIATING RAP OBJECT ------- if args.rap_mask: logging.info("Loading RAP mask.") diff --git a/src/scilpy/cli/scil_tracking_pft.py b/src/scilpy/cli/scil_tracking_pft.py index 1161c2ca8..1dedeb388 100755 --- a/src/scilpy/cli/scil_tracking_pft.py +++ b/src/scilpy/cli/scil_tracking_pft.py @@ -3,7 +3,7 @@ """ Local streamline HARDI tractography including Particle Filtering tracking. -WARNING: This script DOES NOT support asymetric FODF input (aFODF). +WARNING: This script DOES NOT support asymmetric FODF input (aFODF). The tracking is done inside partial volume estimation maps and uses the particle filtering tractography (PFT) algorithm. See @@ -20,14 +20,14 @@ from the SF. For streamline compression, a rule of thumb is to set it to 0.1mm for the -deterministic algorithm and 0.2mm for probabilitic algorithm. +deterministic algorithm and 0.2mm for probabilistic algorithm. All the input nifti files must be in isotropic resolution. ----------------------------------------------------------------------------- Reference: [1] Girard, G., Whittingstall K., Deriche, R., and Descoteaux, M. (2014). - Towards quantitative connectivity analysis: reducing tractographybiases. + Towards quantitative connectivity analysis: reducing tractography biases. Neuroimage. ----------------------------------------------------------------------------- """ @@ -36,25 +36,24 @@ import logging from dipy.data import get_sphere, HemiSphere -from dipy.direction import (ProbabilisticDirectionGetter, - DeterministicMaximumDirectionGetter) -from dipy.io.utils import (get_reference_info, - create_tractogram_header) +from dipy.direction import (DeterministicMaximumDirectionGetter, + ProbabilisticDirectionGetter) +from dipy.io.utils import (create_tractogram_header, get_reference_info) +from dipy.tracking import utils as track_utils from dipy.tracking.local_tracking import ParticleFilteringTracking from dipy.tracking.stopping_criterion import (ActStoppingCriterion, CmcStoppingCriterion) -from dipy.tracking import utils as track_utils -from dipy.tracking.streamlinespeed import length, compress_streamlines +from dipy.tracking.streamlinespeed import compress_streamlines, length import nibabel as nib from nibabel.streamlines import LazyTractogram import numpy as np from scilpy.io.image import get_data_as_mask -from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, - add_verbose_arg, assert_inputs_exist, - assert_outputs_exist, parse_sh_basis_arg, - assert_headers_compatible, add_compression_arg, - verify_compression_th) +from scilpy.io.utils import (add_compression_arg, add_overwrite_arg, + add_sh_basis_args, add_sphere_arg, + add_verbose_arg, assert_headers_compatible, + assert_inputs_exist, assert_outputs_exist, + parse_sh_basis_arg, verify_compression_th) from scilpy.tracking.utils import get_theta from scilpy.version import version_string @@ -102,13 +101,27 @@ def _build_arg_parser(): 'criterion (CMC).') track_g.add_argument('--sfthres', dest='sf_threshold', type=float, default=0.1, - help='Spherical function relative threshold. ' - '[%(default)s]') + help='Spherical function relative threshold ' + 'within each voxel. [%(default)s]') + global_sf_g = track_g.add_mutually_exclusive_group() + global_sf_g.add_argument('--global_sf_rel_thr', metavar='FACTOR', + type=float, nargs='?', const=0.1, default=None, + help='Global SF relative threshold factor. ' + 'If set, masks voxels where\nmaximum SF ' + 'amplitude < FACTOR * global maximum SF ' + 'amplitude. \nIf used without a value, ' + 'default is [%(const)s].') + global_sf_g.add_argument('--global_sf_abs_thr', metavar='ABS_THR', + type=float, + help='Global SF absolute threshold. ' + 'If set, masks voxels where \n' + 'maximum SF amplitude < ABS_THR.') track_g.add_argument('--sfthres_init', dest='sf_threshold_init', type=float, default=0.5, help='Spherical function relative threshold value ' 'for the \ninitial direction. [%(default)s]') add_sh_basis_args(track_g) + add_sphere_arg(track_g, symmetric_only=False) seed_group = p.add_argument_group( 'Seeding options', @@ -193,14 +206,23 @@ def main(): parser.error( 'SH file is not isotropic. Tracking cannot be ran robustly.') - tracking_sphere = HemiSphere.from_sphere(get_sphere(name='repulsion724')) + fodf_sh_data = fodf_sh_img.get_fdata(dtype=np.float32) + + sh_basis, is_legacy = parse_sh_basis_arg(args) + + sf_mask = None + if args.global_sf_rel_thr is not None or \ + args.global_sf_abs_thr is not None: + from scilpy.tracking.utils import get_global_sf_threshold_mask + sf_mask = get_global_sf_threshold_mask( + fodf_sh_data, args, sh_basis, is_legacy) + + tracking_sphere = HemiSphere.from_sphere(get_sphere(name=args.sphere)) # Check if sphere is unit, since we couldn't find such check in Dipy. if not np.allclose(np.linalg.norm(tracking_sphere.vertices, axis=1), 1.): raise RuntimeError('Tracking sphere should be unit normed.') - sh_basis, is_legacy = parse_sh_basis_arg(args) - if args.algo == 'det': dgklass = DeterministicMaximumDirectionGetter else: @@ -213,7 +235,7 @@ def main(): # relative_peak_threshold is for initial directions filtering # min_separation_angle is the initial separation angle for peak extraction dg = dgklass.from_shcoeff( - fodf_sh_img.get_fdata(dtype=np.float32), + fodf_sh_data, max_angle=theta, sphere=tracking_sphere, basis_type=sh_basis, @@ -223,18 +245,26 @@ def main(): map_include_img = nib.load(args.in_map_include) map_exclude_img = nib.load(args.map_exclude_file) + + map_include_data = map_include_img.get_fdata(dtype=np.float32) + map_exclude_data = map_exclude_img.get_fdata(dtype=np.float32) + + if sf_mask is not None: + map_include_data[~sf_mask] = 0 + map_exclude_data[~sf_mask] = 1 + voxel_size = np.average(map_include_img.header['pixdim'][1:4]) if not args.act: tissue_classifier = CmcStoppingCriterion( - map_include_img.get_fdata(dtype=np.float32), - map_exclude_img.get_fdata(dtype=np.float32), + map_include_data, + map_exclude_data, step_size=args.step_size, average_voxel_size=voxel_size) else: tissue_classifier = ActStoppingCriterion( - map_include_img.get_fdata(dtype=np.float32), - map_exclude_img.get_fdata(dtype=np.float32)) + map_include_data, + map_exclude_data) if args.npv: nb_seeds = args.npv diff --git a/src/scilpy/cli/tests/test_fodf_threshold_by_amplitude.py b/src/scilpy/cli/tests/test_fodf_threshold_by_amplitude.py new file mode 100644 index 000000000..9090d47b4 --- /dev/null +++ b/src/scilpy/cli/tests/test_fodf_threshold_by_amplitude.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- + +import os +import tempfile + +import nibabel as nib +import numpy as np + +from scilpy.tests.arrays import fodf_3x3_order8_descoteaux07 + + +def test_help_option(script_runner): + ret = script_runner.run(['scil_fodf_threshold_by_amplitude', '--help']) + assert ret.success + + +def test_execution(script_runner): + with tempfile.TemporaryDirectory() as tmp_dir: + in_sh = os.path.join(tmp_dir, 'sh.nii.gz') + out_mask = os.path.join(tmp_dir, 'mask.nii.gz') + + # Create fake SH file + affine = np.eye(4) + img = nib.Nifti1Image(fodf_3x3_order8_descoteaux07.astype(np.float32), + affine) + nib.save(img, in_sh) + + # Run with relative threshold + ret = script_runner.run(['scil_fodf_threshold_by_amplitude', + in_sh, out_mask, '--relative', '0.5', + '--sh_basis', 'descoteaux07']) + assert ret.success + assert os.path.exists(out_mask) + + # Run with absolute threshold + ret = script_runner.run(['scil_fodf_threshold_by_amplitude', + in_sh, out_mask, '--absolute', '0.1', + '--sh_basis', 'descoteaux07', '-f']) + assert ret.success + assert os.path.exists(out_mask) + + # Run with both thresholds + ret = script_runner.run(['scil_fodf_threshold_by_amplitude', + in_sh, out_mask, '--relative', '0.5', + '--absolute', '0.1', + '--sh_basis', 'descoteaux07', '-f']) + + assert ret.success + assert os.path.exists(out_mask) diff --git a/src/scilpy/cli/tests/test_tracking_local.py b/src/scilpy/cli/tests/test_tracking_local.py index 2a06a9294..435113d0f 100644 --- a/src/scilpy/cli/tests/test_tracking_local.py +++ b/src/scilpy/cli/tests/test_tracking_local.py @@ -202,3 +202,23 @@ def test_execution_tracking_fodf_custom_seeds(script_runner, monkeypatch): '--compress', '0.1', '--sh_basis', 'descoteaux07', '--min_length', '20', '--max_length', '200']) assert ret.success + + +def test_execution_tracking_global_sf_threshold(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_fodf = os.path.join(SCILPY_HOME, 'tracking', 'fodf.nii.gz') + in_mask = os.path.join(SCILPY_HOME, 'tracking', 'seeding_mask.nii.gz') + + # Test with relative threshold + ret = script_runner.run(['scil_tracking_local', in_fodf, + in_mask, in_mask, 'global_sf_rel.trk', + '--nt', '10', '--global_sf_rel_thr', '0.1', + '--sh_basis', 'descoteaux07']) + assert ret.success + + # Test with absolute threshold + ret = script_runner.run(['scil_tracking_local', in_fodf, + in_mask, in_mask, 'global_sf_abs.trk', + '--nt', '10', '--global_sf_abs_thr', '0.05', + '--sh_basis', 'descoteaux07', '-f']) + assert ret.success diff --git a/src/scilpy/cli/tests/test_tracking_local_dev.py b/src/scilpy/cli/tests/test_tracking_local_dev.py index f3ffeca45..a3b9216c2 100644 --- a/src/scilpy/cli/tests/test_tracking_local_dev.py +++ b/src/scilpy/cli/tests/test_tracking_local_dev.py @@ -80,3 +80,16 @@ def test_execution_tracking_fodf_custom_seeds(script_runner, monkeypatch): '--sub_sphere', '2', '--rk_order', '4']) assert ret.success + + +def test_execution_tracking_global_sf_threshold(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_fodf = os.path.join(SCILPY_HOME, 'tracking', 'fodf.nii.gz') + in_mask = os.path.join(SCILPY_HOME, 'tracking', 'seeding_mask.nii.gz') + + # Test with relative threshold + ret = script_runner.run(['scil_tracking_local_dev', in_mask, in_mask, + 'local_prob_sf.trk', '--in_odf', in_fodf, + '--nt', '10', '--global_sf_rel_thr', '0.1', + '--sh_basis', 'descoteaux07']) + assert ret.success diff --git a/src/scilpy/reconst/tests/test_sf_threshold.py b/src/scilpy/reconst/tests/test_sf_threshold.py new file mode 100644 index 000000000..380b6f4f9 --- /dev/null +++ b/src/scilpy/reconst/tests/test_sf_threshold.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +import numpy as np +import pytest + +from scilpy.tracking.utils import (compute_max_sf_amplitude, + compute_sf_threshold_mask) + +from scilpy.tests.arrays import fodf_3x3_order8_descoteaux07 + + +def test_compute_max_sf_amplitude(): + # Test with SH data + max_sf = compute_max_sf_amplitude(fodf_3x3_order8_descoteaux07, + sh_basis='descoteaux07', + is_legacy=True) + assert max_sf.shape == (3, 3, 1) + assert np.all(max_sf >= 0) + + # Test with mask + mask = np.zeros((3, 3, 1), dtype=bool) + mask[1, 1, 0] = True + max_sf_masked = compute_max_sf_amplitude(fodf_3x3_order8_descoteaux07, + sh_basis='descoteaux07', + is_legacy=True, + mask=mask) + assert np.count_nonzero(max_sf_masked) == 1 + assert max_sf_masked[1, 1, 0] == max_sf[1, 1, 0] + + +def test_compute_sf_threshold_mask_sh(): + # Test relative threshold + mask, global_max, threshold = compute_sf_threshold_mask( + fodf_3x3_order8_descoteaux07, relative_factor=0.5, + sh_basis='descoteaux07', is_legacy=True, postprocess_mask=False) + + assert mask.shape == (3, 3, 1) + assert global_max == np.max(compute_max_sf_amplitude( + fodf_3x3_order8_descoteaux07, sh_basis='descoteaux07', is_legacy=True)) + assert threshold == 0.5 * global_max + assert np.all(mask == (compute_max_sf_amplitude( + fodf_3x3_order8_descoteaux07, sh_basis='descoteaux07', + is_legacy=True) >= threshold)) + + # Test absolute threshold + mask, global_max, threshold = compute_sf_threshold_mask( + fodf_3x3_order8_descoteaux07, absolute_threshold=0.1, + sh_basis='descoteaux07', is_legacy=True, postprocess_mask=False) + assert threshold == 0.1 + + +def test_compute_sf_threshold_mask_peaks(): + # Test with 4D peaks (3*npeaks) + peaks_4d = np.zeros((3, 3, 1, 6)) # 2 peaks + peaks_4d[1, 1, 0, :3] = [1, 0, 0] # norm 1 + peaks_4d[2, 2, 0, 3:] = [0, 0, 2] # norm 2 + + mask, global_max, threshold = compute_sf_threshold_mask( + peaks_4d, relative_factor=0.6, postprocess_mask=False) + + assert global_max == 2.0 + assert threshold == 1.2 + assert np.count_nonzero(mask) == 1 + assert mask[2, 2, 0] + + # Test with 5D peaks (npeaks, 3) + peaks_5d = np.zeros((3, 3, 1, 2, 3)) + peaks_5d[1, 1, 0, 0, :] = [1, 0, 0] + peaks_5d[2, 2, 0, 1, :] = [0, 0, 2] + + mask, global_max, threshold = compute_sf_threshold_mask( + peaks_5d, relative_factor=0.6, postprocess_mask=False) + assert global_max == 2.0 + assert np.count_nonzero(mask) == 1 + + +def test_compute_sf_threshold_mask_edge_cases(): + # Test relative_factor validation + with pytest.raises(ValueError): + compute_sf_threshold_mask(fodf_3x3_order8_descoteaux07, + relative_factor=1.5) + + # Test zero data + zero_data = np.zeros((3, 3, 1, 45)) + mask, global_max, threshold = compute_sf_threshold_mask( + zero_data, relative_factor=0.5, sh_basis='descoteaux07', + is_legacy=True) + assert global_max == 0 + assert not np.any(mask) + + # Test no params + with pytest.raises(ValueError): + compute_sf_threshold_mask(fodf_3x3_order8_descoteaux07) + + +def test_compute_sf_threshold_mask_postprocess(): + # Create a mask with two components + data = np.zeros((10, 10, 10, 6)) # 4D peaks + data[2:5, 2:5, 2:5, :3] = [1, 0, 0] # Large component + data[8, 8, 8, :3] = [1, 0, 0] # Small component + + mask, _, _ = compute_sf_threshold_mask( + data, relative_factor=0.5, postprocess_mask=True) + + # Only large component should remain + assert np.count_nonzero(mask) == 27 + assert not mask[8, 8, 8] + assert mask[3, 3, 3] diff --git a/src/scilpy/reconst/utils.py b/src/scilpy/reconst/utils.py index 9e5166a77..8b0af4946 100644 --- a/src/scilpy/reconst/utils.py +++ b/src/scilpy/reconst/utils.py @@ -1,7 +1,9 @@ # -*- coding: utf-8 -*- -from dipy.direction.peaks import peak_directions +import logging + import numpy as np +from dipy.direction.peaks import peak_directions def find_order_from_nb_coeff(data): @@ -106,6 +108,8 @@ def is_data_peaks(img_data): if last_dim % 3 == 0: norm = np.linalg.norm(data_nz.reshape(-1, 3), axis=-1) if np.all(np.isclose(norm, norm[0])): + logging.warning("All peaks have the same norm. They might be " + "already normalized.") return True # If the max is in the first triplet but not at index 0, it's likely Peaks. diff --git a/src/scilpy/tracking/utils.py b/src/scilpy/tracking/utils.py index e9c4da576..23a7a622b 100644 --- a/src/scilpy/tracking/utils.py +++ b/src/scilpy/tracking/utils.py @@ -2,13 +2,7 @@ import logging from typing import Iterable -import nibabel as nib -import numpy as np -from nibabel.streamlines import TrkFile -from nibabel.streamlines.tractogram import LazyTractogram, TractogramItem -from tqdm import tqdm - -from dipy.core.sphere import HemiSphere +from dipy.core.sphere import HemiSphere, Sphere from dipy.data import get_sphere from dipy.direction import (DeterministicMaximumDirectionGetter, ProbabilisticDirectionGetter, PTTDirectionGetter) @@ -16,9 +10,17 @@ from dipy.io.utils import create_tractogram_header, get_reference_info from dipy.reconst.shm import sh_to_sf_matrix from dipy.tracking.streamlinespeed import compress_streamlines, length +import nibabel as nib +from nibabel.streamlines import TrkFile +from nibabel.streamlines.tractogram import LazyTractogram, TractogramItem +import numpy as np +import scipy.ndimage as ndi +from tqdm import tqdm + from scilpy.io.utils import (add_compression_arg, add_overwrite_arg, add_sh_basis_args) -from scilpy.reconst.utils import find_order_from_nb_coeff, get_maximas +from scilpy.reconst.utils import (find_order_from_nb_coeff, get_maximas, + is_data_peaks) class TrackingDirection(list): @@ -45,26 +47,26 @@ def add_mandatory_options_tracking(p, fodf_optional=False): 'file (.nii.gz). Ex: ODF or fODF. \n' 'If not provided, fODF info must be \n' 'specified in rap_policies.json.') - odf_group.add_argument('--rap_params', default=None, - help='JSON file containing RAP parameters, \n' - 'mutually exclusive with --in_odf.\n' - 'Required for --rap_method switch.\n' - 'Expected format:\n' - '{\n' - ' "methods": {\n' - ' "1": {"propagator": "ODF", "filename": str,\n' - ' "sh_basis": str, "algo": str,\n' - ' "theta": float, "step_size": float},\n' - ' "2": {"propagator": "ODF", "filename": str,\n' - ' "sh_basis": str, "algo": str,\n' - ' "theta": float, "step_size": float}\n' - ' }\n' - '}') + odf_group.add_argument( + '--rap_params', default=None, + help='JSON file containing RAP parameters, mutually exclusive ' + 'with --in_odf.\nRequired for --rap_method switch.\n' + 'Expected format:\n' + '{\n' + ' "methods": {\n' + ' "1": {"propagator": "ODF", "filename": str,\n' + ' "sh_basis": str, "algo": str,\n' + ' "theta": float, "step_size": float},\n' + ' "2": {"propagator": "ODF", "filename": str,\n' + ' "sh_basis": str, "algo": str,\n' + ' "theta": float, "step_size": float}\n' + ' }\n' + '}') else: p.add_argument('in_odf', - help='File containing the orientation diffusion function \n' - 'as spherical harmonics file (.nii.gz). \n' - 'Ex: ODF or fODF.') + help='File containing the orientation diffusion ' + 'function \nas spherical harmonics file ' + '(.nii.gz). \nEx: ODF or fODF.') p.add_argument('in_seed', help='Seeding mask (.nii.gz).') p.add_argument('in_mask', @@ -99,8 +101,21 @@ def add_tracking_options(p): '["eudx"=60, "det"=45, "prob"=20, "ptt"=20]') track_g.add_argument('--sfthres', dest='sf_threshold', metavar='sf_th', type=float, default=0.1, - help='Spherical function relative threshold. ' - '[%(default)s]') + help='Spherical function relative threshold ' + 'within each voxel. [%(default)s]') + global_sf_g = track_g.add_mutually_exclusive_group() + global_sf_g.add_argument('--global_sf_rel_thr', metavar='FACTOR', + type=float, nargs='?', const=0.1, default=None, + help='Global SF relative threshold factor. ' + 'If set, masks voxels where\nmaximum SF ' + 'amplitude < FACTOR * global maximum SF ' + 'amplitude. \nIf used without a value, ' + 'default is [%(const)s].') + global_sf_g.add_argument('--global_sf_abs_thr', metavar='ABS_THR', + type=float, + help='Global SF absolute threshold. ' + 'If set, masks voxels where \n' + 'maximum SF amplitude < ABS_THR.') add_sh_basis_args(track_g) return track_g @@ -165,7 +180,7 @@ def add_out_options(p): """ out_g = p.add_argument_group('Output options') msg = ("\nA rule of thumb is to set it to 0.1mm for deterministic \n" - "streamlines and to 0.2mm for probabilitic streamlines.") + "streamlines and to 0.2mm for probabilistic streamlines.") add_compression_arg(out_g, additional_msg=msg) add_overwrite_arg(out_g) @@ -289,7 +304,7 @@ def tracks_generator_wrapper(): nib.streamlines.save(tractogram, out_tractogram, header=header) -def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, +def get_direction_getter(img_data, algo, sphere, sub_sphere, theta, sh_basis, voxel_size, sf_threshold, sh_to_pmf, probe_length, probe_radius, probe_quality, probe_count, support_exponent, is_legacy=True): @@ -297,8 +312,8 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, Parameters ---------- - in_img: str - Path to the input odf file. + img_data: np.ndarray + ODF data (SH or Peaks). algo: str Algorithm to use for tracking. Can be 'det', 'prob', 'ptt' or 'eudx'. sphere: str @@ -314,7 +329,7 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, sf_threshold: float Spherical function-amplitude threshold for tracking. sh_to_pmf: bool - Map sherical harmonics to spherical function (pmf) before tracking + Map spherical harmonics to spherical function (pmf) before tracking (faster, requires more memory). probe_length : float The length of the probes. Shorter probe_length @@ -343,15 +358,11 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, dg: dipy.direction.DirectionGetter The direction getter object. """ - img_data = nib.load(in_img).get_fdata(dtype=np.float32) - sphere = HemiSphere.from_sphere( get_sphere(name=sphere)).subdivide(n=sub_sphere) # Theta depends on user choice and algorithm theta = get_theta(theta, algo) - - from scilpy.reconst.utils import is_data_peaks is_peaks = is_data_peaks(img_data) if algo in ['det', 'prob', 'ptt']: @@ -413,10 +424,10 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, peak_values = np.zeros((img_shape_3d + (npeaks,))) peak_indices = np.full((img_shape_3d + (npeaks,)), -1, dtype='int') - b_matrix, _ = sh_to_sf_matrix(sphere, - sh_order_max=find_order_from_nb_coeff( - img_data), - basis_type=sh_basis, legacy=is_legacy) + b_matrix, _ = sh_to_sf_matrix( + sphere, + sh_order_max=find_order_from_nb_coeff(img_data), + basis_type=sh_basis, legacy=is_legacy) for idx in np.argwhere(np.sum(img_data, axis=-1)): idx = tuple(idx) @@ -469,3 +480,221 @@ def sample_distribution(dist, random_generator: np.random.Generator): return None return cdf.searchsorted(random_generator.random() * cdf[-1]) + + +def compute_max_sf_amplitude(data, sh_basis, is_legacy, + sphere_name='repulsion100', mask=None): + """ + Compute the maximum SF amplitude for each voxel. + Only computes SF for voxels where data is non-zero (or in mask) to save + RAM. + + This information can be used to compute a global threshold for SF + amplitude, which is often used to filter out spurious peaks in fODF. + + Parameters + ---------- + data : np.ndarray + ODF data (SH). + sh_basis : str + SH basis ('tournier07' or 'descoteaux07'). + is_legacy : bool + Whether the SH basis is legacy. + sphere_name : str or dipy.core.sphere.Sphere, optional + Sphere name for SF conversion or Sphere object. + mask : np.ndarray, optional + Binary mask. If provided, only voxels in mask are computed. + + Returns + ------- + max_sf : np.ndarray + Maximum SF amplitude per voxel. + """ + if mask is None: + mask = np.any(data, axis=-1) + + order = find_order_from_nb_coeff(data) + if isinstance(sphere_name, (Sphere,)): + sphere = sphere_name + else: + sphere = get_sphere(name=sphere_name) + + b_matrix, _ = sh_to_sf_matrix(sphere, sh_order_max=order, + basis_type=sh_basis, legacy=is_legacy) + + max_sf = np.zeros(data.shape[:-1], dtype=np.float32) + if np.any(mask): + # Vectorized SF computation for masked voxels + sf = np.dot(data[mask], b_matrix) + max_sf[mask] = np.max(sf, axis=-1) + + return max_sf + + +def compute_sf_threshold_mask(data, sphere_name='repulsion100', + relative_factor=None, + absolute_threshold=None, + sh_basis='descoteaux07', + is_legacy=True, postprocess_mask=True, + size_percentage=0.05): + """ + Compute a binary mask based on a global SF amplitude threshold. + + In SF obtained from fODF, the amplitude of the lobes corresponds to the + strength of the diffusion signal in those directions. Thresholding these + amplitudes is a common practice to filter out spurious peaks. + + Parameters + ---------- + data : np.ndarray + ODF data (SH or Peaks). + sphere_name : str or dipy.core.sphere.Sphere, optional + Sphere name for SF conversion or Sphere object. + relative_factor : float, optional + Factor between 0 and 1. Threshold is factor * global_max_sf. + absolute_threshold : float, optional + Absolute threshold on SF amplitude. + sh_basis : str, optional + SH basis ('tournier07' or 'descoteaux07'). + is_legacy : bool, optional + Whether the SH basis is legacy. + postprocess_mask : bool, optional + Whether to postprocess the mask to keep only the largest component. + size_percentage : float, optional + If postprocess_mask is True, percentage of the largest component size + under which a hole will be filled. + + Returns + ------- + mask : np.ndarray + Binary mask. + global_max : float + Global maximum SF amplitude. + threshold : float + Computed threshold value. + """ + if relative_factor is None and absolute_threshold is None: + raise ValueError("Either relative_factor or absolute_threshold " + "must be provided.") + + is_peaks = is_data_peaks(data) + if is_peaks: + if data.ndim == 5: + if data.shape[-1] != 3: + raise ValueError("5D peaks input must have 3 " + "as last dimension.") + peaks = data + elif data.ndim == 4: + npeaks = data.shape[-1] // 3 + peaks = data.reshape(data.shape[:3] + (npeaks, 3)) + else: + raise ValueError("Peaks input must be 4D or 5D.") + + norms = np.linalg.norm(peaks, axis=-1) + # maximum amplitude/norm across peaks + max_amp = np.max(norms, axis=-1) + + # Check for normalized peaks + nonzero_norms = norms[norms > 0] + if len(nonzero_norms) > 0 and \ + np.all(np.isclose(nonzero_norms, nonzero_norms[0])): + logging.warning("All peaks have the same norm. They might be " + "already normalized.") + else: + max_amp = compute_max_sf_amplitude(data, sh_basis, is_legacy, + sphere_name=sphere_name) + + global_max = np.max(max_amp) + + # Compute threshold. Use max if both are provided. + threshold = 0 + if absolute_threshold is not None: + threshold = absolute_threshold + if relative_factor is not None: + if relative_factor < 0 or relative_factor > 1: + raise ValueError("relative_factor must be between 0 and 1.") + threshold = max(threshold, relative_factor * global_max) + + if global_max == 0: + mask = np.zeros(max_amp.shape, dtype=bool) + else: + mask = max_amp >= threshold + + if postprocess_mask and np.any(mask): + # Postprocess to label all elements and count voxels for each label + labels = ndi.label(mask)[0] + label_counts = np.bincount(labels.ravel()) + + # Guard against empty label_counts[1:] + if len(label_counts) > 1: + # Find the largest connected component (excluding background) + # +1 to skip background + largest_label = np.argmax(label_counts[1:]) + 1 + largest_component_size = label_counts[largest_label] + + # Create a mask for the largest connected component + mask = labels == largest_label + inverted_mask = ~mask + + # Remove isolated voxels in the inverted mask (holes in main mask) + labels_inverted = ndi.label(inverted_mask)[0] + label_counts_inverted = np.bincount(labels_inverted.ravel()) + + # Fill holes smaller than X% of the largest component size + hole_threshold = size_percentage * largest_component_size + for label, count in enumerate(label_counts_inverted): + if label == 0: + continue # Skip background + if count < hole_threshold: + mask[labels_inverted == label] = True + + return mask, global_max, threshold + + +def get_global_sf_threshold_mask(data, args, sh_basis, is_legacy): + """ + Wrapper for compute_sf_threshold_mask to compute the global SF + threshold mask and log information. + + The global SF threshold can be set as a relative factor of the global + maximum SF amplitude, or as an absolute threshold. The relative factor is + often set between 0.1 and 0.2, but it can depend on the data and the + SH basis used. The absolute threshold can be estimated from the + mean/median maximum fODF in the ventricles, computed with + scil_fodf_max_in_ventricles. + + Note that this estimation is not perfect as it depends on the accuracy of + the ventricle mask and on the presence of noise/artifacts in the data. + + Parameters + ---------- + data : np.ndarray + ODF data (SH or Peaks). + args : argparse.Namespace + Arguments from the CLI. Must contain sphere, global_sf_rel_thr, + and global_sf_abs_thr. + sh_basis : str + SH basis. + is_legacy : bool + Whether the SH basis is legacy. + + Returns + ------- + sf_mask : np.ndarray + Binary mask. + """ + sf_mask, global_max, threshold = compute_sf_threshold_mask( + data, sphere_name=args.sphere, + relative_factor=args.global_sf_rel_thr, + absolute_threshold=args.global_sf_abs_thr, sh_basis=sh_basis, + is_legacy=is_legacy) + logging.info("Global SF threshold mask: Global Max SF amplitude: " + "{:.4f}".format(global_max)) + if args.global_sf_rel_thr is not None: + logging.info("Global SF threshold mask: Computed threshold: " + "{:.4f} (Factor: {})" + .format(threshold, args.global_sf_rel_thr)) + else: + logging.info("Global SF threshold mask: Absolute threshold: " + "{:.4f}".format(args.global_sf_abs_thr)) + return sf_mask