Networks Module
Overview
The networks module contains neural network architectures used in neuroLIT.
Submodules
DiffusionUnet
- class neurolit.networks.DiffusionUnet.DiffusionModelUNetVINN(spatial_dims: int, in_channels: int, out_channels: int, internal_size: Sequence[int], num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), num_channels: Sequence[int] = (32, 64, 64, 64), attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, norm_eps: float = 1e-06, resblock_updown: bool = False, num_head_channels: int | Sequence[int] = 8, with_conditioning: bool = False, transformer_num_layers: int = 1, cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, interpolation_mode: str = 'trilinear', use_fp16_VINN: bool = False, is_vinn: bool = False)[source]
Bases:
ModuleUnet network with timestep embedding and attention mechanisms for conditioning based on Rombach et al. “High-Resolution Image Synthesis with Latent Diffusion Models” https://arxiv.org/abs/2112.10752 and Pinaya et al. “Brain Imaging Generation with Latent Diffusion Models” https://arxiv.org/abs/2209.07162
- Parameters:
spatial_dims (number of spatial dimensions.)
in_channels (number of input channels.)
out_channels (number of output channels.)
num_res_blocks (number of residual blocks (see ResnetBlock) per level.)
num_channels (tuple of block output channels.)
attention_levels (list of levels to add attention.)
norm_num_groups (number of groups for the normalization.)
norm_eps (epsilon for the normalization.)
resblock_updown (if True use residual blocks for up/downsampling.)
num_head_channels (number of channels in each attention head.)
with_conditioning (if True add spatial transformers to perform conditioning.)
transformer_num_layers (number of layers of Transformer blocks to use.)
cross_attention_dim (number of context dimensions to use.)
num_class_embeds (if specified (as an int), then this model will be class-conditional with num_class_embeds)
classes.
upcast_attention (if True, upcast attention operations to full precision.)
use_flash_attention (if True, use flash attention for a memory efficient attention mechanism.)
- __init__(spatial_dims: int, in_channels: int, out_channels: int, internal_size: Sequence[int], num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), num_channels: Sequence[int] = (32, 64, 64, 64), attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, norm_eps: float = 1e-06, resblock_updown: bool = False, num_head_channels: int | Sequence[int] = 8, with_conditioning: bool = False, transformer_num_layers: int = 1, cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, interpolation_mode: str = 'trilinear', use_fp16_VINN: bool = False, is_vinn: bool = False) None[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, scale_factors: float, timesteps: Tensor, context: Tensor | None = None, class_labels: Tensor | None = None, down_block_additional_residuals: tuple[Tensor] | None = None, mid_block_additional_residual: Tensor | None = None) Tensor[source]
- Parameters:
x (torch.Tensor) – Input tensor (N, C, SpatialDims).
timesteps (torch.Tensor) – Timestep tensor (N,).
context (torch.Tensor, optional) – Context tensor (N, 1, ContextDim).
class_labels (torch.Tensor, optional) – Class labels tensor (N, ).
down_block_additional_residuals (tuple[torch.Tensor], optional) – Additional residual tensors for down blocks (N, C, FeatureMapsDims).
mid_block_additional_residual (torch.Tensor, optional) – Additional residual tensor for mid block (N, C, FeatureMapsDims).
U-Net architecture for diffusion-based inpainting.
Key Classes:
DiffusionUNet: Main U-Net modelTimeEmbedding: Time embedding for diffusion stepsResidualBlock: Residual block with time conditioning
interpolation_layer
- class neurolit.networks.interpolation_layer.Zoom2d(interpolation_mode: str = 'nearest', crop_position: str = 'top_left')[source]
Bases:
_ZoomNdPerform a crop and interpolation on a Four-dimensional Tensor respecting batch and channel.
- _N
Number of dimensions (Here 2)
- _crop_position
Position to crop
- __init__(interpolation_mode: str = 'nearest', crop_position: str = 'top_left')[source]
Construct Zoom2d object.
- Parameters:
target_shape (_T.Optional[_T.Sequence[int]]) – Target tensor size for after this module, not including batchsize and channels.
interpolation_mode (str) – interpolation mode as in torch.nn.interpolate (default: ‘nearest’)
crop_position (str) – crop position to use from ‘top_left’, ‘bottom_left’, top_right’, ‘bottom_right’, ‘center’ (default: ‘top_left’)
- class neurolit.networks.interpolation_layer.Zoom3d(interpolation_mode: str = 'nearest', crop_position: str = 'front_top_left')[source]
Bases:
_ZoomNdPerform a crop and interpolation on a Five-dimensional Tensor respecting batch and channel.
- __init__(interpolation_mode: str = 'nearest', crop_position: str = 'front_top_left')[source]
Construct Zoom3d object.
- Parameters:
target_shape (_T.Optional[_T.Sequence[int]]) – Target tensor size for after this module, not including batchsize and channels.
interpolation_mode (str) – interpolation mode as in torch.nn.interpolate (default: ‘neareast’)
crop_position (str) – crop position to use from ‘front_top_left’, ‘back_top_left’, ‘front_bottom_left’, ‘back_bottom_left’, ‘front_top_right’, ‘back_top_right’, ‘front_bottom_right’, ‘back_bottom_right’, ‘center’ (default: ‘front_top_left’)
- class neurolit.networks.interpolation_layer.OnlyZoom3d(interpolation_mode: str = 'nearest')[source]
Bases:
ModulePerform a crop and interpolation on a Five-dimensional Tensor respecting batch and channel.
- __init__(interpolation_mode: str = 'nearest')[source]
Construct Zoom3d object.
- Parameters:
target_shape (_T.Optional[_T.Sequence[int]]) – Target tensor size for after this module, not including batchsize and channels.
interpolation_mode (str) – interpolation mode as in torch.nn.interpolate (default: ‘neareast’)
crop_position (str) – crop position to use from ‘front_top_left’, ‘back_top_left’, ‘front_bottom_left’, ‘back_bottom_left’, ‘front_top_right’, ‘back_top_right’, ‘front_bottom_right’, ‘back_bottom_right’, ‘center’ (default: ‘front_top_left’)
- forward(input_tensor, out_shape)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Custom interpolation layers for spatial processing.
Key Classes:
InterpolationLayer: Bilinear/trilinear interpolationAdaptivePooling: Adaptive pooling layer
Architecture Details
DiffusionUNet Architecture
The DiffusionUNet is a modified U-Net with:
Encoder: Downsampling path with residual blocks
Bottleneck: Deepest layer with attention
Decoder: Upsampling path with skip connections
Time Conditioning: Timestep embeddings at each level
Input (1, 256, 256)
↓
[Encoder Block 1] → [Skip Connection] ──┐
↓ ↓
[Encoder Block 2] → [Skip Connection] ──┼─┐
↓ ↓ ↓
[Encoder Block 3] → [Skip Connection] ──┼─┼─┐
↓ ↓ ↓ ↓
[Bottleneck] ↓ ↓ ↓
↓ ↓ ↓ ↓
[Decoder Block 3] ←─────────────────────┘ ↓ ↓
↓ ↓ ↓
[Decoder Block 2] ←───────────────────────┘ ↓
↓ ↓
[Decoder Block 1] ←─────────────────────────┘
↓
Output (1, 256, 256)
Examples
Using DiffusionUNet
from neurolit.networks.DiffusionUnet import DiffusionUNet
import torch
# Create model
model = DiffusionUNet(
in_channels=1,
out_channels=1,
base_channels=64,
num_res_blocks=2,
attention_levels=[2, 3]
)
# Move to GPU
model = model.cuda()
model.eval()
# Forward pass
x = torch.randn(1, 1, 256, 256).cuda()
t = torch.tensor([100]).cuda() # Timestep
with torch.no_grad():
output = model(x, t)
Loading Pre-trained Weights
from neurolit.networks.DiffusionUnet import DiffusionUNet
import torch
# Create model
model = DiffusionUNet()
# Load checkpoint
checkpoint = torch.load('model_axial.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
Custom Model Configuration
from neurolit.networks.DiffusionUnet import DiffusionUNet
# Create custom model
model = DiffusionUNet(
in_channels=1,
out_channels=1,
base_channels=128, # More channels
channel_multipliers=[1, 2, 4, 8],
num_res_blocks=3, # More residual blocks
attention_levels=[2, 3, 4], # More attention
dropout=0.1
)
Training Example
from neurolit.networks.DiffusionUnet import DiffusionUNet
import torch
import torch.nn as nn
import torch.optim as optim
# Setup
model = DiffusionUNet().cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()
# Training loop
model.train()
for batch in dataloader:
x = batch['image'].cuda()
t = batch['timestep'].cuda()
noise = batch['noise'].cuda()
# Forward pass
pred_noise = model(x, t)
loss = criterion(pred_noise, noise)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
Model Summary
from neurolit.networks.DiffusionUnet import DiffusionUNet
model = DiffusionUNet()
# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {num_params:,}")
# Count trainable parameters
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {num_trainable:,}")