FastSurferCNN.inference

class FastSurferCNN.inference.Inference(cfg, device, ckpt='', lut=None)[source]

Model evaluation class to run inference using FastSurferCNN.

Attributes

permute_order

(Dict[str, Tuple[int, int, int, int]]) Permutation order for axial, coronal, and sagittal

device

(Optional[torch.device])) Device specification for distributed computation usage.

default_device

(torch.device) Default device specification for distributed computation usage.

cfg

(yacs.config.CfgNode) Configuration Node

model_parallel

(bool) Option for parallel run

model

(torch.nn.Module) Neural network model

model_name

(str) Name of the model

alpha

(Dict[str, float]) Alpha values for different planes.

post_prediction_mapping_hook

Hook for post prediction mapping.

Methods

setup_model([cfg, device])

Set up the model.

set_cfg(cfg)

Set the configuration node.

to([device])

Move and/or cast the parameters and buffers.

load_checkpoint(ckpt)

Load the checkpoint and set device and model.

eval(init_pred, val_loader, *[, out_scale, out])

Perform prediction and inplace-aggregate views into pred_prob.

run(init_pred, img_filename, orig_data, ...)

Run the loaded model on the data (T1) from orig_data and img_filename (for messages only) with scale factors orig_zoom.

eval(init_pred, val_loader, *, out_scale=None, out=None)[source]

Perform prediction and inplace-aggregate views into pred_prob.

Parameters:
init_predtorch.Tensor

Initial prediction.

val_loaderDataLoader

Validation loader.

out_scaleOptional

Output scale (Default value = None).

outOptional[torch.Tensor]

Previous prediction tensor (Default value = None).

Returns:
torch.Tensor

Prediction probability tensor.

get_cfg()[source]

Return the configurations.

Returns:
yacs.config.CfgNode

Configuration node.

get_device()[source]

Return the device.

Returns:
torch.device

The device used for computation.

get_max_size()[source]

Return the max size.

Returns:
int | tuple[int, int]

The maximum size, either a single value or a tuple (width, height).

get_model_height()[source]

Return the model height.

Returns:
int

The height of the model.

get_model_width()[source]

Return the model width.

Returns:
int

The width of the model.

get_modelname()[source]

Return the model name.

Returns:
str

The name of the model.

get_num_classes()[source]

Return the number of classes.

Returns:
int

The number of classes.

get_plane()[source]

Return the plane.

Returns:
str

The plane used in the model.

load_checkpoint(ckpt)[source]

Load the checkpoint and set device and model.

Parameters:
ckptUnion[str, os.PathLike]

String or os.PathLike object containing the name to the checkpoint file.

run(init_pred, img_filename, orig_data, orig_zoom, out=None, out_res=None, batch_size=None)[source]

Run the loaded model on the data (T1) from orig_data and img_filename (for messages only) with scale factors orig_zoom.

Parameters:
init_predtorch.Tensor

Initial prediction.

img_filenamestr

Original image filename.

orig_datanpt.NDArray

Original image data.

orig_zoomnpt.NDArray

Original zoom.

outOptional[torch.Tensor]

Updated output tensor (Default = None).

out_resOptional[int]

Output resolution (Default value = None).

batch_sizeint

Batch size (Default = None).

Returns:
toch.Tensor

Prediction probability tensor.

set_cfg(cfg)[source]

Set the configuration node.

Parameters:
cfgyacs.config.CfgNode

Configuration node.

setup_model(cfg=None, device=None)[source]

Set up the model.

Parameters:
cfgyacs.config.CfgNode

Configuration Node (Default value = None).

devicetorch.device

Device specification for distributed computation usage. (Default value = None).

to(device=None)[source]

Move and/or cast the parameters and buffers.

Parameters:
deviceOptional[torch.device]

The desired device of the parameters and buffers in this module (Default value = None).