Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions src/scilpy/cli/scil_tracking_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@
* Forward tracking: For GPU tracking, the `--forward_only` flag can be used
to disable backward tracking. This option isn't available for CPU
tracking.
* Random number generator seed (RNG): CPU and GPU use different RNG implementations,<
so the same `--seed` is reproducible within a backend but does not guarantee
identical streamlines across CPU vs GPU tracking.
* Random number generator seed (RNG): CPU and GPU use different RNG
implementations, so the same `--seed` is reproducible within a
backend but does not guarantee identical streamlines
across CPU vs GPU tracking.

All the input nifti files must be in isotropic resolution.

Expand Down Expand Up @@ -71,12 +72,13 @@
from dipy.tracking.stopping_criterion import BinaryStoppingCriterion
from dipy.tracking.tracker import eudx_tracking
from scilpy.io.image import get_data_as_mask
from scilpy.io.utils import (add_sphere_arg, add_verbose_arg,
from scilpy.io.utils import (add_verbose_arg,
assert_headers_compatible, assert_inputs_exist,
assert_outputs_exist, parse_sh_basis_arg,
verify_compression_th, load_matrix_in_any_format)
from scilpy.tracking.tracker import GPUTracker
from scilpy.tracking.utils import (add_mandatory_options_tracking,
add_tracking_sh_options,
add_out_options, add_seeding_options,
add_tracking_options,
add_tracking_ptt_options,
Expand All @@ -100,6 +102,7 @@ def _build_arg_parser():
# Options that are the same in this script and scil_tracking_local_dev:
add_mandatory_options_tracking(p)
track_g = add_tracking_options(p)
add_tracking_sh_options(p)
add_seeding_options(p)

# Other options, only available in this script:
Expand All @@ -110,11 +113,7 @@ def _build_arg_parser():
track_g.add_argument('--algo', default='prob',
choices=['det', 'prob', 'ptt', 'eudx'],
help='Algorithm to use. [%(default)s]')
add_sphere_arg(track_g, symmetric_only=False)
track_g.add_argument('--sub_sphere',
type=int, default=0,
help='Subdivides each face of the sphere into 4^s new'
' faces. [%(default)s]')

add_tracking_ptt_options(p)
gpu_g = p.add_argument_group('GPU options')
gpu_g.add_argument('--use_gpu', action='store_true',
Expand All @@ -131,7 +130,6 @@ def _build_arg_parser():
' [{}]'.format(DEFAULT_BATCH_SIZE))

out_g = add_out_options(p)

out_g.add_argument('--seed', type=int,
help='Random number generator seed.')

Expand Down
212 changes: 139 additions & 73 deletions src/scilpy/cli/scil_tracking_local_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,22 @@
from nibabel.streamlines import detect_format, TrkFile

from scilpy.io.image import assert_same_resolution
from scilpy.io.utils import (add_processes_arg, add_sphere_arg,
from scilpy.io.utils import (add_processes_arg,
add_verbose_arg,
assert_inputs_exist, assert_outputs_exist,
parse_sh_basis_arg, verify_compression_th,
load_matrix_in_any_format)
from scilpy.io.tensor import convert_tensor_to_dipy_format
from scilpy.image.volume_space_management import DataVolume
from scilpy.tracking.propagator import ODFPropagator
from scilpy.tracking.propagator import ODFPropagator, TensorPropagator
from scilpy.tracking.rap import RAPContinue, RAPSwitch
from scilpy.tracking.seed import SeedGenerator, CustomSeedsDispenser
from scilpy.tracking.tracker import Tracker
from scilpy.tracking.utils import (add_mandatory_options_tracking,
add_out_options, add_seeding_options,
add_tracking_options,
add_tracking_tensor_options,
add_tracking_sh_options,
get_theta,
verify_streamline_length_options,
verify_seed_options)
Expand All @@ -98,44 +101,39 @@ def _build_arg_parser():
formatter_class=argparse.RawTextHelpFormatter,
epilog=version_string)

# Options common to both scripts
add_mandatory_options_tracking(p, fodf_optional=True)
# Input data options
data_g = p.add_argument_group('Input data options')
data_group = data_g.add_mutually_exclusive_group(required=True)
data_group.add_argument('--in_odf',
help='Path to the ODF SH coefficient file (.nii.gz).\n'
'Use this for ODF-based tracking.')
data_group.add_argument('--in_tensor',
help='Path to the DTI tensor file (.nii.gz).\n'
'Use this for tensor-based tracking.')
data_group.add_argument('--rap_params',
help='Path to the JSON file containing RAP policies.\n'
'Use this for RAP with method "switch".\n'
'Expected format:\n'
'{\n'
' "methods": {\n'
' "1": {"propagator": "ODF", "filename": str,\n'
' "sh_basis": str, "algo": str,\n'
' "theta": float, "step_size": float},\n'
' "2": {"propagator": "Tensor", "filename": str,\n'
' "tensor_format": str, "tensor_interp": str,\n'
' "algo": str, "theta": float,\n'
' "step_size": float, "std": float}\n'
' }\n'
'}')

# Options common to ODF and tensor-based tracking
add_mandatory_options_tracking(p, fodf_mandatory=False)
track_g = add_tracking_options(p)
add_seeding_options(p)

# Options only for here.
# Options only for all models.
track_g.add_argument('--algo', default='prob', choices=['det', 'prob'],
help='Algorithm to use. [%(default)s]')
add_sphere_arg(track_g, symmetric_only=False)
track_g.add_argument('--sub_sphere',
type=int, default=0,
help='Subdivides each face of the sphere into 4^s new'
' faces. [%(default)s]')
track_g.add_argument('--sfthres_init', metavar='sf_th', type=float,
default=0.5, dest='sf_threshold_init',
help="Spherical function relative threshold value "
"for the \ninitial direction. [%(default)s]")
track_g.add_argument('--rk_order', metavar="K", type=int, default=1,
choices=[1, 2, 4],
help="The order of the Runge-Kutta integration used "
"for the step function.\n"
"For more information, refer to the note in the"
" script description. [%(default)s]")
track_g.add_argument('--max_invalid_nb_points', metavar='MAX', type=float,
default=0,
help="Maximum number of steps without valid "
"direction, \nex: if threshold on ODF or max "
"angles are reached.\n"
"Default: 0, i.e. do not add points following "
"an invalid direction.")
track_g.add_argument('--forward_only', action='store_true',
help="If set, tracks in one direction only (forward) "
"given the \ninitial seed. The direction is "
"randomly drawn from the ODF.")
track_g.add_argument('--sh_interp', default='trilinear',
choices=['nearest', 'trilinear'],
help="Spherical harmonic interpolation: "
"nearest-neighbor \nor trilinear. [%(default)s]")
track_g.add_argument('--mask_interp', default='nearest',
choices=['nearest', 'trilinear'],
help="Mask interpolation: nearest-neighbor or "
Expand All @@ -152,37 +150,56 @@ def _build_arg_parser():
help="By default, each seed position is used only once. This option\n"
"allows for tracking from the exact same seed n_repeats_per_seed"
"\ntimes. [%(default)s]")
track_g.add_argument('--max_invalid_nb_points', metavar='MAX', type=float,
default=0,
help="Maximum number of steps without valid "
"direction, \nex: if threshold on ODF or max "
"angles are reached.\n"
"Default: 0, i.e. do not add points following "
"an invalid direction.")
track_g.add_argument('--forward_only', action='store_true',
help="If set, tracks in one direction only (forward) "
"given the \ninitial seed. The direction is "
"randomly drawn from the ODF/Tensor.")

sh_options = add_tracking_sh_options(p)
sh_options.add_argument('--sh_interp', default='trilinear',
choices=['nearest', 'trilinear'],
help="Spherical harmonic interpolation: "
"nearest-neighbor \nor trilinear. [%(default)s]")
add_tracking_tensor_options(p)

r_g = p.add_argument_group('Random seeding options')
r_g.add_argument('--rng_seed', type=int, default=0,
help='Initial value for the random number generator. '
'[%(default)s]')
r_g.add_argument('--skip', type=int, default=0,
help="Skip the first N random number. \n"
"Useful if you want to create new streamlines to "
"add to \na previously created tractogram with a "
"fixed --rng_seed.\nEx: If tractogram_1 was created "
"with -nt 1,000,000, \nyou can create tractogram_2 "
"with \n--skip 1,000,000.")
rap_g = p.add_argument_group('Region-Adaptive Propagation options')
rap_mode = rap_g.add_mutually_exclusive_group()
rap_mode.add_argument('--rap_mask', default=None,
help='Region-Adaptive Propagation mask (.nii.gz).\n'
'Region-Adaptive Propagation tractography will start within '
'this mask.')
'Region-Adaptive Propagation tractography will start within '
'this mask.')
rap_mode.add_argument('--rap_labels', default=None,
help='Region-Adaptive Propagation label volume (.nii.gz) .\n'
'Voxel values are integer labels (0=background, 1..N=regions) .\n'
'Used with --rap_method switch to select policies per label.')
'Voxel values are integer labels (0=background, 1..N=regions) .\n'
'Used with --rap_method switch to select policies per label.')
rap_g.add_argument('--rap_method', default='None',
choices=['None', 'continue', 'switch'],
help="Region-Adaptive Propagation tractography method.\n"
"'continue': continues tracking with same params,\n"
"'switch': switches tracking params inside RAP mask.\n"
" [%(default)s]")
"'continue': continues tracking with same params,\n"
"'switch': switches tracking params inside RAP mask.\n"
" [%(default)s]")
rap_g.add_argument('--rap_save_entry_exit', default=None,
help='Save RAP entry/exit coordinates as a binary mask.\n'
'Provide output filename (.nii.gz).')
'Provide output filename (.nii.gz).')

r_g = p.add_argument_group('Random seeding options')
r_g.add_argument('--rng_seed', type=int, default=None,
help='Initial value for the random number generator. '
'[%(default)s]')
r_g.add_argument('--skip', type=int, default=0,
help="Skip the first N random number. \n"
"Useful if you want to create new streamlines to "
"add to \na previously created tractogram with a "
"fixed --rng_seed.\nEx: If tractogram_1 was created "
"with -nt 1,000,000, \nyou can create tractogram_2 "
"with \n--skip 1,000,000.")

m_g = p.add_argument_group('Memory options')
add_processes_arg(m_g)
Expand All @@ -203,15 +220,19 @@ def main():
parser.error('Invalid output streamline file format (must be trk or ' +
'tck): {0}'.format(args.out_tractogram))

if args.rap_params:
inputs = [args.in_seed, args.in_mask]
models = []
if args.in_odf:
models = [args.in_odf]
elif args.in_tensor:
models = [args.in_tensor]
elif args.rap_params:
with open(args.rap_params, 'r') as f:
rap_params = json.load(f)
filenames = [cfg['filename'] for cfg in rap_params.get('methods', {}).values()
if 'filename' in cfg]
assert_inputs_exist(parser, filenames)
models = [cfg['filename'] for cfg in rap_params.get('methods', {}).values()
if 'filename' in cfg]

inputs = [args.in_seed, args.in_mask]
assert_inputs_exist(parser, inputs, optional=args.in_odf)
assert_inputs_exist(parser, inputs + models)
assert_outputs_exist(parser, args, args.out_tractogram)

verify_streamline_length_options(parser, args)
Expand Down Expand Up @@ -244,8 +265,8 @@ def main():

max_nbr_pts = int(args.max_length / args.step_size)
min_nbr_pts = max(int(args.min_length / args.step_size), 1)
if args.in_odf:
assert_same_resolution([args.in_mask, args.in_odf, args.in_seed])

assert_same_resolution(inputs + models)

# Choosing our space and origin for this tracking
# If save_seeds, space and origin must be vox, center. Choosing those
Expand Down Expand Up @@ -301,26 +322,38 @@ def main():
odf_sh_res = odf_sh_img.header.get_zooms()[:3]
dataset = DataVolume(odf_sh_data, odf_sh_res, args.sh_interp)

logging.info("Instantiating propagator.")
# Converting step size to vox space
# We only support iso vox for now but allow slightly different vox
# 1e-3.
assert np.allclose(np.mean(odf_sh_res[:3]),
odf_sh_res, atol=1e-03)
logging.info("Instantiating ODF propagator.")
voxel_size = odf_sh_img.header.get_zooms()[0]
vox_step_size = args.step_size / voxel_size

# Using space and origin in the propagator: vox and center, like
# in dipy.
sh_basis, is_legacy = parse_sh_basis_arg(args)

propagator = ODFPropagator(
dataset, vox_step_size, args.rk_order, args.algo, sh_basis,
args.sf_threshold, args.sf_threshold_init, theta, args.sphere,
sub_sphere=args.sub_sphere,
space=our_space, origin=our_origin, is_legacy=is_legacy)
propagators = {args.in_odf: propagator}

elif args.in_tensor: # args.in_tensor
logging.info("Loading DTI tensor data.")
tensor_img = nib.load(args.in_tensor)
tensor_data = tensor_img.get_fdata(caching='unchanged', dtype=float)

# Convert tensor to dipy format if needed
if args.tensor_format != 'dipy':
logging.info(f"Converting tensor from {args.tensor_format} to dipy format.")
tensor_data = convert_tensor_to_dipy_format(tensor_data, args.tensor_format)

logging.info(f"Tensor data shape: {tensor_data.shape}")
tensor_res = tensor_img.header.get_zooms()[:3]
dataset = DataVolume(tensor_data, tensor_res, args.tensor_interp)

logging.info("Instantiating tensor propagator.")
voxel_size = tensor_img.header.get_zooms()[0]
vox_step_size = args.step_size / voxel_size

propagator = TensorPropagator(
dataset, vox_step_size, args.rk_order, args.algo,
theta, std=args.std, space=our_space, origin=our_origin)
elif args.rap_method == "switch":
propagator = None
propagators = {}
Expand Down Expand Up @@ -354,10 +387,43 @@ def main():
args.sf_threshold_init, theta, args.sphere,
sub_sphere=args.sub_sphere, space=our_space,
origin=our_origin, is_legacy=is_legacy)
elif cfg.get('propagator').lower() == 'tensor':
filename = cfg['filename']
tensor_format = cfg.get('tensor_format', args.tensor_format)
tensor_interp = cfg.get('tensor_interp', args.tensor_interp)
dataset_key = (filename, tensor_format, tensor_interp)

# Load data if needed
if dataset_key not in loaded_datasets:
tensor_img = nib.load(filename)
tensor_data = tensor_img.get_fdata(caching='unchanged',
dtype=float)

if tensor_format != 'dipy':
tensor_data = convert_tensor_to_dipy_format(
tensor_data, tensor_format)

tensor_res = tensor_img.header.get_zooms()[:3]
loaded_datasets[dataset_key] = DataVolume(
tensor_data, tensor_res, tensor_interp)

# Get params from rap_policies file
algo = cfg.get('algo', args.algo)
theta = gm.math.radians(get_theta(cfg.get('theta', args.theta),
algo))
std = cfg.get('std', args.std)
voxel_size = loaded_datasets[dataset_key].voxres[0]
vox_step_size = cfg.get('step_size', args.step_size) / voxel_size

# Build propagator from rap_policies file
propagators[label] = TensorPropagator(
loaded_datasets[dataset_key], vox_step_size,
args.rk_order, algo, theta, std=std,
space=our_space, origin=our_origin)
else:
raise ValueError(
f"Unknown propagator type '{cfg.get('propagator')}"
f"for label {label}. Supported types: 'ODF")
f"for label {label}. Supported types: 'ODF', 'Tensor'.")
del loaded_datasets

if not propagators:
Expand Down
Loading
Loading