Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ scil_sh_fusion = "scilpy.cli.scil_sh_fusion:main"
scil_sh_to_aodf = "scilpy.cli.scil_sh_to_aodf:main"
scil_sh_to_rish = "scilpy.cli.scil_sh_to_rish:main"
scil_sh_to_sf = "scilpy.cli.scil_sh_to_sf:main"
scil_fodf_threshold_by_amplitude = "scilpy.cli.scil_fodf_threshold_by_amplitude:main"
scil_stats_group_comparison = "scilpy.cli.scil_stats_group_comparison:main"
scil_surface_assign_custom_color = "scilpy.cli.scil_surface_assign_custom_color:main"
scil_surface_assign_uniform_color = "scilpy.cli.scil_surface_assign_uniform_color:main"
Expand Down
101 changes: 101 additions & 0 deletions src/scilpy/cli/scil_fodf_threshold_by_amplitude.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

Comment thread
arnaudbore marked this conversation as resolved.
"""
Compute a binary mask based on a global SF threshold.
The script masks voxels where the maximum SF amplitude is below
either a relative factor or an absolute threshold (or both).

When fODFs are evaluated on a sphere (SF), the amplitude of the lobes
corresponds to the strength of the diffusion signal in those directions.
Thresholding these amplitudes is a common practice to filter out spurious
peaks arising from noise or the deconvolution process (e.g., ringing effects).

The absolute threshold can be estimated from the mean/median maximum fODF
in the ventricles, computed with scil_fodf_max_in_ventricles.

If both --relative and --absolute are provided, the final threshold is the
maximum of the two resulting values.
"""

import argparse
import logging

import nibabel as nib
import numpy as np

from scilpy.io.utils import (add_sh_basis_args, add_sphere_arg,
add_verbose_arg, add_overwrite_arg,
assert_inputs_exist, assert_outputs_exist,
parse_sh_basis_arg)
from scilpy.tracking.utils import compute_sf_threshold_mask
from scilpy.version import version_string


def _build_arg_parser():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
epilog=version_string)

p.add_argument('in_odf',
help='Input ODF file (SH or Peaks) (.nii.gz).')
p.add_argument('out_mask',
help='Output binary mask (.nii.gz).')

thr_g = p.add_argument_group('Threshold options')
thr_g.add_argument('--relative', type=float,
help='Global SF threshold relative factor (0-1).')
thr_g.add_argument('--absolute', type=float,
help='Global SF absolute threshold.')
add_sh_basis_args(p)
add_sphere_arg(p)
add_overwrite_arg(p)
add_verbose_arg(p)

return p


def main():
parser = _build_arg_parser()
args = parser.parse_args()
logging.getLogger().setLevel(logging.getLevelName(args.verbose))

Comment thread
frheault marked this conversation as resolved.
if args.relative is None and args.absolute is None:
parser.error("At least one of --relative or --absolute must be "
"provided.")

assert_inputs_exist(parser, args.in_odf)
assert_outputs_exist(parser, args, args.out_mask)

sh_basis, is_legacy = parse_sh_basis_arg(args)

logging.info("Loading ODF data.")
img = nib.load(args.in_odf)
data = img.get_fdata(dtype=np.float32)

logging.info("Computing global SF threshold mask.")
mask, global_max, threshold = compute_sf_threshold_mask(
data, sphere_name=args.sphere, relative_factor=args.relative,
absolute_threshold=args.absolute, sh_basis=sh_basis,
is_legacy=is_legacy)

logging.info("Global max SF amplitude: {:.4f}".format(global_max))
if args.relative is not None and args.absolute is not None:
logging.info("Both relative and absolute thresholds used. "
"Final threshold: {:.4f}".format(threshold))
elif args.relative is not None:
logging.info("Relative threshold: {:.4f} (Factor: {})"
.format(threshold, args.relative))
else:
logging.info("Absolute threshold used: {:.4f}".format(args.absolute))

logging.info("Number of voxels in mask: {}".format(np.sum(mask)))

# Save mask
mask_img = nib.Nifti1Image(mask.astype(np.uint8), img.affine,
img.header)
nib.save(mask_img, args.out_mask)


if __name__ == "__main__":
main()
51 changes: 34 additions & 17 deletions src/scilpy/cli/scil_tracking_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
The tracking direction is chosen in the aperture cone defined by the
previous tracking direction and the angular constraint.

WARNING: This script DOES NOT support asymetric FODF input (aFODF).
WARNING: This script DOES NOT support asymmetric FODF input (aFODF).

Algo 'eudx': select the peak from the spherical function (SF) most closely
aligned to the previous direction, and follow an average of it and the previous
Expand Down Expand Up @@ -41,9 +41,10 @@
* Forward tracking: For GPU tracking, the `--forward_only` flag can be used
to disable backward tracking. This option isn't available for CPU
tracking.
* Random number generator seed (RNG): CPU and GPU use different RNG implementations,<
so the same `--seed` is reproducible within a backend but does not guarantee
identical streamlines across CPU vs GPU tracking.
* Random number generator seed (RNG): CPU and GPU use different RNG
implementations, so the same `--seed` is reproducible within a
backend but does not guarantee identical streamlines across CPU vs
GPU tracking.

All the input nifti files must be in isotropic resolution.

Expand All @@ -61,26 +62,27 @@
import logging
from time import perf_counter

import nibabel as nib
import numpy as np
from nibabel.streamlines import TrkFile, detect_format

from dipy.data import get_sphere
from dipy.tracking import utils as track_utils
from dipy.tracking.local_tracking import LocalTracking
from dipy.tracking.stopping_criterion import BinaryStoppingCriterion
from dipy.tracking.tracker import eudx_tracking
import nibabel as nib
from nibabel.streamlines import TrkFile, detect_format
import numpy as np

from scilpy.io.image import get_data_as_mask
from scilpy.io.utils import (add_sphere_arg, add_verbose_arg,
assert_headers_compatible, assert_inputs_exist,
assert_outputs_exist, parse_sh_basis_arg,
verify_compression_th, load_matrix_in_any_format)
assert_outputs_exist, load_matrix_in_any_format,
parse_sh_basis_arg, verify_compression_th)
from scilpy.tracking.tracker import GPUTracker
from scilpy.tracking.utils import (add_mandatory_options_tracking,
add_out_options, add_seeding_options,
add_tracking_options,
add_tracking_ptt_options,
get_direction_getter, get_theta,
get_direction_getter,
get_global_sf_threshold_mask, get_theta,
save_tractogram, verify_seed_options,
verify_streamline_length_options)
from scilpy.version import version_string
Expand All @@ -104,7 +106,7 @@ def _build_arg_parser():

# Other options, only available in this script:
track_g.add_argument('--sh_to_pmf', action='store_true',
help='If set, map sherical harmonics to spherical '
help='If set, map spherical harmonics to spherical '
'function (pmf) before \ntracking (faster, '
'requires more memory)')
track_g.add_argument('--algo', default='prob',
Expand Down Expand Up @@ -200,6 +202,17 @@ def main():
logging.debug("Loading masks and finding seeds.")
mask_data = get_data_as_mask(nib.load(args.in_mask), dtype=bool)

# ODF data for thresholding
odf_sh_data = odf_sh_img.get_fdata(dtype=np.float32)

sh_basis, is_legacy = parse_sh_basis_arg(args)

sf_mask = None
if args.global_sf_rel_thr is not None or \
args.global_sf_abs_thr is not None:
sf_mask = get_global_sf_threshold_mask(
odf_sh_data, args, sh_basis, is_legacy)

if args.npv:
nb_seeds = args.npv
seed_per_vox = True
Expand Down Expand Up @@ -235,11 +248,16 @@ def main():
random_seed=args.seed)
total_nb_seeds = len(seeds)

combined_mask = mask_data
if sf_mask is not None:
combined_mask = np.logical_and(mask_data, sf_mask)

if not args.use_gpu:
# LocalTracking.maxlen is actually the maximum length
# per direction, we need to filter post-tracking.
max_steps_per_direction = int(args.max_length / args.step_size)
stopping_criterion = BinaryStoppingCriterion(mask_data)

stopping_criterion = BinaryStoppingCriterion(combined_mask)

logging.info("Starting CPU local tracking.")
if args.algo == 'eudx':
Expand All @@ -248,7 +266,7 @@ def main():
stopping_criterion,
np.eye(4),
pam=get_direction_getter(
args.in_odf, args.algo, args.sphere,
odf_sh_data, args.algo, args.sphere,
args.sub_sphere, args.theta, sh_basis,
voxel_size, args.sf_threshold, args.sh_to_pmf,
args.probe_length, args.probe_radius,
Expand All @@ -264,7 +282,7 @@ def main():
else:
streamlines_generator = LocalTracking(
get_direction_getter(
args.in_odf, args.algo, args.sphere,
odf_sh_data, args.algo, args.sphere,
args.sub_sphere, args.theta, sh_basis,
voxel_size, args.sf_threshold, args.sh_to_pmf,
args.probe_length, args.probe_radius,
Expand All @@ -284,14 +302,13 @@ def main():
max_strl_len = int(2.0 * args.max_length / args.step_size) + 1

# data volume
odf_sh = odf_sh_img.get_fdata(dtype=np.float32)

# GPU tracking needs the full sphere
sphere = get_sphere(name=args.sphere).subdivide(n=args.sub_sphere)

logging.info("Starting GPU local tracking.")
streamlines_generator = GPUTracker(
odf_sh, mask_data, seeds,
odf_sh_data, combined_mask, seeds,
vox_step_size, max_strl_len,
theta=get_theta(args.theta, args.algo),
sf_threshold=args.sf_threshold,
Expand Down
Loading
Loading