Networks Module

Overview

The networks module contains neural network architectures used in neuroLIT.

Submodules

DiffusionUnet

class neurolit.networks.DiffusionUnet.DiffusionModelUNetVINN(spatial_dims: int, in_channels: int, out_channels: int, internal_size: Sequence[int], num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), num_channels: Sequence[int] = (32, 64, 64, 64), attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, norm_eps: float = 1e-06, resblock_updown: bool = False, num_head_channels: int | Sequence[int] = 8, with_conditioning: bool = False, transformer_num_layers: int = 1, cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, interpolation_mode: str = 'trilinear', use_fp16_VINN: bool = False, is_vinn: bool = False)[source]

Bases: Module

Unet network with timestep embedding and attention mechanisms for conditioning based on Rombach et al. “High-Resolution Image Synthesis with Latent Diffusion Models” https://arxiv.org/abs/2112.10752 and Pinaya et al. “Brain Imaging Generation with Latent Diffusion Models” https://arxiv.org/abs/2209.07162

Parameters:
  • spatial_dims (number of spatial dimensions.)

  • in_channels (number of input channels.)

  • out_channels (number of output channels.)

  • num_res_blocks (number of residual blocks (see ResnetBlock) per level.)

  • num_channels (tuple of block output channels.)

  • attention_levels (list of levels to add attention.)

  • norm_num_groups (number of groups for the normalization.)

  • norm_eps (epsilon for the normalization.)

  • resblock_updown (if True use residual blocks for up/downsampling.)

  • num_head_channels (number of channels in each attention head.)

  • with_conditioning (if True add spatial transformers to perform conditioning.)

  • transformer_num_layers (number of layers of Transformer blocks to use.)

  • cross_attention_dim (number of context dimensions to use.)

  • num_class_embeds (if specified (as an int), then this model will be class-conditional with num_class_embeds)

  • classes.

  • upcast_attention (if True, upcast attention operations to full precision.)

  • use_flash_attention (if True, use flash attention for a memory efficient attention mechanism.)

__init__(spatial_dims: int, in_channels: int, out_channels: int, internal_size: Sequence[int], num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), num_channels: Sequence[int] = (32, 64, 64, 64), attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, norm_eps: float = 1e-06, resblock_updown: bool = False, num_head_channels: int | Sequence[int] = 8, with_conditioning: bool = False, transformer_num_layers: int = 1, cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, interpolation_mode: str = 'trilinear', use_fp16_VINN: bool = False, is_vinn: bool = False) None[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, scale_factors: float, timesteps: Tensor, context: Tensor | None = None, class_labels: Tensor | None = None, down_block_additional_residuals: tuple[Tensor] | None = None, mid_block_additional_residual: Tensor | None = None) Tensor[source]
Parameters:
  • x (torch.Tensor) – Input tensor (N, C, SpatialDims).

  • timesteps (torch.Tensor) – Timestep tensor (N,).

  • context (torch.Tensor, optional) – Context tensor (N, 1, ContextDim).

  • class_labels (torch.Tensor, optional) – Class labels tensor (N, ).

  • down_block_additional_residuals (tuple[torch.Tensor], optional) – Additional residual tensors for down blocks (N, C, FeatureMapsDims).

  • mid_block_additional_residual (torch.Tensor, optional) – Additional residual tensor for mid block (N, C, FeatureMapsDims).

U-Net architecture for diffusion-based inpainting.

Key Classes:

  • DiffusionUNet: Main U-Net model

  • TimeEmbedding: Time embedding for diffusion steps

  • ResidualBlock: Residual block with time conditioning

interpolation_layer

class neurolit.networks.interpolation_layer.Zoom2d(interpolation_mode: str = 'nearest', crop_position: str = 'top_left')[source]

Bases: _ZoomNd

Perform a crop and interpolation on a Four-dimensional Tensor respecting batch and channel.

_N

Number of dimensions (Here 2)

_crop_position

Position to crop

_interpolate()[source]

Crops, interpolates and pads the tensor

__init__(interpolation_mode: str = 'nearest', crop_position: str = 'top_left')[source]

Construct Zoom2d object.

Parameters:
  • target_shape (_T.Optional[_T.Sequence[int]]) – Target tensor size for after this module, not including batchsize and channels.

  • interpolation_mode (str) – interpolation mode as in torch.nn.interpolate (default: ‘nearest’)

  • crop_position (str) – crop position to use from ‘top_left’, ‘bottom_left’, top_right’, ‘bottom_right’, ‘center’ (default: ‘top_left’)

class neurolit.networks.interpolation_layer.Zoom3d(interpolation_mode: str = 'nearest', crop_position: str = 'front_top_left')[source]

Bases: _ZoomNd

Perform a crop and interpolation on a Five-dimensional Tensor respecting batch and channel.

__init__(interpolation_mode: str = 'nearest', crop_position: str = 'front_top_left')[source]

Construct Zoom3d object.

Parameters:
  • target_shape (_T.Optional[_T.Sequence[int]]) – Target tensor size for after this module, not including batchsize and channels.

  • interpolation_mode (str) – interpolation mode as in torch.nn.interpolate (default: ‘neareast’)

  • crop_position (str) – crop position to use from ‘front_top_left’, ‘back_top_left’, ‘front_bottom_left’, ‘back_bottom_left’, ‘front_top_right’, ‘back_top_right’, ‘front_bottom_right’, ‘back_bottom_right’, ‘center’ (default: ‘front_top_left’)

class neurolit.networks.interpolation_layer.OnlyZoom3d(interpolation_mode: str = 'nearest')[source]

Bases: Module

Perform a crop and interpolation on a Five-dimensional Tensor respecting batch and channel.

__init__(interpolation_mode: str = 'nearest')[source]

Construct Zoom3d object.

Parameters:
  • target_shape (_T.Optional[_T.Sequence[int]]) – Target tensor size for after this module, not including batchsize and channels.

  • interpolation_mode (str) – interpolation mode as in torch.nn.interpolate (default: ‘neareast’)

  • crop_position (str) – crop position to use from ‘front_top_left’, ‘back_top_left’, ‘front_bottom_left’, ‘back_bottom_left’, ‘front_top_right’, ‘back_top_right’, ‘front_bottom_right’, ‘back_bottom_right’, ‘center’ (default: ‘front_top_left’)

forward(input_tensor, out_shape)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Custom interpolation layers for spatial processing.

Key Classes:

  • InterpolationLayer: Bilinear/trilinear interpolation

  • AdaptivePooling: Adaptive pooling layer

Architecture Details

DiffusionUNet Architecture

The DiffusionUNet is a modified U-Net with:

  • Encoder: Downsampling path with residual blocks

  • Bottleneck: Deepest layer with attention

  • Decoder: Upsampling path with skip connections

  • Time Conditioning: Timestep embeddings at each level

Input (1, 256, 256)
     ↓
[Encoder Block 1] → [Skip Connection] ──┐
     ↓                                   ↓
[Encoder Block 2] → [Skip Connection] ──┼─┐
     ↓                                   ↓ ↓
[Encoder Block 3] → [Skip Connection] ──┼─┼─┐
     ↓                                   ↓ ↓ ↓
[Bottleneck]                             ↓ ↓ ↓
     ↓                                   ↓ ↓ ↓
[Decoder Block 3] ←─────────────────────┘ ↓ ↓
     ↓                                     ↓ ↓
[Decoder Block 2] ←───────────────────────┘ ↓
     ↓                                       ↓
[Decoder Block 1] ←─────────────────────────┘
     ↓
Output (1, 256, 256)

Examples

Using DiffusionUNet

from neurolit.networks.DiffusionUnet import DiffusionUNet
import torch

# Create model
model = DiffusionUNet(
    in_channels=1,
    out_channels=1,
    base_channels=64,
    num_res_blocks=2,
    attention_levels=[2, 3]
)

# Move to GPU
model = model.cuda()
model.eval()

# Forward pass
x = torch.randn(1, 1, 256, 256).cuda()
t = torch.tensor([100]).cuda()  # Timestep

with torch.no_grad():
    output = model(x, t)

Loading Pre-trained Weights

from neurolit.networks.DiffusionUnet import DiffusionUNet
import torch

# Create model
model = DiffusionUNet()

# Load checkpoint
checkpoint = torch.load('model_axial.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

Custom Model Configuration

from neurolit.networks.DiffusionUnet import DiffusionUNet

# Create custom model
model = DiffusionUNet(
    in_channels=1,
    out_channels=1,
    base_channels=128,  # More channels
    channel_multipliers=[1, 2, 4, 8],
    num_res_blocks=3,  # More residual blocks
    attention_levels=[2, 3, 4],  # More attention
    dropout=0.1
)

Training Example

from neurolit.networks.DiffusionUnet import DiffusionUNet
import torch
import torch.nn as nn
import torch.optim as optim

# Setup
model = DiffusionUNet().cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

# Training loop
model.train()
for batch in dataloader:
    x = batch['image'].cuda()
    t = batch['timestep'].cuda()
    noise = batch['noise'].cuda()

    # Forward pass
    pred_noise = model(x, t)
    loss = criterion(pred_noise, noise)

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Model Summary

from neurolit.networks.DiffusionUnet import DiffusionUNet

model = DiffusionUNet()

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {num_params:,}")

# Count trainable parameters
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {num_trainable:,}")