diff --git a/src/scilpy/cli/scil_tracking_local.py b/src/scilpy/cli/scil_tracking_local.py index bed9e5e7b..c1b07d1bf 100755 --- a/src/scilpy/cli/scil_tracking_local.py +++ b/src/scilpy/cli/scil_tracking_local.py @@ -41,9 +41,10 @@ * Forward tracking: For GPU tracking, the `--forward_only` flag can be used to disable backward tracking. This option isn't available for CPU tracking. - * Random number generator seed (RNG): CPU and GPU use different RNG implementations,< - so the same `--seed` is reproducible within a backend but does not guarantee - identical streamlines across CPU vs GPU tracking. + * Random number generator seed (RNG): CPU and GPU use different RNG + implementations, so the same `--seed` is reproducible within a + backend but does not guarantee identical streamlines + across CPU vs GPU tracking. All the input nifti files must be in isotropic resolution. @@ -71,12 +72,13 @@ from dipy.tracking.stopping_criterion import BinaryStoppingCriterion from dipy.tracking.tracker import eudx_tracking from scilpy.io.image import get_data_as_mask -from scilpy.io.utils import (add_sphere_arg, add_verbose_arg, +from scilpy.io.utils import (add_verbose_arg, assert_headers_compatible, assert_inputs_exist, assert_outputs_exist, parse_sh_basis_arg, verify_compression_th, load_matrix_in_any_format) from scilpy.tracking.tracker import GPUTracker from scilpy.tracking.utils import (add_mandatory_options_tracking, + add_tracking_sh_options, add_out_options, add_seeding_options, add_tracking_options, add_tracking_ptt_options, @@ -100,6 +102,7 @@ def _build_arg_parser(): # Options that are the same in this script and scil_tracking_local_dev: add_mandatory_options_tracking(p) track_g = add_tracking_options(p) + add_tracking_sh_options(p) add_seeding_options(p) # Other options, only available in this script: @@ -110,11 +113,7 @@ def _build_arg_parser(): track_g.add_argument('--algo', default='prob', choices=['det', 'prob', 'ptt', 'eudx'], help='Algorithm to use. [%(default)s]') - add_sphere_arg(track_g, symmetric_only=False) - track_g.add_argument('--sub_sphere', - type=int, default=0, - help='Subdivides each face of the sphere into 4^s new' - ' faces. [%(default)s]') + add_tracking_ptt_options(p) gpu_g = p.add_argument_group('GPU options') gpu_g.add_argument('--use_gpu', action='store_true', @@ -131,7 +130,6 @@ def _build_arg_parser(): ' [{}]'.format(DEFAULT_BATCH_SIZE)) out_g = add_out_options(p) - out_g.add_argument('--seed', type=int, help='Random number generator seed.') diff --git a/src/scilpy/cli/scil_tracking_local_dev.py b/src/scilpy/cli/scil_tracking_local_dev.py index 6b36fcb63..3fe99dc80 100755 --- a/src/scilpy/cli/scil_tracking_local_dev.py +++ b/src/scilpy/cli/scil_tracking_local_dev.py @@ -72,19 +72,22 @@ from nibabel.streamlines import detect_format, TrkFile from scilpy.io.image import assert_same_resolution -from scilpy.io.utils import (add_processes_arg, add_sphere_arg, +from scilpy.io.utils import (add_processes_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, parse_sh_basis_arg, verify_compression_th, load_matrix_in_any_format) +from scilpy.io.tensor import convert_tensor_to_dipy_format from scilpy.image.volume_space_management import DataVolume -from scilpy.tracking.propagator import ODFPropagator +from scilpy.tracking.propagator import ODFPropagator, TensorPropagator from scilpy.tracking.rap import RAPContinue, RAPSwitch from scilpy.tracking.seed import SeedGenerator, CustomSeedsDispenser from scilpy.tracking.tracker import Tracker from scilpy.tracking.utils import (add_mandatory_options_tracking, add_out_options, add_seeding_options, add_tracking_options, + add_tracking_tensor_options, + add_tracking_sh_options, get_theta, verify_streamline_length_options, verify_seed_options) @@ -98,44 +101,39 @@ def _build_arg_parser(): formatter_class=argparse.RawTextHelpFormatter, epilog=version_string) - # Options common to both scripts - add_mandatory_options_tracking(p, fodf_optional=True) + # Input data options + data_g = p.add_argument_group('Input data options') + data_group = data_g.add_mutually_exclusive_group(required=True) + data_group.add_argument('--in_odf', + help='Path to the ODF SH coefficient file (.nii.gz).\n' + 'Use this for ODF-based tracking.') + data_group.add_argument('--in_tensor', + help='Path to the DTI tensor file (.nii.gz).\n' + 'Use this for tensor-based tracking.') + data_group.add_argument('--rap_params', + help='Path to the JSON file containing RAP policies.\n' + 'Use this for RAP with method "switch".\n' + 'Expected format:\n' + '{\n' + ' "methods": {\n' + ' "1": {"propagator": "ODF", "filename": str,\n' + ' "sh_basis": str, "algo": str,\n' + ' "theta": float, "step_size": float},\n' + ' "2": {"propagator": "Tensor", "filename": str,\n' + ' "tensor_format": str, "tensor_interp": str,\n' + ' "algo": str, "theta": float,\n' + ' "step_size": float, "std": float}\n' + ' }\n' + '}') + + # Options common to ODF and tensor-based tracking + add_mandatory_options_tracking(p, fodf_mandatory=False) track_g = add_tracking_options(p) add_seeding_options(p) - # Options only for here. + # Options only for all models. track_g.add_argument('--algo', default='prob', choices=['det', 'prob'], help='Algorithm to use. [%(default)s]') - add_sphere_arg(track_g, symmetric_only=False) - track_g.add_argument('--sub_sphere', - type=int, default=0, - help='Subdivides each face of the sphere into 4^s new' - ' faces. [%(default)s]') - track_g.add_argument('--sfthres_init', metavar='sf_th', type=float, - default=0.5, dest='sf_threshold_init', - help="Spherical function relative threshold value " - "for the \ninitial direction. [%(default)s]") - track_g.add_argument('--rk_order', metavar="K", type=int, default=1, - choices=[1, 2, 4], - help="The order of the Runge-Kutta integration used " - "for the step function.\n" - "For more information, refer to the note in the" - " script description. [%(default)s]") - track_g.add_argument('--max_invalid_nb_points', metavar='MAX', type=float, - default=0, - help="Maximum number of steps without valid " - "direction, \nex: if threshold on ODF or max " - "angles are reached.\n" - "Default: 0, i.e. do not add points following " - "an invalid direction.") - track_g.add_argument('--forward_only', action='store_true', - help="If set, tracks in one direction only (forward) " - "given the \ninitial seed. The direction is " - "randomly drawn from the ODF.") - track_g.add_argument('--sh_interp', default='trilinear', - choices=['nearest', 'trilinear'], - help="Spherical harmonic interpolation: " - "nearest-neighbor \nor trilinear. [%(default)s]") track_g.add_argument('--mask_interp', default='nearest', choices=['nearest', 'trilinear'], help="Mask interpolation: nearest-neighbor or " @@ -152,37 +150,56 @@ def _build_arg_parser(): help="By default, each seed position is used only once. This option\n" "allows for tracking from the exact same seed n_repeats_per_seed" "\ntimes. [%(default)s]") + track_g.add_argument('--max_invalid_nb_points', metavar='MAX', type=float, + default=0, + help="Maximum number of steps without valid " + "direction, \nex: if threshold on ODF or max " + "angles are reached.\n" + "Default: 0, i.e. do not add points following " + "an invalid direction.") + track_g.add_argument('--forward_only', action='store_true', + help="If set, tracks in one direction only (forward) " + "given the \ninitial seed. The direction is " + "randomly drawn from the ODF/Tensor.") + + sh_options = add_tracking_sh_options(p) + sh_options.add_argument('--sh_interp', default='trilinear', + choices=['nearest', 'trilinear'], + help="Spherical harmonic interpolation: " + "nearest-neighbor \nor trilinear. [%(default)s]") + add_tracking_tensor_options(p) - r_g = p.add_argument_group('Random seeding options') - r_g.add_argument('--rng_seed', type=int, default=0, - help='Initial value for the random number generator. ' - '[%(default)s]') - r_g.add_argument('--skip', type=int, default=0, - help="Skip the first N random number. \n" - "Useful if you want to create new streamlines to " - "add to \na previously created tractogram with a " - "fixed --rng_seed.\nEx: If tractogram_1 was created " - "with -nt 1,000,000, \nyou can create tractogram_2 " - "with \n--skip 1,000,000.") rap_g = p.add_argument_group('Region-Adaptive Propagation options') rap_mode = rap_g.add_mutually_exclusive_group() rap_mode.add_argument('--rap_mask', default=None, help='Region-Adaptive Propagation mask (.nii.gz).\n' - 'Region-Adaptive Propagation tractography will start within ' - 'this mask.') + 'Region-Adaptive Propagation tractography will start within ' + 'this mask.') rap_mode.add_argument('--rap_labels', default=None, help='Region-Adaptive Propagation label volume (.nii.gz) .\n' - 'Voxel values are integer labels (0=background, 1..N=regions) .\n' - 'Used with --rap_method switch to select policies per label.') + 'Voxel values are integer labels (0=background, 1..N=regions) .\n' + 'Used with --rap_method switch to select policies per label.') rap_g.add_argument('--rap_method', default='None', choices=['None', 'continue', 'switch'], help="Region-Adaptive Propagation tractography method.\n" - "'continue': continues tracking with same params,\n" - "'switch': switches tracking params inside RAP mask.\n" - " [%(default)s]") + "'continue': continues tracking with same params,\n" + "'switch': switches tracking params inside RAP mask.\n" + " [%(default)s]") rap_g.add_argument('--rap_save_entry_exit', default=None, help='Save RAP entry/exit coordinates as a binary mask.\n' - 'Provide output filename (.nii.gz).') + 'Provide output filename (.nii.gz).') + + r_g = p.add_argument_group('Random seeding options') + r_g.add_argument('--rng_seed', type=int, default=None, + help='Initial value for the random number generator. ' + '[%(default)s]') + r_g.add_argument('--skip', type=int, default=0, + help="Skip the first N random number. \n" + "Useful if you want to create new streamlines to " + "add to \na previously created tractogram with a " + "fixed --rng_seed.\nEx: If tractogram_1 was created " + "with -nt 1,000,000, \nyou can create tractogram_2 " + "with \n--skip 1,000,000.") m_g = p.add_argument_group('Memory options') add_processes_arg(m_g) @@ -203,15 +220,19 @@ def main(): parser.error('Invalid output streamline file format (must be trk or ' + 'tck): {0}'.format(args.out_tractogram)) - if args.rap_params: + inputs = [args.in_seed, args.in_mask] + models = [] + if args.in_odf: + models = [args.in_odf] + elif args.in_tensor: + models = [args.in_tensor] + elif args.rap_params: with open(args.rap_params, 'r') as f: rap_params = json.load(f) - filenames = [cfg['filename'] for cfg in rap_params.get('methods', {}).values() - if 'filename' in cfg] - assert_inputs_exist(parser, filenames) + models = [cfg['filename'] for cfg in rap_params.get('methods', {}).values() + if 'filename' in cfg] - inputs = [args.in_seed, args.in_mask] - assert_inputs_exist(parser, inputs, optional=args.in_odf) + assert_inputs_exist(parser, inputs + models) assert_outputs_exist(parser, args, args.out_tractogram) verify_streamline_length_options(parser, args) @@ -244,8 +265,8 @@ def main(): max_nbr_pts = int(args.max_length / args.step_size) min_nbr_pts = max(int(args.min_length / args.step_size), 1) - if args.in_odf: - assert_same_resolution([args.in_mask, args.in_odf, args.in_seed]) + + assert_same_resolution(inputs + models) # Choosing our space and origin for this tracking # If save_seeds, space and origin must be vox, center. Choosing those @@ -301,17 +322,9 @@ def main(): odf_sh_res = odf_sh_img.header.get_zooms()[:3] dataset = DataVolume(odf_sh_data, odf_sh_res, args.sh_interp) - logging.info("Instantiating propagator.") - # Converting step size to vox space - # We only support iso vox for now but allow slightly different vox - # 1e-3. - assert np.allclose(np.mean(odf_sh_res[:3]), - odf_sh_res, atol=1e-03) + logging.info("Instantiating ODF propagator.") voxel_size = odf_sh_img.header.get_zooms()[0] vox_step_size = args.step_size / voxel_size - - # Using space and origin in the propagator: vox and center, like - # in dipy. sh_basis, is_legacy = parse_sh_basis_arg(args) propagator = ODFPropagator( @@ -319,8 +332,28 @@ def main(): args.sf_threshold, args.sf_threshold_init, theta, args.sphere, sub_sphere=args.sub_sphere, space=our_space, origin=our_origin, is_legacy=is_legacy) - propagators = {args.in_odf: propagator} + elif args.in_tensor: # args.in_tensor + logging.info("Loading DTI tensor data.") + tensor_img = nib.load(args.in_tensor) + tensor_data = tensor_img.get_fdata(caching='unchanged', dtype=float) + + # Convert tensor to dipy format if needed + if args.tensor_format != 'dipy': + logging.info(f"Converting tensor from {args.tensor_format} to dipy format.") + tensor_data = convert_tensor_to_dipy_format(tensor_data, args.tensor_format) + + logging.info(f"Tensor data shape: {tensor_data.shape}") + tensor_res = tensor_img.header.get_zooms()[:3] + dataset = DataVolume(tensor_data, tensor_res, args.tensor_interp) + + logging.info("Instantiating tensor propagator.") + voxel_size = tensor_img.header.get_zooms()[0] + vox_step_size = args.step_size / voxel_size + + propagator = TensorPropagator( + dataset, vox_step_size, args.rk_order, args.algo, + theta, std=args.std, space=our_space, origin=our_origin) elif args.rap_method == "switch": propagator = None propagators = {} @@ -354,10 +387,43 @@ def main(): args.sf_threshold_init, theta, args.sphere, sub_sphere=args.sub_sphere, space=our_space, origin=our_origin, is_legacy=is_legacy) + elif cfg.get('propagator').lower() == 'tensor': + filename = cfg['filename'] + tensor_format = cfg.get('tensor_format', args.tensor_format) + tensor_interp = cfg.get('tensor_interp', args.tensor_interp) + dataset_key = (filename, tensor_format, tensor_interp) + + # Load data if needed + if dataset_key not in loaded_datasets: + tensor_img = nib.load(filename) + tensor_data = tensor_img.get_fdata(caching='unchanged', + dtype=float) + + if tensor_format != 'dipy': + tensor_data = convert_tensor_to_dipy_format( + tensor_data, tensor_format) + + tensor_res = tensor_img.header.get_zooms()[:3] + loaded_datasets[dataset_key] = DataVolume( + tensor_data, tensor_res, tensor_interp) + + # Get params from rap_policies file + algo = cfg.get('algo', args.algo) + theta = gm.math.radians(get_theta(cfg.get('theta', args.theta), + algo)) + std = cfg.get('std', args.std) + voxel_size = loaded_datasets[dataset_key].voxres[0] + vox_step_size = cfg.get('step_size', args.step_size) / voxel_size + + # Build propagator from rap_policies file + propagators[label] = TensorPropagator( + loaded_datasets[dataset_key], vox_step_size, + args.rk_order, algo, theta, std=std, + space=our_space, origin=our_origin) else: raise ValueError( f"Unknown propagator type '{cfg.get('propagator')}" - f"for label {label}. Supported types: 'ODF") + f"for label {label}. Supported types: 'ODF', 'Tensor'.") del loaded_datasets if not propagators: diff --git a/src/scilpy/image/volume_space_management.py b/src/scilpy/image/volume_space_management.py index 6f5b54429..6f815fb4f 100644 --- a/src/scilpy/image/volume_space_management.py +++ b/src/scilpy/image/volume_space_management.py @@ -19,6 +19,22 @@ class DataVolume(object): """ Class to access/interpolate data from nibabel object + + Tensor interpolation notes + -------------------------- + The tensor interpolation path supports a Log-Euclidean mode for 6-coefficient + DTI volumes. This follows the idea of interpolating symmetric positive-definite + matrices in the log-domain and then mapping the result back with the matrix + exponential. + + References + ~~~~~~~~~~ + Arsigny, V., Fillard, P., Pennec, X., Ayache, N. (2007). Geometric Means in a + Novel Vector Space Structure on Symmetric Positive-Definite Matrices. SIAM + Journal on Matrix Analysis and Applications, 29(1), 328-347. + + Pennec, X., Fillard, P., Ayache, N. (2006). A Riemannian Framework for Tensor + Computing. International Journal of Computer Vision, 66, 41-66. """ def __init__(self, data, voxres, interpolation=None, must_be_3d=False): @@ -30,18 +46,18 @@ def __init__(self, data, voxres, interpolation=None, must_be_3d=False): voxres: np.array(3,) The pixel resolution, ex, using img.header.get_zooms()[:3]. interpolation: str or None - The interpolation choice amongst "trilinear" or "nearest". If - None, functions getting a coordinate in mm instead of voxel + The interpolation choice amongst "trilinear" or "nearest" or "log_euclidean". + If None, functions getting a coordinate in mm instead of voxel coordinates are not available. must_be_3d: bool If True, dataset can't be 4D. """ self.interpolation = interpolation if self.interpolation: - if not (self.interpolation == 'trilinear' or - self.interpolation == 'nearest'): - raise Exception("Interpolation must be 'trilinear' or " - "'nearest'") + if self.interpolation not in ['trilinear', 'nearest', + 'log_euclidean']: + raise Exception("Interpolation must be 'trilinear', " + "'nearest' or 'log_euclidean'") self.data = data self.nb_coeffs = data.shape[-1] @@ -255,6 +271,8 @@ def _vox_to_value(self, x, y, z, origin): # They use round(point), not floor. This is the equivalent of # origin = 'center'. result = nearestneighbor_interpolate(self.data, coord) + elif self.interpolation == 'log_euclidean': + result = self._log_euclidean_interpolate4d(coord) else: # Trilinear # They do not say it explicitly but they verify if @@ -269,6 +287,107 @@ def _vox_to_value(self, x, y, z, origin): raise Exception("No interpolation method was given, cannot run " "this method..") + @staticmethod + def _lo_tri_to_tensor(lo_tri): + dxx, dxy, dyy, dxz, dyz, dzz = lo_tri + return np.array([[dxx, dxy, dxz], + [dxy, dyy, dyz], + [dxz, dyz, dzz]], dtype=np.float64) + + @staticmethod + def _tensor_to_lo_tri(tensor): + return np.array([tensor[0, 0], tensor[0, 1], tensor[1, 1], + tensor[0, 2], tensor[1, 2], tensor[2, 2]], + dtype=np.float64) + + @staticmethod + def _spd_logm(tensor, eps=1e-12): + eigvals, eigvecs = np.linalg.eigh(tensor) + eigvals = np.maximum(eigvals, eps) + return eigvecs @ np.diag(np.log(eigvals)) @ eigvecs.T + + @staticmethod + def _spd_expm(sym_tensor): + eigvals, eigvecs = np.linalg.eigh(sym_tensor) + return eigvecs @ np.diag(np.exp(eigvals)) @ eigvecs.T + + def _log_euclidean_interpolate4d(self, coord): + """ + Interpolate a 4D tensor volume using Log-Euclidean interpolation. + + The input volume must store diffusion tensors in lower-triangular + 6-coefficient form: [Dxx, Dxy, Dyy, Dxz, Dyz, Dzz]. Each tensor is + converted to a 3x3 SPD matrix, mapped to the log-domain, interpolated + with trilinear weights in that domain, and mapped back with the matrix + exponential. + + Parameters + ---------- + coord: ndarray shape (3,) + Coordinate in voxel space, already converted to center-origin + indexing. + + Returns + ------- + ndarray shape (6,) + Interpolated tensor coefficients in lower-triangular form. + + Notes + ----- + This implementation assumes SPD tensors. Eigenvalues are clamped to a + small positive epsilon before taking the logarithm to avoid numerical + issues on nearly-singular tensors. + """ + if self.data.shape[-1] != 6: + raise ValueError("log_euclidean interpolation requires 6 tensor " + "coefficients per voxel.") + + # Convert center-based coordinates ([-0.5, N-0.5)) to index space + # ([0, N)) for trilinear weights. + idx_coord = coord + 0.5 + + x0 = int(np.floor(idx_coord[0])) + y0 = int(np.floor(idx_coord[1])) + z0 = int(np.floor(idx_coord[2])) + + x0 = np.clip(x0, 0, self.dim[0] - 1) + y0 = np.clip(y0, 0, self.dim[1] - 1) + z0 = np.clip(z0, 0, self.dim[2] - 1) + + x1 = min(x0 + 1, self.dim[0] - 1) + y1 = min(y0 + 1, self.dim[1] - 1) + z1 = min(z0 + 1, self.dim[2] - 1) + + dx = idx_coord[0] - x0 + dy = idx_coord[1] - y0 + dz = idx_coord[2] - z0 + + if x1 == x0: + dx = 0.0 + if y1 == y0: + dy = 0.0 + if z1 == z0: + dz = 0.0 + + weighted_log_tensor = np.zeros((3, 3), dtype=np.float64) + corners = [(x0, y0, z0, (1.0 - dx) * (1.0 - dy) * (1.0 - dz)), + (x1, y0, z0, dx * (1.0 - dy) * (1.0 - dz)), + (x0, y1, z0, (1.0 - dx) * dy * (1.0 - dz)), + (x1, y1, z0, dx * dy * (1.0 - dz)), + (x0, y0, z1, (1.0 - dx) * (1.0 - dy) * dz), + (x1, y0, z1, dx * (1.0 - dy) * dz), + (x0, y1, z1, (1.0 - dx) * dy * dz), + (x1, y1, z1, dx * dy * dz)] + + for i, j, k, w in corners: + if w == 0.0: + continue + tensor = self._lo_tri_to_tensor(self.data[i, j, k]) + weighted_log_tensor += w * self._spd_logm(tensor) + + interp_tensor = self._spd_expm(weighted_log_tensor) + return self._tensor_to_lo_tri(interp_tensor) + def _is_vox_in_bound(self, x, y, z, origin): """ Test if voxel is in dataset range. diff --git a/src/scilpy/tracking/propagator.py b/src/scilpy/tracking/propagator.py index 32f7e6b5d..738afa499 100644 --- a/src/scilpy/tracking/propagator.py +++ b/src/scilpy/tracking/propagator.py @@ -7,6 +7,7 @@ from dipy.data import get_sphere from dipy.io.stateful_tractogram import Space, Origin from dipy.reconst.shm import sh_to_sf_matrix +from dipy.reconst.dti import eig_from_lo_tri from scilpy.reconst.utils import (get_sphere_neighbours, get_sh_order_and_fullness) @@ -692,3 +693,226 @@ def _get_possible_next_dirs(self, pos, v_in): valid_volumes = np.array(valid_volumes) return valid_dirs, valid_volumes + + +class TensorPropagator(AbstractPropagator): + """ + Propagator for DTI tensor tracking. Tracks along the principal + eigenvector of the diffusion tensor. + """ + def __init__(self, datavolume, step_size, rk_order, algo, theta, + std=None, space=Space('vox'), origin=Origin('center')): + """ + Parameters + ---------- + datavolume: scilpy.image.volume_space_management.DataVolume + Trackable DataVolume object containing tensor data in lower + triangular format (6 coefficients: Dxx, Dxy, Dyy, Dxz, Dyz, Dzz). + step_size: float + The step size for tracking. + rk_order: int + Order for the Runge Kutta integration. + algo: string + Type of algorithm. Choices are 'det' or 'prob' + theta: float + Maximum angle (radians) between two steps. + std: float or None + Standard deviation for the Gaussian noise added to the principal + eigenvector in probabilistic tracking. If None, a default value of + 0.1 is used. Ignored for deterministic tracking. + space: dipy Space + Space of the streamlines during tracking. Default: VOX. + origin: dipy Origin + Origin of the streamlines during tracking. Default: center. + """ + super().__init__(datavolume, step_size, rk_order, space, origin) + + if self.space == Space.RASMM: + raise NotImplementedError( + "This version of the propagator on tensors is not ready to work " + "in RASMM space.") + + self.algo = algo + if self.algo == "prob" and std is None: + self.std = 0.1 + elif self.algo == "det": + self.std = 0 + else: + self.std = std + + self.theta = theta + self.normalize_directions = True + self.line_rng_generator = None + + logging.debug(f"Algo: ${self.algo}") + logging.debug(f"Theta: ${self.theta}") + logging.debug(f"Std: ${self.std}") + + def reset_data(self, new_data=None): + return super().reset_data(new_data) + + def prepare_forward(self, seeding_pos, random_generator): + """Get initial direction from tensor at seeding position.""" + self.line_rng_generator = random_generator + + # Get tensor at seeding position + tensor_data = self.datavolume.get_value_at_coordinate( + *seeding_pos, space=self.space, origin=self.origin) + + if tensor_data is None: + logging.debug(f"Seed at {seeding_pos}: tensor_data is None") + return PropagationStatus.ERROR + + if np.all(tensor_data == 0): + logging.debug(f"Seed at {seeding_pos}: tensor_data is all zeros") + return PropagationStatus.ERROR + + # Get principal eigenvector + direction = self._get_direction_from_tensor(tensor_data) + + if direction is None: + logging.debug(f"Seed at {seeding_pos}: failed to extract direction from tensor") + return PropagationStatus.ERROR + + return TrackingDirection(direction) + + def prepare_backward(self, line, forward_dir): + """Flip direction for backward tracking.""" + # forward_dir is a TrackingDirection (which is a list) + return TrackingDirection(-np.array(forward_dir)) + + def finalize_streamline(self, last_pos, v_in): + return super().finalize_streamline(last_pos, v_in) + + def propagate(self, line, v_in): + """Propagate using Runge-Kutta integration.""" + return super().propagate(line, v_in) + + def _sample_next_direction(self, pos, v_in): + """Sample next direction from tensor.""" + tensor_data = self.datavolume.get_value_at_coordinate( + *pos, space=self.space, origin=self.origin) + + if tensor_data is None or np.all(tensor_data == 0): + return None + + # Get principal eigenvector and local eigenvalues. + direction, evals = self._get_direction_from_tensor( + tensor_data, return_evals=True) + + if direction is None: + return None + + # Check angle constraint + cosine = np.dot(v_in, direction) / (np.linalg.norm(v_in) * np.linalg.norm(direction)) + cosine = np.clip(cosine, -1, 1) + + # Flip if needed to maintain direction continuity + if cosine < 0: + direction = -direction + cosine = abs(cosine) + + if np.arccos(cosine) > self.theta: + return None + + # For probabilistic tracking, add some noise + if self.algo == 'prob' and self.line_rng_generator is not None: + # Scale angular perturbation by local anisotropy (FA). + # Rationale: principal direction is less stable in low-FA regions + # (crossing/isotropic tissue) and more stable in high-FA regions, + # so we use std_local = std_global * (1 - FA). + local_std = self._compute_local_std_from_evals(evals) + noise = self.line_rng_generator.normal(0, local_std, 3) + direction = direction + noise + direction = direction / np.linalg.norm(direction) + + return direction + + def _get_direction_from_tensor(self, tensor_data, return_evals=False): + """ + Extract principal eigenvector from tensor data. + + Parameters + ---------- + tensor_data : ndarray + Tensor coefficients in lower triangular format (6 values). + return_evals: bool + If True, also return sorted eigenvalues. + + Returns + ------- + direction : ndarray or None + Principal eigenvector (3D direction). + evals : ndarray or None + Sorted eigenvalues (descending), returned only when + return_evals=True. + """ + if len(tensor_data) != 6: + logging.warning(f"Expected 6 tensor coefficients, got {len(tensor_data)}") + return (None, None) if return_evals else None + + logging.debug(f"Tensor data: {tensor_data}") + + # Compute eigenvalues and eigenvectors + # eig_from_lo_tri returns a flat array of 12 values: + # [eval1, eval2, eval3, evec1_x, evec1_y, evec1_z, evec2_x, evec2_y, evec2_z, evec3_x, evec3_y, evec3_z] + try: + result = eig_from_lo_tri(tensor_data) + evals = result[:3] # First 3 values are eigenvalues + evecs = result[3:].reshape(3, 3) # Next 9 values are eigenvectors (3x3 matrix) + logging.debug(f"Eigenvalues: {evals}, Eigenvectors shape: {evecs.shape}") + except Exception as e: + logging.debug(f"Exception in eig_from_lo_tri: {e}") + return (None, None) if return_evals else None + + # Sort by eigenvalue magnitude (largest first) + order = np.argsort(evals)[::-1] + evals = evals[order] + evecs = evecs[:, order] + + # Principal eigenvector (associated with largest eigenvalue) + principal_evec = evecs[:, 0] + + logging.debug(f"Principal eigenvector: {principal_evec}, Sorted eigenvalues: {evals}") + + if return_evals: + return principal_evec, evals + return principal_evec + + def _compute_local_std_from_evals(self, evals): + """ + Compute voxel-wise noise std from local eigenvalues. + + Uses FA-derived scaling so that isotropic tensors get larger + perturbations and highly anisotropic tensors get smaller ones: + + std_local = std_global * (1 - FA) + + This is a pragmatic uncertainty proxy (not a full posterior model): + FA captures how strongly diffusion is oriented, so low FA implies less + directional confidence and therefore wider perturbation. + + References + ---------- + Basser, P. J., Mattiello, J., & LeBihan, D. (1994). MR diffusion + tensor spectroscopy and imaging. Biophysical Journal, 66(1), 259-267. + + Pierpaoli, C., & Basser, P. J. (1996). Toward a quantitative + assessment of diffusion anisotropy. Magnetic Resonance in Medicine, + 36(6), 893-906. + """ + if evals is None: + return self.std + + evals = np.asarray(evals, dtype=np.float64) + evals = np.maximum(evals, 0.0) + denom = np.sum(evals * evals) + + if denom <= 0.0: + fa = 0.0 + else: + mean_eval = np.mean(evals) + fa = np.sqrt(1.5 * np.sum((evals - mean_eval) ** 2) / denom) + fa = float(np.clip(fa, 0.0, 1.0)) + + return self.std * (1.0 - fa) diff --git a/src/scilpy/tracking/rap.py b/src/scilpy/tracking/rap.py index 981ba1576..59a8b6f2e 100644 --- a/src/scilpy/tracking/rap.py +++ b/src/scilpy/tracking/rap.py @@ -163,6 +163,25 @@ def rap_multistep_propagate(self, line, prev_direction): self.propagator = new_propagator logging.debug(f"RAP propagator switched to default label {self._propagators.keys()[0]}") + # Normalize previous direction representation when switching + # propagator families. + # + # ODF propagators rely on TrackingDirection.index to lookup angular + # neighborhoods on the discrete sphere (tracking_neighbours). Tensor + # propagators only need the Cartesian direction and can operate on a + # plain ndarray. + # + # Therefore: + # - Tensor -> ODF: wrap/quantize direction using prepare_backward so + # an index is available on the target ODF sphere. + # - ODF -> Tensor: drop the index and keep only Cartesian components + # to avoid carrying stale sphere metadata across models. + prev_direction_has_index = getattr(prev_direction, 'index', None) is not None + if hasattr(self.propagator, 'tracking_neighbours') and not prev_direction_has_index: + prev_direction = self.propagator.prepare_backward(line, prev_direction) + elif not hasattr(self.propagator, 'tracking_neighbours') and prev_direction_has_index: + prev_direction = np.asarray(prev_direction) + # Perform propagation with new parameters new_pos, new_dir, is_direction_valid = self.propagator.propagate( line, prev_direction) diff --git a/src/scilpy/tracking/tracker.py b/src/scilpy/tracking/tracker.py index c8db1f1c6..a3b4c8918 100644 --- a/src/scilpy/tracking/tracker.py +++ b/src/scilpy/tracking/tracker.py @@ -67,7 +67,8 @@ def __init__(self, propagator: AbstractPropagator, mask: DataVolume, Memory-mapping mode. One of {None, 'r+', 'c'}. This value is passed to np.load() when loading the raw tracking data from a subprocess. rng_seed: int - The random "seed" for the random generator. + The random "seed" for the random generator. If None, a random + uint32 seed is generated at runtime. track_forward_only: bool If true, only the forward direction is computed. skip: int @@ -101,7 +102,11 @@ def __init__(self, propagator: AbstractPropagator, mask: DataVolume, self.compression_th = compression_th self.save_seeds = save_seeds self.mmap_mode = mmap_mode - self.rng_seed = rng_seed + if rng_seed is None: + self.rng_seed = int(np.random.default_rng().integers( + 0, np.iinfo(np.uint32).max, dtype=np.uint32)) + else: + self.rng_seed = int(np.uint32(rng_seed)) self.track_forward_only = track_forward_only self.append_last_point = append_last_point self.skip = skip @@ -749,4 +754,4 @@ def _track(self): # output is yielded so that we can use LazyTractogram. # seed and strl with origin center (same as DIPY) - yield strl - 0.5, seed - 0.5 \ No newline at end of file + yield strl - 0.5, seed - 0.5 diff --git a/src/scilpy/tracking/utils.py b/src/scilpy/tracking/utils.py index 6a8f26a96..83b738eae 100644 --- a/src/scilpy/tracking/utils.py +++ b/src/scilpy/tracking/utils.py @@ -17,7 +17,9 @@ from dipy.reconst.shm import sh_to_sf_matrix from dipy.tracking.streamlinespeed import compress_streamlines, length from scilpy.io.utils import (add_compression_arg, add_overwrite_arg, - add_sh_basis_args) + add_sh_basis_args, add_sphere_arg) +from scilpy.io.tensor import (supported_tensor_formats, + tensor_format_description) from scilpy.reconst.utils import find_order_from_nb_coeff, get_maximas @@ -32,35 +34,12 @@ def __init__(self, cartesian, index=None): self.index = index -def add_mandatory_options_tracking(p, fodf_optional=False): +def add_mandatory_options_tracking(p, fodf_mandatory=True): """ Args that are required in both scil_tracking_local and scil_tracking_local_dev scripts. """ - if fodf_optional: - odf_group = p.add_mutually_exclusive_group() - odf_group.add_argument('--in_odf', default=None, - help='File containing the orientation \n' - 'diffusion function as spherical harmonics \n' - 'file (.nii.gz). Ex: ODF or fODF. \n' - 'If not provided, fODF info must be \n' - 'specified in rap_policies.json.') - odf_group.add_argument('--rap_params', default=None, - help='JSON file containing RAP parameters, \n' - 'mutually exclusive with --in_odf.\n' - 'Required for --rap_method switch.\n' - 'Expected format:\n' - '{\n' - ' "methods": {\n' - ' "1": {"propagator": "ODF", "filename": str,\n' - ' "sh_basis": str, "algo": str,\n' - ' "theta": float, "step_size": float},\n' - ' "2": {"propagator": "ODF", "filename": str,\n' - ' "sh_basis": str, "algo": str,\n' - ' "theta": float, "step_size": float}\n' - ' }\n' - '}') - else: + if fodf_mandatory: p.add_argument('in_odf', help='File containing the orientation diffusion function \n' 'as spherical harmonics file (.nii.gz). \n' @@ -97,15 +76,53 @@ def add_tracking_options(p): 'too big, streamline is \nstopped and the ' 'following point is NOT included.\n' '["eudx"=60, "det"=45, "prob"=20, "ptt"=20]') - track_g.add_argument('--sfthres', dest='sf_threshold', metavar='sf_th', - type=float, default=0.1, - help='Spherical function relative threshold. ' - '[%(default)s]') - add_sh_basis_args(track_g) - return track_g +def add_tracking_tensor_options(p): + tensor_options = p.add_argument_group('Tensor options') + tensor_options.add_argument('--tensor_format', type=str, default='dipy', + choices=supported_tensor_formats, + help="Format of the input tensor file.\n" + "Only used with --in_tensor. [%(default)s]\n" + + tensor_format_description) + tensor_options.add_argument('--tensor_interp', type=str, + default='log_euclidean', + choices=['nearest', 'trilinear', + 'log_euclidean'], + help="Tensor interpolation method. " + "Only used with --in_tensor. " + "[%(default)s]") + tensor_options.add_argument('--std', type=float, default=0.1, + help="Standard deviation of the noise added " + "to the direction for prob tensor-based tracking.") + + +def add_tracking_sh_options(p): + sh_options = p.add_argument_group('Spherical harmonics options') + add_sphere_arg(sh_options, symmetric_only=False) + sh_options.add_argument('--sub_sphere', + type=int, default=0, + help='Subdivides each face of the sphere into 4^s new' + ' faces. [%(default)s]') + sh_options.add_argument('--sfthres_init', metavar='sf_th', type=float, + default=0.5, dest='sf_threshold_init', + help="Spherical function relative threshold value " + "for the \ninitial direction. [%(default)s]") + sh_options.add_argument('--rk_order', metavar="K", type=int, default=1, + choices=[1, 2, 4], + help="The order of the Runge-Kutta integration used " + "for the step function.\n" + "For more information, refer to the note in the" + " script description. [%(default)s]") + sh_options.add_argument('--sfthres', dest='sf_threshold', metavar='sf_th', + type=float, default=0.1, + help='Spherical function relative threshold. ' + '[%(default)s]') + add_sh_basis_args(sh_options) + return sh_options + + def add_tracking_ptt_options(p): track_g = p.add_argument_group('PTT options') track_g.add_argument('--probe_length', dest='probe_length',