# Copyright 2026 DeepMI Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections.abc import Sequence
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
import torch
[docs]
def plot_batch(image_batch: torch.Tensor, image_path: str, slice_cut: None | list[int] | torch.Tensor = None) -> None:
"""Plot a batch of 2D or 3D images with different slice views.
Parameters
----------
image_batch : torch.Tensor
Input batch of images to plot. Shape should be:
- For 3D: (batch_size, channels, H, W, D)
- For 2D: (batch_size, channels, H, W)
image_path : str
Path where to save the output plot
slice_cut : list of int, optional
List of indices where to cut the slices. If None, uses middle slice.
For 3D: [x, y, z] indices
For 2D: [x, y] indices
"""
dim = len(image_batch.shape) - 2
BATCH_SIZE = image_batch.shape[0]
if slice_cut is None:
if dim == 3:
slice_cut = [
image_batch.shape[2] // 2,
image_batch.shape[3] // 2,
image_batch.shape[4] // 2,
]
else:
slice_cut = [image_batch.shape[2] // 2, image_batch.shape[3] // 2]
else:
if torch.is_tensor(slice_cut):
slice_cut = slice_cut.detach().cpu().tolist()
slice_cut = [int(v) for v in slice_cut]
if dim == 3:
slice_cut[0] = int(np.clip(slice_cut[0], 0, image_batch.shape[2] - 1))
slice_cut[1] = int(np.clip(slice_cut[1], 0, image_batch.shape[3] - 1))
slice_cut[2] = int(np.clip(slice_cut[2], 0, image_batch.shape[4] - 1))
elif dim == 2:
slice_cut[0] = int(np.clip(slice_cut[0], 0, image_batch.shape[2] - 1))
slice_cut[1] = int(np.clip(slice_cut[1], 0, image_batch.shape[3] - 1))
if dim == 3:
slices = [
[(i, 0, slice(None), slice(None), slice_cut[2]) for i in range(BATCH_SIZE)],
[(i, 0, slice(None), slice_cut[1], slice(None)) for i in range(BATCH_SIZE)],
[(i, 0, slice_cut[0], slice(None), slice(None)) for i in range(BATCH_SIZE)],
]
elif dim == 2:
middle_slice = image_batch.shape[1] // 2
slices = [
[(i, middle_slice, slice(None), slice(None)) for i in range(BATCH_SIZE)],
[(i, slice(None), slice_cut[0], slice(None)) for i in range(BATCH_SIZE)],
[(i, slice(None), slice(None), slice_cut[1]) for i in range(BATCH_SIZE)],
]
else:
raise ValueError("Invalid dimension")
fig, axs = plt.subplots(3, BATCH_SIZE, figsize=(1.5 * BATCH_SIZE, 3))
if BATCH_SIZE == 1:
axs = np.expand_dims(axs, axis=1)
for s, ax in zip(slices, axs, strict=True):
for i in range(BATCH_SIZE):
ax[i].imshow(image_batch[s[i]].detach().cpu(), vmin=0, vmax=1, cmap="gray")
# plt.axis("off")
ax[i].set_xticks([])
ax[i].set_yticks([])
plt.tight_layout()
plt.savefig(image_path, dpi=300, bbox_inches="tight")
plt.close()
[docs]
def plot_learning_curve(n_epochs: int, epoch_loss_list: list[float], val_epoch_loss_list: list[float], args: Any, VAL_INTERVAL: int) -> None:
"""Plot training and validation learning curves.
Parameters
----------
n_epochs : int
Total number of epochs
epoch_loss_list : list of float
List of training losses for each epoch
val_epoch_loss_list : list of float
List of validation losses (recorded every VAL_INTERVAL)
args : Any
Arguments object containing output directory information
VAL_INTERVAL : int
Interval at which validation was performed
"""
plt.style.use("seaborn-v0_8")
plt.title("Learning Curves", fontsize=20)
plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color="C0", linewidth=2.0, label="Train")
plt.plot(
np.linspace(VAL_INTERVAL, n_epochs, int(n_epochs / VAL_INTERVAL)),
val_epoch_loss_list,
color="C1",
linewidth=2.0,
label="Validation",
)
plt.yticks(fontsize=12)
plt.xticks(fontsize=12)
plt.xlabel("Epochs", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.legend(prop={"size": 14})
# log scale
plt.yscale("log")
plt.savefig(os.path.join(args.out_dir, "images/learning_curves.png"), dpi=100, bbox_inches="tight")
plt.close()
[docs]
def plot_scheduler(scheduler: Any, output_file: str) -> None:
"""Plot the scheduler's alpha values over time.
Parameters
----------
scheduler : Any
Diffusion scheduler object containing alphas_cumprod
output_file : str
Path where to save the output plot
"""
plt.plot(scheduler.alphas_cumprod.cpu(), color=(2 / 255, 163 / 255, 163 / 255), linewidth=2)
plt.xlabel("Timestep [t]")
plt.ylabel("alpha cumprod")
plt.savefig(output_file, dpi=100, bbox_inches="tight")
plt.close()
[docs]
def plot_inpainting(
val_image: torch.Tensor,
val_image_masked: torch.Tensor,
val_image_inpainted: torch.Tensor,
out_file: str,
SLICE_CUT: torch.Tensor | Sequence[int],
cut_dim: int = 0,
) -> None:
"""Plot original, masked, and inpainted images side by side.
Parameters
----------
val_image : torch.Tensor
Original image tensor
val_image_masked : torch.Tensor
Masked version of the image
val_image_inpainted : torch.Tensor
Inpainted result
out_file : str
Path where to save the output plot
SLICE_CUT : torch.Tensor
Index where to cut the slice for visualization
cut_dim : int, optional
Dimension along which to take the slice (0=sagittal, 1=coronal, 2=axial).
Defaults to 0.
"""
if len(val_image.shape) == 5:
val_image = val_image[:, 0, ...]
val_image_masked = val_image_masked[:, 0, ...]
val_image_inpainted = val_image_inpainted[:, 0, ...]
if torch.is_tensor(SLICE_CUT):
slice_cut = SLICE_CUT.detach().cpu().tolist()
else:
slice_cut = [int(v) for v in SLICE_CUT]
spatial_shape = val_image.shape[1:4]
for axis in range(3):
slice_cut[axis] = int(np.clip(slice_cut[axis], 0, spatial_shape[axis] - 1))
slicing = [0, slice(None), slice(None), slice(None)]
slicing[cut_dim + 1] = slice_cut[cut_dim]
slicing = tuple(slicing) # Convert to tuple for PyTorch 2.9+ compatibility
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].imshow(val_image[slicing].cpu(), cmap="gray")
axs[0].set_title("Original image")
axs[0].axis("off")
axs[1].imshow(val_image_masked[slicing].cpu(), cmap="gray")
axs[1].axis("off")
axs[1].set_title("Masked image")
axs[2].imshow(val_image_inpainted[slicing].cpu(), cmap="gray")
axs[2].axis("off")
axs[2].set_title("Inpainted image")
plt.savefig(out_file, dpi=300, bbox_inches="tight")
plt.close()