CerebNet.utils.meters

class CerebNet.utils.meters.Meter(cfg, mode, global_step, total_iter=None, total_epoch=None, class_names=None, device=None, writer=None)[source]

Meter class.

Methods

log_epoch(cur_epoch)

Log mean Dice score and confusion matrix at the end of each epoch.

log_iter(cur_iter, cur_epoch)

Log training or validation progress at each iteration.

log_lr(lr[, step])

Log learning rate at each step.

prediction_visualize(cur_iter, cur_epoch, ...)

Visualize prediction results for current iteration and epoch.

reset()

Reset function.

update_stats(pred, labels[, loss_dict])

Update stats.

write_summary(loss_dict)

Write summary.

log_epoch(cur_epoch)[source]

Log mean Dice score and confusion matrix at the end of each epoch.

Parameters:
cur_epochint

Current epoch number.

Returns:
dice_scorefloat

The mean Dice score for the non-background classes.

log_iter(cur_iter, cur_epoch)[source]

Log training or validation progress at each iteration.

Parameters:
cur_iterint

The current iteration number.

cur_epochint

The current epoch number.

log_lr(lr, step=None)[source]

Log learning rate at each step.

Parameters:
lrlist

Learning rate at the current step. Expected to be a list where the first element is the learning rate.

stepint, optional

Current step number. If not provided, the global iteration number is used.

prediction_visualize(cur_iter, cur_epoch, img_batch, label_batch, pred_batch)[source]

Visualize prediction results for current iteration and epoch.

Parameters:
cur_iterint

Current iteration number.

cur_epochint

Current epoch number.

img_batchtorch.Tensor

Input image batch.

label_batchtorch.Tensor

Ground truth label batch.

pred_batchtorch.Tensor

Predicted label batch.

reset()[source]

Reset function.

update_stats(pred, labels, loss_dict=None)[source]

Update stats.

Parameters:
predtorch.Tensor

Predicted labels.

labelstorch.Tensor

Ground truth labels.

loss_dictdict, optional

Dictionary containing loss values.

write_summary(loss_dict)[source]

Write summary.

Parameters:
loss_dictdict

Dictionary containing loss values.

class CerebNet.utils.meters.TestMeter(classname_to_ids)[source]

TestMeter class.

Methods

metrics_per_class(pred, gt)

Compute metrics for each class in the predicted and ground truth segmentation maps.

metrics_per_class(pred, gt)[source]

Compute metrics for each class in the predicted and ground truth segmentation maps.

Parameters:
prednp.array

Predicted segmentation map.

gtnp.array

Ground truth segmentation map.

Returns:
metricsdict

Dict containing metrics for each class.