Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/scilpy/cli/scil_frf_ssst.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import nibabel as nib
import numpy as np

from scilpy.gradients.bvec_bval_tools import check_b0_threshold
from scilpy.gradients.bvec_bval_tools import (check_b0_threshold,
check_shells_frf)
from scilpy.io.image import get_data_as_mask
from scilpy.io.utils import (add_b0_thresh_arg, add_overwrite_arg,
add_precision_arg,
Expand Down Expand Up @@ -111,6 +112,8 @@ def main():
b0_thr=args.b0_threshold,
skip_b0_check=args.skip_b0_check)

check_shells_frf(bvals, args.b0_threshold)

Comment thread
arnaudbore marked this conversation as resolved.
mask = get_data_as_mask(nib.load(args.mask),
dtype=bool) if args.mask else None
mask_wm = get_data_as_mask(nib.load(args.mask_wm),
Expand Down
30 changes: 29 additions & 1 deletion src/scilpy/gradients/bvec_bval_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,34 @@ def identify_shells(bvals, tol=40.0, round_centroids=False, sort=False):
return centroids, shell_indices


def check_shells_frf(bvals, b0_threshold):
"""
Check if the shells are too far apart, which might cause problems for
FRF estimation.

Parameters
----------
bvals : np.ndarray
b-values.
b0_threshold : float
Threshold for b0.
"""
shells_centroids, _ = identify_shells(bvals, b0_threshold,
round_centroids=True)
shells_centroids = list(sorted(
shells_centroids[shells_centroids > b0_threshold]))
min_non_b0_shell = np.min(shells_centroids) \
if len(shells_centroids) > 0 else 0
max_non_b0_delta = np.ediff1d(shells_centroids)[0] \
if len(shells_centroids) > 1 else 0
if max_non_b0_delta >= min_non_b0_shell:
logging.warning(
'Your shells seem to be very far apart (max delta: {}, '
'min non-b0 shell: {}). This might cause problems for the '
'estimation of the FRF. Consider using scil_frf_msmt.py.'
.format(max_non_b0_delta, min_non_b0_shell))


def str_to_axis_index(axis):
"""
Convert x y z axis string to 0 1 2 axis index
Expand Down Expand Up @@ -257,7 +285,7 @@ def find_flip_swap_from_order(order):
elif next_axis in [-1, -2, -3]:
axes_to_flip.append(abs(next_axis) - 1)
swapped_order.append(abs(next_axis) - 1)
return(axes_to_flip, swapped_order)
return (axes_to_flip, swapped_order)


def flip_gradient_axis(bvecs, axes, sampling_type):
Expand Down
24 changes: 14 additions & 10 deletions src/scilpy/gradients/tests/test_bvec_bval_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np

from scilpy.gradients.bvec_bval_tools import (
check_b0_threshold, identify_shells, is_normalized_bvecs,
check_b0_threshold, check_shells_frf, identify_shells, is_normalized_bvecs,
flip_gradient_axis, find_flip_swap_from_order, normalize_bvecs,
round_bvals_to_shell, str_to_axis_index, swap_gradient_axis)

Expand Down Expand Up @@ -125,12 +125,16 @@ def test_round_bvals_to_shell():
success = False
assert not success

# 3. Verify that doesn't work with shell missing: no data on shell 1000.
bvals = np.asarray([0, 10])
shells = [0, 1000]
success = True
try:
_ = round_bvals_to_shell(bvals, shells, tol=tolerance)
except ValueError:
success = False
assert not success

def test_check_shells_frf():
# Test case where shells are close enough
bvals = np.asarray([0, 0, 1000, 1000, 2000, 2000])
check_shells_frf(bvals, b0_threshold=20)

# Test case where shells are too far apart
bvals = np.asarray([0, 0, 1000, 1000, 3000, 3000])
check_shells_frf(bvals, b0_threshold=20)

# Test case with no non-b0 shells
bvals = np.asarray([0, 0, 10, 10])
check_shells_frf(bvals, b0_threshold=20)
39 changes: 39 additions & 0 deletions src/scilpy/reconst/tests/test_is_data_peaks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
import os
import nibabel as nib
import numpy as np
from scilpy import SCILPY_HOME
from scilpy.io.fetcher import fetch_data, get_testing_files_dict
from scilpy.reconst.utils import is_data_peaks


def test_is_data_peaks_with_real_data():
fetch_data(get_testing_files_dict(), keys=['processing.zip'])

processing_dir = os.path.join(SCILPY_HOME, 'processing')

# 1. Test with SH data (fODF)
sh_path = os.path.join(processing_dir, 'fodf_descoteaux07.nii.gz')
sh_data = nib.load(sh_path).get_fdata()
assert is_data_peaks(sh_data) is False, "Should identify SH data as False"

# 2. Test with Peaks data
peaks_path = os.path.join(processing_dir, 'peaks.nii.gz')
peaks_data = nib.load(peaks_path).get_fdata()
assert is_data_peaks(
peaks_data) is True, "Should identify Peaks data as True"


def test_is_data_peaks_with_edge_cases():
# 3D data (e.g. 1 directions)
peaks_3d = np.random.rand(10, 10, 10, 3)
assert is_data_peaks(peaks_3d) is True

# SH data with order 4 (15 coefficients) but all zeros
sh_zeros = np.zeros((10, 10, 10, 15))
assert is_data_peaks(sh_zeros) is False

# Data that is clearly peaks (multiple of 3, many zeros)
peaks_many_zeros = np.zeros((10, 10, 10, 9))
peaks_many_zeros[5, 5, 5, :3] = [1, 0, 0]
assert is_data_peaks(peaks_many_zeros) is True
69 changes: 68 additions & 1 deletion src/scilpy/reconst/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def get_maximas(data, sphere, b_matrix, threshold, absolute_threshold,
spherical_func = np.dot(data, b_matrix.T)
spherical_func[np.nonzero(spherical_func < absolute_threshold)] = 0.
return peak_directions(
spherical_func, sphere, threshold, min_separation_angle)
spherical_func, sphere,
relative_peak_threshold=threshold,
min_separation_angle=min_separation_angle)


def get_sphere_neighbours(sphere, max_angle):
Expand All @@ -57,3 +59,68 @@ def get_sphere_neighbours(sphere, max_angle):
np.outer(zs, zs))
neighbours = scalar_prods >= np.cos(max_angle)
return neighbours


def is_data_peaks(img_data):
Comment thread
arnaudbore marked this conversation as resolved.
"""
Heuristic to find out if the input are peaks or fodf.
fodf are always around 0.15 and peaks around 0.75.
Peaks have more zero values than fodf. The first value of fodf is
usually the highest.

Parameters
----------
img_data : np.ndarray
4D image data where the last dimension contains directional info.

Returns
-------
is_peaks : bool
True if data is likely peaks, False if likely fODF (SH).
"""
last_dim = img_data.shape[-1]
if last_dim == 3:
return True

# Sum of absolute values to detect non-zero voxels correctly
non_zeros_mask = np.any(np.abs(img_data) > 0, axis=-1)
if not np.count_nonzero(non_zeros_mask):
return False

try:
order, full = get_sh_order_and_fullness(last_dim)
# Symmetric SH must be even order
if not full and order % 2 != 0:
return False
except ValueError:
# If not a valid SH number of coefficients, and not 3,
# it might be something else, but if it's a multiple of 3
# it's likely Peaks.
if last_dim % 3 == 0:
return True
return False

data_nz = img_data[non_zeros_mask]

# If all triplets have the same norm, it is likely peaks, otherwise SH.
if last_dim % 3 == 0:
norm = np.linalg.norm(data_nz.reshape(-1, 3), axis=-1)
if np.all(np.isclose(norm, norm[0])):
return True

# If the max is in the first triplet but not at index 0, it's likely Peaks.
# Smoothed SH almost always has max at index 0
argmax_indices = np.argmax(np.abs(data_nz), axis=-1)
if last_dim % 3 == 0 and \
np.mean(np.logical_or(argmax_indices == 1,
argmax_indices == 2)) > 0.1:
return True

# Exact zeros. SH almost never has exact zeros in real data.
# Peaks often have exact zeros for unused lobes
zero_ratio = np.mean(data_nz == 0)
if zero_ratio > 0.05:
return True

# Default to SH
return False
9 changes: 2 additions & 7 deletions src/scilpy/tracking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,13 +351,8 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis,
# Theta depends on user choice and algorithm
theta = get_theta(theta, algo)

# Heuristic to find out if the input are peaks or fodf
# fodf are always around 0.15 and peaks around 0.75
# Peaks have more zero values than fodf. The first value of fodf is
# usually the highest.
non_zeros_count = np.count_nonzero(np.sum(img_data, axis=-1))
non_first_val_count = np.count_nonzero(np.argmax(img_data, axis=-1))
is_peaks = non_first_val_count / non_zeros_count > 0.5
from scilpy.reconst.utils import is_data_peaks
is_peaks = is_data_peaks(img_data)

if algo in ['det', 'prob', 'ptt']:
if is_peaks:
Expand Down
Loading