Source code for neurolit.utils.plotting

# Copyright 2026 DeepMI Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from collections.abc import Sequence
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import torch


[docs] def plot_batch(image_batch: torch.Tensor, image_path: str, slice_cut: None | list[int] | torch.Tensor = None) -> None: """Plot a batch of 2D or 3D images with different slice views. Parameters ---------- image_batch : torch.Tensor Input batch of images to plot. Shape should be: - For 3D: (batch_size, channels, H, W, D) - For 2D: (batch_size, channels, H, W) image_path : str Path where to save the output plot slice_cut : list of int, optional List of indices where to cut the slices. If None, uses middle slice. For 3D: [x, y, z] indices For 2D: [x, y] indices """ dim = len(image_batch.shape) - 2 BATCH_SIZE = image_batch.shape[0] if slice_cut is None: if dim == 3: slice_cut = [ image_batch.shape[2] // 2, image_batch.shape[3] // 2, image_batch.shape[4] // 2, ] else: slice_cut = [image_batch.shape[2] // 2, image_batch.shape[3] // 2] else: if torch.is_tensor(slice_cut): slice_cut = slice_cut.detach().cpu().tolist() slice_cut = [int(v) for v in slice_cut] if dim == 3: slice_cut[0] = int(np.clip(slice_cut[0], 0, image_batch.shape[2] - 1)) slice_cut[1] = int(np.clip(slice_cut[1], 0, image_batch.shape[3] - 1)) slice_cut[2] = int(np.clip(slice_cut[2], 0, image_batch.shape[4] - 1)) elif dim == 2: slice_cut[0] = int(np.clip(slice_cut[0], 0, image_batch.shape[2] - 1)) slice_cut[1] = int(np.clip(slice_cut[1], 0, image_batch.shape[3] - 1)) if dim == 3: slices = [ [(i, 0, slice(None), slice(None), slice_cut[2]) for i in range(BATCH_SIZE)], [(i, 0, slice(None), slice_cut[1], slice(None)) for i in range(BATCH_SIZE)], [(i, 0, slice_cut[0], slice(None), slice(None)) for i in range(BATCH_SIZE)], ] elif dim == 2: middle_slice = image_batch.shape[1] // 2 slices = [ [(i, middle_slice, slice(None), slice(None)) for i in range(BATCH_SIZE)], [(i, slice(None), slice_cut[0], slice(None)) for i in range(BATCH_SIZE)], [(i, slice(None), slice(None), slice_cut[1]) for i in range(BATCH_SIZE)], ] else: raise ValueError("Invalid dimension") fig, axs = plt.subplots(3, BATCH_SIZE, figsize=(1.5 * BATCH_SIZE, 3)) if BATCH_SIZE == 1: axs = np.expand_dims(axs, axis=1) for s, ax in zip(slices, axs, strict=True): for i in range(BATCH_SIZE): ax[i].imshow(image_batch[s[i]].detach().cpu(), vmin=0, vmax=1, cmap="gray") # plt.axis("off") ax[i].set_xticks([]) ax[i].set_yticks([]) plt.tight_layout() plt.savefig(image_path, dpi=300, bbox_inches="tight") plt.close()
[docs] def plot_learning_curve(n_epochs: int, epoch_loss_list: list[float], val_epoch_loss_list: list[float], args: Any, VAL_INTERVAL: int) -> None: """Plot training and validation learning curves. Parameters ---------- n_epochs : int Total number of epochs epoch_loss_list : list of float List of training losses for each epoch val_epoch_loss_list : list of float List of validation losses (recorded every VAL_INTERVAL) args : Any Arguments object containing output directory information VAL_INTERVAL : int Interval at which validation was performed """ plt.style.use("seaborn-v0_8") plt.title("Learning Curves", fontsize=20) plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color="C0", linewidth=2.0, label="Train") plt.plot( np.linspace(VAL_INTERVAL, n_epochs, int(n_epochs / VAL_INTERVAL)), val_epoch_loss_list, color="C1", linewidth=2.0, label="Validation", ) plt.yticks(fontsize=12) plt.xticks(fontsize=12) plt.xlabel("Epochs", fontsize=16) plt.ylabel("Loss", fontsize=16) plt.legend(prop={"size": 14}) # log scale plt.yscale("log") plt.savefig(os.path.join(args.out_dir, "images/learning_curves.png"), dpi=100, bbox_inches="tight") plt.close()
[docs] def plot_scheduler(scheduler: Any, output_file: str) -> None: """Plot the scheduler's alpha values over time. Parameters ---------- scheduler : Any Diffusion scheduler object containing alphas_cumprod output_file : str Path where to save the output plot """ plt.plot(scheduler.alphas_cumprod.cpu(), color=(2 / 255, 163 / 255, 163 / 255), linewidth=2) plt.xlabel("Timestep [t]") plt.ylabel("alpha cumprod") plt.savefig(output_file, dpi=100, bbox_inches="tight") plt.close()
[docs] def plot_inpainting( val_image: torch.Tensor, val_image_masked: torch.Tensor, val_image_inpainted: torch.Tensor, out_file: str, SLICE_CUT: torch.Tensor | Sequence[int], cut_dim: int = 0, ) -> None: """Plot original, masked, and inpainted images side by side. Parameters ---------- val_image : torch.Tensor Original image tensor val_image_masked : torch.Tensor Masked version of the image val_image_inpainted : torch.Tensor Inpainted result out_file : str Path where to save the output plot SLICE_CUT : torch.Tensor Index where to cut the slice for visualization cut_dim : int, optional Dimension along which to take the slice (0=sagittal, 1=coronal, 2=axial). Defaults to 0. """ if len(val_image.shape) == 5: val_image = val_image[:, 0, ...] val_image_masked = val_image_masked[:, 0, ...] val_image_inpainted = val_image_inpainted[:, 0, ...] if torch.is_tensor(SLICE_CUT): slice_cut = SLICE_CUT.detach().cpu().tolist() else: slice_cut = [int(v) for v in SLICE_CUT] spatial_shape = val_image.shape[1:4] for axis in range(3): slice_cut[axis] = int(np.clip(slice_cut[axis], 0, spatial_shape[axis] - 1)) slicing = [0, slice(None), slice(None), slice(None)] slicing[cut_dim + 1] = slice_cut[cut_dim] slicing = tuple(slicing) # Convert to tuple for PyTorch 2.9+ compatibility fig, axs = plt.subplots(1, 3, figsize=(12, 4)) axs[0].imshow(val_image[slicing].cpu(), cmap="gray") axs[0].set_title("Original image") axs[0].axis("off") axs[1].imshow(val_image_masked[slicing].cpu(), cmap="gray") axs[1].axis("off") axs[1].set_title("Masked image") axs[2].imshow(val_image_inpainted[slicing].cpu(), cmap="gray") axs[2].axis("off") axs[2].set_title("Inpainted image") plt.savefig(out_file, dpi=300, bbox_inches="tight") plt.close()