CorpusCallosum.localization.inference

CorpusCallosum.localization.inference.get_transforms()[source]

Get preprocessing transforms for inference.

Returns:
transforms.Compose

Composed transform pipeline including: - Intensity scaling to [0,1] - Fixed size cropping around AC-PC points

CorpusCallosum.localization.inference.load_model(device)[source]

Load trained numerical localization model from checkpoint.

Parameters:
devicetorch.device

Device to load model to.

Returns:
DenseNet

Loaded and initialized model in evaluation mode.

CorpusCallosum.localization.inference.predict(model, image_volume, patch_center, device=None, transform=None)[source]

Run inference on an image volume

Parameters:
modelDenseNet

Trained model for inference.

image_volumenp.ndarray

Input volume as numpy array.

patch_centernp.ndarray

Initial center point estimate for cropping.

devicetorch.device, optional

Device to run inference on, by default None.

transformtransforms.Transform, optional

Custom transform pipeline, defaults to preconfigured transforms of get_transforms.

Returns:
pc_ccordnp.ndarray

Predicted PC coordinates.

ac_coordnp.ndarray

Predicted AC coordinates.

crop_offsetspair of ints

Crop offsets (left, top).

CorpusCallosum.localization.inference.preprocess_volume(image_volume, center_pt, transform=None)[source]

Preprocess a volume for inference.

Parameters:
image_volumenp.ndarray

Input image volume of shape (W, W, D) in RAS.

center_ptnp.ndarray

Center point coordinates for cropping on the slice with shape (3,).

transformtransforms.Transform or None, optional

Custom transform pipeline, by default None. If None, uses default transforms from get_transforms().

Returns:
dict[str, torch.Tensor | tuple[int, …]]

Dictionary containing preprocessed image tensor.

CorpusCallosum.localization.inference.run_inference_on_slice(model, image_slab, center_pt, num_iterations=2, debug_output=None)[source]

Run inference on a single slice to detect AC and PC points.

Parameters:
modeltorch.nn.Module

Trained model for AC-PC detection.

image_slabnp.ndarray

3D image mid-slices to run inference on in RAS.

center_ptnp.ndarray

Initial center point estimate for cropping.

num_iterationsint, default=2

Number of refinement iterations to run.

debug_outputstr, optional

Path to save debug visualization.

Returns:
ac_coordsnp.ndarray

Detected AC voxel coordinates with shape (2,) containing its [y,x] positions.

pc_coordsnp.ndarray

Detected PC voxel coordinates with shape (2,) containing its [y,x] positions.