FastSurferCNN.utils.metrics¶
- class FastSurferCNN.utils.metrics.DiceScore(num_classes, device=None, output_transform=<function DiceScore.<lambda>>)[source]¶
Accumulate the component of the dice coefficient i.e. the union and intersection.
Attributes
op
(callable) A callable to update accumulator. Method’s signature is
(accumulator, output)
. For example, to compute arithmetic mean value,op = lambda a, x: a + x
.output_transform
(callable, optional) A callable that is used to transform the
Engine
’sprocess_function
’s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs.device
(str of torch.device, optional) Device specification in case of distributed computation usage. In most of the cases, it can be defined as “cuda:local_rank” or “cuda” if already set
torch.cuda.set_device(local_rank)
. By default, if a distributed process group is initialized and available, device is set tocuda
.Methods
- FastSurferCNN.utils.metrics.iou_score(pred_cls, true_cls, nclass=79)[source]¶
Compute the intersection-over-union score.
Both inputs should be categorical (as opposed to one-hot).
- Parameters:
- pred_cls
torch.Tensor
Network prediction (categorical).
- true_cls
torch.Tensor
Ground truth (categorical).
- nclass
int
Number of classes (Default value = 79).
- pred_cls
- Returns:
np.ndarray
An array containing the intersection for each class.
np.ndarray
An array containing the union for each class.
- FastSurferCNN.utils.metrics.precision_recall(pred_cls, true_cls, nclass=79)[source]¶
Calculate recall (TP/(TP + FN) and precision (TP/(TP+FP) per class.
- Parameters:
- pred_cls
torch.Tensor
Network prediction (categorical).
- true_cls
torch.Tensor
Ground truth (categorical).
- nclass
int
Number of classes (Default value = 79).
- pred_cls
- Returns:
np.ndarray
An array containing the number of true positives for each class.
np.ndarray
An array containing the sum of true positives and false negatives for each class.
np.ndarray
An array containing the sum of true positives and false positives for each class.