Utils Module

Overview

The utils module provides utility functions for various tasks including model downloading, plotting, and general helpers.

Submodules

download_checkpoints

neurolit.utils.download_checkpoints.download_checkpoint(checkpoint_name: str, checkpoint_path: str | Path, urls: list[str], verbose: bool = False, show_progress: bool = True) None[source]

Download a checkpoint file with progress bar.

Raises a RequestException if the remote file cannot be fetched or the download is incomplete. Propagates OSError if the checkpoint cannot be written to disk.

Parameters:
  • checkpoint_name (str) – Name of checkpoint.

  • checkpoint_path (Path, str) – Path of the file in which the checkpoint will be saved.

  • urls (list[str]) – List of URLs of checkpoint hosting sites.

  • verbose (bool) – Whether to print verbose output.

  • show_progress (bool) – Whether to show download progress bar.

neurolit.utils.download_checkpoints.check_and_download_ckpts(checkpoint_path: Path | str, urls: list[str], verbose: bool = False, show_progress: bool = True) None[source]

Check and download a checkpoint file, if it does not exist.

Parameters:
  • checkpoint_path (Path, str) – Path of the file in which the checkpoint will be saved.

  • urls (list[str]) – URLs of checkpoint hosting site.

  • verbose (bool) – Whether to print verbose output.

  • show_progress (bool) – Whether to show download progress bar.

neurolit.utils.download_checkpoints.fallback_multiple_urls(checkpoint_path: Path | str, urls: list[str], verbose: bool = False, show_progress: bool = True) None[source]
neurolit.utils.download_checkpoints.main(argv=None)[source]

Download and manage model checkpoints.

Key Functions:

  • download_models(): Download all model checkpoints

  • get_model_path(): Get path to specific model

  • verify_checksum(): Verify model integrity

plotting

neurolit.utils.plotting.plot_batch(image_batch: Tensor, image_path: str, slice_cut: None | list[int] | Tensor = None) None[source]

Plot a batch of 2D or 3D images with different slice views.

Parameters:
  • image_batch (torch.Tensor) – Input batch of images to plot. Shape should be: - For 3D: (batch_size, channels, H, W, D) - For 2D: (batch_size, channels, H, W)

  • image_path (str) – Path where to save the output plot

  • slice_cut (list of int, optional) – List of indices where to cut the slices. If None, uses middle slice. For 3D: [x, y, z] indices For 2D: [x, y] indices

neurolit.utils.plotting.plot_learning_curve(n_epochs: int, epoch_loss_list: list[float], val_epoch_loss_list: list[float], args: Any, VAL_INTERVAL: int) None[source]

Plot training and validation learning curves.

Parameters:
  • n_epochs (int) – Total number of epochs

  • epoch_loss_list (list of float) – List of training losses for each epoch

  • val_epoch_loss_list (list of float) – List of validation losses (recorded every VAL_INTERVAL)

  • args (Any) – Arguments object containing output directory information

  • VAL_INTERVAL (int) – Interval at which validation was performed

neurolit.utils.plotting.plot_scheduler(scheduler: Any, output_file: str) None[source]

Plot the scheduler’s alpha values over time.

Parameters:
  • scheduler (Any) – Diffusion scheduler object containing alphas_cumprod

  • output_file (str) – Path where to save the output plot

neurolit.utils.plotting.plot_inpainting(val_image: Tensor, val_image_masked: Tensor, val_image_inpainted: Tensor, out_file: str, SLICE_CUT: Tensor | Sequence[int], cut_dim: int = 0) None[source]

Plot original, masked, and inpainted images side by side.

Parameters:
  • val_image (torch.Tensor) – Original image tensor

  • val_image_masked (torch.Tensor) – Masked version of the image

  • val_image_inpainted (torch.Tensor) – Inpainted result

  • out_file (str) – Path where to save the output plot

  • SLICE_CUT (torch.Tensor) – Index where to cut the slice for visualization

  • cut_dim (int, optional) – Dimension along which to take the slice (0=sagittal, 1=coronal, 2=axial). Defaults to 0.

Visualization utilities for brain MRI and results.

Key Functions:

  • plot_slices(): Plot 2D slices

  • plot_comparison(): Compare before/after inpainting

  • plot_mask_overlay(): Overlay mask on image

Examples

Downloading Models

from neurolit.utils.download_checkpoints import download_models

# Download all models
download_models(force=False)  # Skip if already downloaded

Command-Line Download

# Download models
lit-download-models

# Force re-download
lit-download-models --force

Getting Model Paths

from neurolit.utils.download_checkpoints import get_model_path

# Get path to specific model
axial_path = get_model_path('axial')
coronal_path = get_model_path('coronal')
sagittal_path = get_model_path('sagittal')

print(f"Axial model: {axial_path}")

Plotting Results

from neurolit.utils.plotting import plot_comparison
import nibabel as nib

# Load images
original = nib.load('T1w.nii.gz').get_fdata()
inpainted = nib.load('T1w_inpainted.nii.gz').get_fdata()
mask = nib.load('lesion_mask.nii.gz').get_fdata()

# Plot comparison
plot_comparison(
    original=original,
    inpainted=inpainted,
    mask=mask,
    slice_idx=128,
    save_path='comparison.png'
)

Plotting Slices

from neurolit.utils.plotting import plot_slices
import nibabel as nib

# Load image
image = nib.load('T1w.nii.gz').get_fdata()

# Plot axial, coronal, and sagittal slices
plot_slices(
    image,
    title='T1-weighted MRI',
    save_path='slices.png'
)

Mask Overlay

from neurolit.utils.plotting import plot_mask_overlay
import nibabel as nib

# Load data
image = nib.load('T1w.nii.gz').get_fdata()
mask = nib.load('lesion_mask.nii.gz').get_fdata()

# Plot with overlay
plot_mask_overlay(
    image=image,
    mask=mask,
    slice_idx=128,
    alpha=0.5,  # Overlay transparency
    save_path='overlay.png'
)

Custom Visualization

from neurolit.utils.plotting import plot_slices
import matplotlib.pyplot as plt
import nibabel as nib

# Load multiple images
original = nib.load('T1w.nii.gz').get_fdata()
inpainted = nib.load('T1w_inpainted.nii.gz').get_fdata()

# Create custom figure
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

slice_idx = 128
axes[0].imshow(original[:, :, slice_idx], cmap='gray')
axes[0].set_title('Original')
axes[0].axis('off')

axes[1].imshow(inpainted[:, :, slice_idx], cmap='gray')
axes[1].set_title('Inpainted')
axes[1].axis('off')

plt.tight_layout()
plt.savefig('custom_comparison.png', dpi=300)
plt.close()

Batch Visualization

from neurolit.utils.plotting import plot_comparison
from pathlib import Path
import nibabel as nib

subjects = ['sub-01', 'sub-02', 'sub-03']

for subject in subjects:
    original = nib.load(f'data/{subject}/T1w.nii.gz').get_fdata()
    inpainted = nib.load(f'output/{subject}/T1w_inpainted.nii.gz').get_fdata()
    mask = nib.load(f'data/{subject}/lesion_mask.nii.gz').get_fdata()

    plot_comparison(
        original=original,
        inpainted=inpainted,
        mask=mask,
        slice_idx=128,
        save_path=f'visualizations/{subject}_comparison.png'
    )