Utils Module
Overview
The utils module provides utility functions for various tasks including model downloading, plotting, and general helpers.
Submodules
download_checkpoints
- neurolit.utils.download_checkpoints.download_checkpoint(checkpoint_name: str, checkpoint_path: str | Path, urls: list[str], verbose: bool = False, show_progress: bool = True) None[source]
Download a checkpoint file with progress bar.
Raises a
RequestExceptionif the remote file cannot be fetched or the download is incomplete. PropagatesOSErrorif the checkpoint cannot be written to disk.- Parameters:
checkpoint_name (str) – Name of checkpoint.
checkpoint_path (Path, str) – Path of the file in which the checkpoint will be saved.
urls (list[str]) – List of URLs of checkpoint hosting sites.
verbose (bool) – Whether to print verbose output.
show_progress (bool) – Whether to show download progress bar.
- neurolit.utils.download_checkpoints.check_and_download_ckpts(checkpoint_path: Path | str, urls: list[str], verbose: bool = False, show_progress: bool = True) None[source]
Check and download a checkpoint file, if it does not exist.
- neurolit.utils.download_checkpoints.fallback_multiple_urls(checkpoint_path: Path | str, urls: list[str], verbose: bool = False, show_progress: bool = True) None[source]
Download and manage model checkpoints.
Key Functions:
download_models(): Download all model checkpointsget_model_path(): Get path to specific modelverify_checksum(): Verify model integrity
plotting
- neurolit.utils.plotting.plot_batch(image_batch: Tensor, image_path: str, slice_cut: None | list[int] | Tensor = None) None[source]
Plot a batch of 2D or 3D images with different slice views.
- Parameters:
image_batch (torch.Tensor) – Input batch of images to plot. Shape should be: - For 3D: (batch_size, channels, H, W, D) - For 2D: (batch_size, channels, H, W)
image_path (str) – Path where to save the output plot
slice_cut (list of int, optional) – List of indices where to cut the slices. If None, uses middle slice. For 3D: [x, y, z] indices For 2D: [x, y] indices
- neurolit.utils.plotting.plot_learning_curve(n_epochs: int, epoch_loss_list: list[float], val_epoch_loss_list: list[float], args: Any, VAL_INTERVAL: int) None[source]
Plot training and validation learning curves.
- Parameters:
n_epochs (int) – Total number of epochs
epoch_loss_list (list of float) – List of training losses for each epoch
val_epoch_loss_list (list of float) – List of validation losses (recorded every VAL_INTERVAL)
args (Any) – Arguments object containing output directory information
VAL_INTERVAL (int) – Interval at which validation was performed
- neurolit.utils.plotting.plot_scheduler(scheduler: Any, output_file: str) None[source]
Plot the scheduler’s alpha values over time.
- Parameters:
scheduler (Any) – Diffusion scheduler object containing alphas_cumprod
output_file (str) – Path where to save the output plot
- neurolit.utils.plotting.plot_inpainting(val_image: Tensor, val_image_masked: Tensor, val_image_inpainted: Tensor, out_file: str, SLICE_CUT: Tensor | Sequence[int], cut_dim: int = 0) None[source]
Plot original, masked, and inpainted images side by side.
- Parameters:
val_image (torch.Tensor) – Original image tensor
val_image_masked (torch.Tensor) – Masked version of the image
val_image_inpainted (torch.Tensor) – Inpainted result
out_file (str) – Path where to save the output plot
SLICE_CUT (torch.Tensor) – Index where to cut the slice for visualization
cut_dim (int, optional) – Dimension along which to take the slice (0=sagittal, 1=coronal, 2=axial). Defaults to 0.
Visualization utilities for brain MRI and results.
Key Functions:
plot_slices(): Plot 2D slicesplot_comparison(): Compare before/after inpaintingplot_mask_overlay(): Overlay mask on image
Examples
Downloading Models
from neurolit.utils.download_checkpoints import download_models
# Download all models
download_models(force=False) # Skip if already downloaded
Command-Line Download
# Download models
lit-download-models
# Force re-download
lit-download-models --force
Getting Model Paths
from neurolit.utils.download_checkpoints import get_model_path
# Get path to specific model
axial_path = get_model_path('axial')
coronal_path = get_model_path('coronal')
sagittal_path = get_model_path('sagittal')
print(f"Axial model: {axial_path}")
Plotting Results
from neurolit.utils.plotting import plot_comparison
import nibabel as nib
# Load images
original = nib.load('T1w.nii.gz').get_fdata()
inpainted = nib.load('T1w_inpainted.nii.gz').get_fdata()
mask = nib.load('lesion_mask.nii.gz').get_fdata()
# Plot comparison
plot_comparison(
original=original,
inpainted=inpainted,
mask=mask,
slice_idx=128,
save_path='comparison.png'
)
Plotting Slices
from neurolit.utils.plotting import plot_slices
import nibabel as nib
# Load image
image = nib.load('T1w.nii.gz').get_fdata()
# Plot axial, coronal, and sagittal slices
plot_slices(
image,
title='T1-weighted MRI',
save_path='slices.png'
)
Mask Overlay
from neurolit.utils.plotting import plot_mask_overlay
import nibabel as nib
# Load data
image = nib.load('T1w.nii.gz').get_fdata()
mask = nib.load('lesion_mask.nii.gz').get_fdata()
# Plot with overlay
plot_mask_overlay(
image=image,
mask=mask,
slice_idx=128,
alpha=0.5, # Overlay transparency
save_path='overlay.png'
)
Custom Visualization
from neurolit.utils.plotting import plot_slices
import matplotlib.pyplot as plt
import nibabel as nib
# Load multiple images
original = nib.load('T1w.nii.gz').get_fdata()
inpainted = nib.load('T1w_inpainted.nii.gz').get_fdata()
# Create custom figure
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
slice_idx = 128
axes[0].imshow(original[:, :, slice_idx], cmap='gray')
axes[0].set_title('Original')
axes[0].axis('off')
axes[1].imshow(inpainted[:, :, slice_idx], cmap='gray')
axes[1].set_title('Inpainted')
axes[1].axis('off')
plt.tight_layout()
plt.savefig('custom_comparison.png', dpi=300)
plt.close()
Batch Visualization
from neurolit.utils.plotting import plot_comparison
from pathlib import Path
import nibabel as nib
subjects = ['sub-01', 'sub-02', 'sub-03']
for subject in subjects:
original = nib.load(f'data/{subject}/T1w.nii.gz').get_fdata()
inpainted = nib.load(f'output/{subject}/T1w_inpainted.nii.gz').get_fdata()
mask = nib.load(f'data/{subject}/lesion_mask.nii.gz').get_fdata()
plot_comparison(
original=original,
inpainted=inpainted,
mask=mask,
slice_idx=128,
save_path=f'visualizations/{subject}_comparison.png'
)