HypVINN.models.networks¶
- class HypVINN.models.networks.HypVINN(params, padded_size=256)[source]¶
HypVINN class that extends the FastSurferCNNBase class.
This class represents a HypVINN model. It includes methods for initializing the model, setting up the layers, and performing forward propagation.
Attributes
height
(int) The height of the output tensor.
width
(int) The width of the output tensor.
out_tensor_shape
(tuple) The shape of the output tensor.
interpolation_mode
(str) The interpolation mode to use when resizing the images. This can be ‘nearest’, ‘bilinear’, ‘bicubic’, or ‘area’.
crop_position
(str) The position to crop the images from. This can be ‘center’, ‘top_left’, ‘top_right’, ‘bottom_left’, or ‘bottom_right’.
m1_inp_block
(InputDenseBlock) The input block for the first modality.
m2_inp_block
(InputDenseBlock) The input block for the second modality.
mod_weights
(nn.Parameter) The weights for the two modalities.
normalize_weights
(nn.Softmax) A softmax function to normalize the modality weights.
outp_block
(OutputDenseBlock) The output block of the model.
interpol1
(Zoom2d) The first interpolation layer.
interpol2
(Zoom2d) The second interpolation layer.
classifier
(ClassifierBlock) The final classifier block of the model.
Methods
forward(x, scale_factor, weight_factor, scale_factor_out=None)
Perform forward propagation through the model.
- forward(x, scale_factor, weight_factor, scale_factor_out=None)[source]¶
Forward propagation method for the HypVINN model.
This method takes an input tensor, a scale factor, a weight factor, and an optional output scale factor. It performs forward propagation through the model, applying the input blocks, interpolation layers, output block, and classifier block. It also handles the weighting of the two modalities and the rescaling of the output.
- Parameters:
- x
torch.Tensor
The input tensor. It should have a shape of (batch_size, num_channels, height, width).
- scale_factor
torch.Tensor
The scale factor for the input images. It should have a shape of (batch_size, 2).
- weight_factor
torch.Tensor
The weight factor for the two modalities. It should have a shape of (batch_size, 2).
- scale_factor_out
torch.Tensor
,optional
The scale factor for the output images. If not provided, it defaults to the scale factor of the input images.
- x
- Returns:
- logits
torch.Tensor
The output logits from the classifier block. It has a shape of (batch_size, num_classes, height, width).
- logits
- Raises:
ValueError
If the interpolation mode or crop position is invalid.
- HypVINN.models.networks.build_model(cfg)[source]¶
Build and return the requested model.
- Parameters:
- cfg
yacs.config.CfgNode
The configuration node containing the parameters for the model.
- cfg
- Returns:
HypVINN
An instance of the requested model.
- Raises:
AssertionError
If the model specified in the configuration is not supported.