FastSurferCNN.train

class FastSurferCNN.train.Trainer(cfg)[source]

Trainer for the networks.

Methods

__init__(cfg)

Construct Trainer object.

train(train_loader, optimizer, scheduler, ...)

Train the network to the given training data.

eval(val_loader, val_meter, epoch)

Evaluate model and calculates stats.

run()

Transfer the model to devices, create a tensor board summary writer and then perform the training loop.

eval(val_loader, val_meter, epoch)[source]

Evaluate model and calculates stats.

Parameters:
val_loaderloader.DataLoader

Value loader.

val_meterMeter

Meter for the values.

epochint

Epoch to evaluate.

Returns:
int, float, ndarray

Median mean IOU value.

run()[source]

Transfer the model to devices, create a tensor board summary writer and then perform the training loop.

train(train_loader, optimizer, scheduler, train_meter, epoch)[source]

Train the network to the given training data.

Parameters:
train_loaderloader.DataLoader

Data loader for the training.

optimizertorch.optim.Optimizer

Optimizer for the training.

schedulerNone, scheduler.StepLR, scheduler.CosineAnnealingWarmRestarts

LR scheduler for the training.

train_meterMeter

Meter to keep track of the training stats.

epochint

Current epoch.