deep learning package

deep_learning_tools

towbintools.deep_learning.deep_learning_tools.create_classification_model(architecture: str, input_channels: int, classes: list[str], learning_rate: float = 0.0001, checkpoint_path: str | None = None, normalization: dict = {'hi': 99, 'lo': 1, 'type': 'percentile'}) ClassificationModel[source]

Create a classification model.

Parameters:
  • architecture (str) – timm model name (e.g. "efficientnet_b0").

  • input_channels (int) – Number of input image channels.

  • classes (list[str]) – List of class label strings.

  • learning_rate (float, optional) – Learning rate for the Adam optimizer. (default: 1e-4)

  • checkpoint_path (str, optional) – Path to a .ckpt checkpoint; if provided, the model is loaded from the checkpoint and all other arguments are ignored. (default: None)

  • normalization (dict, optional) – Normalization config stored as a hyperparameter for inference. (default: percentile 1–99)

Returns:

Constructed or loaded classification model.

Return type:

ClassificationModel

towbintools.deep_learning.deep_learning_tools.create_keypoint_detection_model(architecture: str, input_channels: int, n_classes: int, learning_rate: float = 0.0001, checkpoint_path: str | None = None, criterion: Any | None = None, activation: str = 'relu') KeypointDetection1DModel[source]

Create a 1D keypoint detection model.

Parameters:
  • architecture (str) – Architecture name; one of "Unet", "AttentionUnet", or "UnetPlusPlus".

  • input_channels (int) – Number of input sequence channels.

  • n_classes (int) – Number of keypoint classes (output channels).

  • learning_rate (float, optional) – Learning rate for the Adam optimizer. (default: 1e-4)

  • checkpoint_path (str, optional) – Path to a .ckpt checkpoint; if provided, the model is loaded from the checkpoint and all other arguments are ignored. (default: None)

  • criterion (nn.Module, optional) – Loss function. If None, PeakWeightedMSELoss is used. (default: None)

  • activation (str, optional) – Output activation; one of "relu", "leaky_relu", "sigmoid", or "none". (default: "relu")

Returns:

Constructed or loaded keypoint detection model.

Return type:

KeypointDetection1DModel

towbintools.deep_learning.deep_learning_tools.create_segmentation_model(input_channels: int = 1, n_classes: int = 1, architecture: str = 'UnetPlusPlus', encoder: str = 'efficientnet-b4', pretrained_weights: str = 'image-micronet', normalization: dict = {'hi': 99, 'lo': 1, 'type': 'percentile'}, learning_rate: float = 1e-05, checkpoint_path: str | None = None, reset_optimizer: bool = True, criterion: Any | None = None) SegmentationModel[source]

Create a segmentation model with a pretrained encoder.

Parameters:
  • input_channels (int, optional) – Number of input image channels. (default: 1)

  • n_classes (int, optional) – Number of foreground segmentation classes. (default: 1)

  • architecture (str, optional) – smp architecture name. (default: "UnetPlusPlus")

  • encoder (str, optional) – Encoder backbone name. (default: "efficientnet-b4")

  • pretrained_weights (str, optional) – Dataset the encoder was pretrained on. (default: "image-micronet")

  • normalization (dict, optional) – Normalization config. (default: percentile 1–99)

  • learning_rate (float, optional) – Learning rate for the Adam optimizer. (default: 1e-5)

  • checkpoint_path (str, optional) – Path to a .ckpt checkpoint. If provided, the model is loaded from the checkpoint, then learning_rate and normalization are updated. (default: None)

  • reset_optimizer (bool, optional) – If True, discard the optimizer state from the checkpoint. (default: True)

  • criterion (nn.Module, optional) – Loss function. If None, a default is chosen based on n_classes. (default: None)

Returns:

Constructed or loaded segmentation model.

Return type:

SegmentationModel

Raises:

ValueError – If the checkpoint architecture or encoder does not match the requested one.

towbintools.deep_learning.deep_learning_tools.load_keypoint_detection_model_from_checkpoint(checkpoint_path: str) KeypointDetection1DModel[source]

Load a 1D keypoint detection model from a checkpoint.

Parameters:

checkpoint_path (str) – Path to a .ckpt checkpoint file.

Returns:

Loaded keypoint detection model.

Return type:

KeypointDetection1DModel

Raises:

ValueError – If the model cannot be loaded from the checkpoint.

towbintools.deep_learning.deep_learning_tools.load_segmentation_model_from_checkpoint(checkpoint_path: str) SegmentationModel[source]

Load a segmentation model from a checkpoint.

First tries a direct load; if that fails (e.g. mismatched input_channels in the checkpoint metadata), infers the channel count from the checkpoint weights and retries with pretrained_weights=None.

Parameters:

checkpoint_path (str) – Path to a .ckpt checkpoint file.

Returns:

Loaded segmentation model.

Return type:

SegmentationModel

Raises:

ValueError – If both loading attempts fail.

dataset

class towbintools.deep_learning.utils.dataset.ClassificationDataset(dataset, channels, n_classes, class_column='class', image_column='image', transform=None)[source]

Bases: Dataset

PyTorch Dataset for image classification training.

Loads images from file paths and their corresponding class labels. For multi-class problems (n_classes > 2) labels are one-hot encoded.

Parameters:
  • dataset (pd.DataFrame) – DataFrame with image_column and class_column columns.

  • channels (int or list[int]) – Channel index (or indices) to load.

  • n_classes (int) – Total number of classes. Labels are one-hot encoded when n_classes > 2.

  • class_column (str, optional) – Column name for class labels. (default: "class")

  • image_column (str, optional) – Column name for image paths. (default: "image")

  • transform (callable, optional) – MONAI Compose transform applied to {"image": ...}. (default: None)

class towbintools.deep_learning.utils.dataset.KeypointDetection1DPredictionDataset(inputs, enforce_divisibility_by=32, resize_method='pad')[source]

Bases: Dataset

PyTorch Dataset for 1D keypoint detection inference.

Stores input time-series for prediction. The collate function pads or crops all series to the same length and replaces NaN-containing series with zeros, returning their indices as invalid_series_index.

Parameters:
  • inputs (array-like) – Sequence of 1D (or 2D) input series arrays.

  • enforce_divisibility_by (int, optional) – Target batch length is rounded to a multiple of this value. (default: 32)

  • resize_method (str, optional) – "pad" or "crop". (default: "pad")

collate_fn(batch)[source]
class towbintools.deep_learning.utils.dataset.KeypointDetection1DTrainingDataset(inputs, targets, enforce_divisibility_by=32, resize_method='pad')[source]

Bases: Dataset

PyTorch Dataset for 1D keypoint detection training.

Stores pairs of input time-series and target heatmaps. The collate function pads or crops all series in a batch to the same length (a multiple of enforce_divisibility_by) and drops samples containing NaN values.

Parameters:
  • inputs (array-like) – Sequence of 1D (or 2D) input series arrays.

  • targets (array-like) – Sequence of target heatmap arrays aligned with inputs.

  • enforce_divisibility_by (int, optional) – Target batch length is rounded to a multiple of this value. (default: 32)

  • resize_method (str, optional) – "pad" or "crop". (default: "pad")

collate_fn(batch)[source]
class towbintools.deep_learning.utils.dataset.QualityControlDataset(image_paths, mask_paths, channels, labels, classes, enforce_divisibility_by=32, resize_method='pad', transform=None)[source]

Bases: Dataset

PyTorch Dataset for quality-control classification training.

Loads image + mask pairs and their quality labels. When mask_paths is provided, image and mask are concatenated along the channel axis before the transform. The collate function discards samples whose mask has no foreground.

Parameters:
  • image_paths (list[str]) – Paths to image files.

  • mask_paths (list[str] or None) – Paths to mask files. Pass None or an empty list for image-only mode.

  • channels (int or list[int]) – Channel indices to load from images.

  • labels (list) – Class labels (integers or strings matching classes).

  • classes (list) – Ordered list of class names.

  • enforce_divisibility_by (int, optional) – Batch spatial dimensions are resized to multiples of this value. (default: 32)

  • resize_method (str, optional) – "pad" or "crop". (default: "pad")

  • transform (callable, optional) – MONAI Compose transform. (default: None)

collate_fn(batch)[source]
class towbintools.deep_learning.utils.dataset.QualityControlPredictionDataset(image_paths, mask_paths, channels, enforce_divisibility_by=32, resize_method='pad', transform=None)[source]

Bases: Dataset

PyTorch Dataset for quality-control classification inference.

Loads image + mask pairs and concatenates them along the channel axis. The collate function tracks samples rejected due to empty or failed masks and returns their indices alongside valid batches.

Parameters:
  • image_paths (list[str]) – Paths to image files.

  • mask_paths (list[str]) – Paths to mask files.

  • channels (int or list[int]) – Channel indices to load from images.

  • enforce_divisibility_by (int, optional) – Batch spatial dimensions are resized to multiples of this value. (default: 32)

  • resize_method (str, optional) – "pad" or "crop". (default: "pad")

  • transform (callable, optional) – MONAI Compose transform. (default: None)

collate_fn(batch)[source]
class towbintools.deep_learning.utils.dataset.SegmentationDataset(dataset, channels, mask_column='mask', image_column='image', transform=None, enforce_divisibility_by=32, pad_or_crop='pad', mask_pad_value=-1)[source]

Bases: Dataset

PyTorch Dataset for full-image segmentation training.

Loads image and mask pairs from file paths stored in a DataFrame. Images are resized (padded or cropped) in the collate function to a divisibility-enforced common size within each batch.

Parameters:
  • dataset (pd.DataFrame) – DataFrame with image_column and mask_column columns containing file paths.

  • channels (int or list[int]) – Channel indices to load from the image.

  • mask_column (str, optional) – Column name for mask paths. (default: "mask")

  • image_column (str, optional) – Column name for image paths. (default: "image")

  • transform (callable, optional) – MONAI Compose transform. (default: None)

  • enforce_divisibility_by (int, optional) – Batch images are resized so their spatial dimensions are multiples of this value. (default: 32)

  • pad_or_crop (str, optional) – Whether to pad ("pad") or crop ("crop") images to the common batch size. (default: "pad")

  • mask_pad_value (int, optional) – Fill value used when padding masks. (default: -1)

collate_fn(batch)[source]
class towbintools.deep_learning.utils.dataset.SegmentationPredictionDataset(image_paths, channels, transform=None, enforce_divisibility_by=32, scale_factor=1.0, pad_or_crop='pad')[source]

Bases: Dataset

PyTorch Dataset for segmentation inference.

Loads images from a list of file paths, optionally rescales them, and pads or crops batches to a common size that is a multiple of enforce_divisibility_by. The collate function returns image paths, resized tensors, original shapes, and indices of images that failed to load.

Parameters:
  • image_paths (list[str]) – Paths to the image files.

  • channels (int or list[int]) – Channel indices to load.

  • transform (callable, optional) – MONAI Compose transform. (default: None)

  • enforce_divisibility_by (int, optional) – Images are resized so their spatial dimensions are multiples of this value. (default: 32)

  • scale_factor (float, optional) – Isotropic rescaling factor applied to each image before batching. (default: 1.0)

  • pad_or_crop (str, optional) – "pad" or "crop". (default: "pad")

collate_fn(batch)[source]
class towbintools.deep_learning.utils.dataset.StackPredictionDataset(stack, channels, transform=None, enforce_divisibility_by=32, pad_or_crop='pad', scale_factor=1.0)[source]

Bases: Dataset

PyTorch Dataset for plane-by-plane inference on a z-stack.

Accepts a z-stack as a NumPy array or a file path. Optionally downscales planes before loading. Each __getitem__ call returns a single plane, ready for inference.

Parameters:
  • stack (str or np.ndarray) – z-stack as an array of shape (N, H, W) or (N, C, H, W), or a path to a TIFF file.

  • channels (int, list[int], or None) – Channel indices to load when stack is a file path.

  • transform (callable, optional) – MONAI Compose transform applied per plane. (default: None)

  • enforce_divisibility_by (int, optional) – Spatial dimensions are resized to multiples of this value. (default: 32)

  • pad_or_crop (str, optional) – "pad" or "crop". (default: "pad")

  • scale_factor (float, optional) – Isotropic rescaling factor applied to each plane via cv2.resize. (default: 1.0)

class towbintools.deep_learning.utils.dataset.TiledSegmentationDataset(dataset, image_slicers, channels, mask_column='mask', image_column='image', transform=None)[source]

Bases: Dataset

PyTorch Dataset for tiled segmentation training.

Loads image and mask pairs from file paths stored in a DataFrame and returns a randomly selected tile from each sample. Tile boundaries are pre-computed using pytorch_toolbelt.inference.ImageSlicer.

Parameters:
  • dataset (pd.DataFrame) – DataFrame with at least image_column and mask_column columns containing file paths.

  • image_slicers (dict) – Mapping from image shape to ImageSlicer instance.

  • channels (int or list[int]) – Channel indices to load from the image.

  • mask_column (str, optional) – DataFrame column name for mask paths. (default: "mask")

  • image_column (str, optional) – DataFrame column name for image paths. (default: "image")

  • transform (callable, optional) – MONAI Compose transform applied to the {"image": ..., "mask": ...} dictionary. (default: None)

towbintools.deep_learning.utils.dataset.create_classification_dataloaders(training_dataframe, validation_dataframe, channels, n_classes, batch_size=64, num_workers=32, pin_memory=True, training_transform=None, validation_transform=None)[source]

Create training and validation DataLoaders for image classification.

Parameters:
  • training_dataframe (pd.DataFrame) – DataFrame with "image" and "class" columns for training data.

  • validation_dataframe (pd.DataFrame) – DataFrame with "image" and "class" columns for validation data.

  • channels (int or list[int]) – Channel indices to load.

  • n_classes (int) – Number of classes (labels are one-hot encoded when > 2).

  • batch_size (int, optional) – Batch size. (default: 64)

  • num_workers (int, optional) – DataLoader worker processes. (default: 32)

  • pin_memory (bool, optional) – Whether to pin memory. (default: True)

  • training_transform (callable, optional) – Override training transform. (default: None)

  • validation_transform (callable, optional) – Override validation transform. (default: None)

Returns:

(train_loader, val_loader).

Return type:

tuple[DataLoader, DataLoader]

towbintools.deep_learning.utils.dataset.create_classification_training_dataframes(ground_truth_csv_paths, image_columns, class_columns, save_dir, validation_set_ratio=0.25, test_set_ratio=0.1)[source]

Build training and validation DataFrames for classification from CSV ground-truth files.

Reads one or more CSV files, extracts image path and class label columns, concatenates them, splits into train/val/test sets, and saves date-stamped CSV backups.

Parameters:
  • ground_truth_csv_paths (str or list[str]) – Paths to ground-truth CSV files.

  • image_columns (str or list[str]) – Column name(s) for image paths. A single string is broadcast to all CSVs.

  • class_columns (str or list[str]) – Column name(s) for class labels.

  • save_dir (str) – Directory where backup CSVs are written.

  • validation_set_ratio (float, optional) – Fraction for validation. (default: 0.25)

  • test_set_ratio (float, optional) – Fraction for testing. (default: 0.1)

Returns:

(training_dataframe, validation_dataframe).

Return type:

tuple[pd.DataFrame, pd.DataFrame]

towbintools.deep_learning.utils.dataset.create_segmentation_dataloaders(training_dataframe, validation_dataframe, channels, batch_size=5, num_workers=32, pin_memory=True, train_on_tiles=True, tiler_params=None, training_transform=None, validation_transform=None)[source]

Create training and validation DataLoaders for segmentation.

When train_on_tiles is True, uses TiledSegmentationDataset with per-shape ImageSlicer objects; otherwise uses SegmentationDataset with full images. Default transforms (percentile normalization) are applied if no transform is supplied.

Parameters:
  • training_dataframe (pd.DataFrame) – DataFrame with "image" and "mask" columns for training data.

  • validation_dataframe (pd.DataFrame) – DataFrame with "image" and "mask" columns for validation data.

  • channels (int or list[int]) – Channel indices to load.

  • batch_size (int, optional) – Batch size. (default: 5)

  • num_workers (int, optional) – DataLoader worker processes. (default: 32)

  • pin_memory (bool, optional) – Whether to pin memory. (default: True)

  • train_on_tiles (bool, optional) – If True, sample random tiles; otherwise use full images. (default: True)

  • tiler_params (dict, optional) – Required when train_on_tiles is True; must contain "tile_size" and "tile_step" keys. (default: None)

  • training_transform (callable, optional) – Override transform for training. (default: None)

  • validation_transform (callable, optional) – Override transform for validation. (default: None)

Returns:

(train_loader, val_loader).

Return type:

tuple[DataLoader, DataLoader]

towbintools.deep_learning.utils.dataset.create_segmentation_dataloaders_from_filemap(filemap_path, save_dir, channels, image_column='image', mask_column='mask', validation_set_ratio=0.25, test_set_ratio=0.1, batch_size=5, num_workers=32, pin_memory=True, train_on_tiles=True, tiler_params=None, training_transform=None, validation_transform=None)[source]

Build segmentation DataLoaders from a CSV filemap.

Reads a CSV at filemap_path, renames the specified columns to "image" and "mask", splits into train/val/test sets, saves date-stamped backups, and returns DataLoaders.

Parameters:
  • filemap_path (str) – Path to the CSV filemap.

  • save_dir (str) – Directory where backup CSVs are written.

  • channels (int or list[int]) – Channel indices to load.

  • image_column (str, optional) – CSV column name for image paths. (default: "image")

  • mask_column (str, optional) – CSV column name for mask paths. (default: "mask")

  • validation_set_ratio (float, optional) – Fraction for validation. (default: 0.25)

  • test_set_ratio (float, optional) – Fraction for testing. (default: 0.1)

  • batch_size (int, optional) – Batch size. (default: 5)

  • num_workers (int, optional) – DataLoader worker processes. (default: 32)

  • pin_memory (bool, optional) – Whether to pin memory. (default: True)

  • train_on_tiles (bool, optional) – Whether to train on tiles. (default: True)

  • tiler_params (dict, optional) – Tile parameters; see create_segmentation_dataloaders(). (default: None)

  • training_transform (callable, optional) – Override training transform. (default: None)

  • validation_transform (callable, optional) – Override validation transform. (default: None)

Returns:

(training_dataframe, validation_dataframe, train_loader, val_loader).

Return type:

tuple

towbintools.deep_learning.utils.dataset.create_segmentation_training_dataframes(image_directories, mask_directories, save_dir, validation_set_ratio=0.25, test_set_ratio=0.1)[source]

Build training and validation DataFrames from image and mask directories.

Pairs files by sorted order within each directory pair. Saves date-stamped CSV backups of all three splits to save_dir/database_backup/.

Parameters:
  • image_directories (str or list[str]) – Directories containing image files.

  • mask_directories (str or list[str]) – Directories containing mask files, paired with image_directories.

  • save_dir (str) – Directory where backup CSVs are written.

  • validation_set_ratio (float, optional) – Fraction of data for validation. (default: 0.25)

  • test_set_ratio (float, optional) – Fraction of data for testing. (default: 0.1)

Returns:

(training_dataframe, validation_dataframe).

Return type:

tuple[pd.DataFrame, pd.DataFrame]

Raises:

AssertionError – If the number of images and masks in a directory pair does not match.

towbintools.deep_learning.utils.dataset.create_segmentation_training_dataframes_and_dataloaders(image_directories, mask_directories, save_dir, channels, validation_set_ratio=0.25, test_set_ratio=0.1, batch_size=5, num_workers=32, pin_memory=True, train_on_tiles=True, tiler_params=None, training_transform=None, validation_transform=None)[source]

Build training DataFrames and DataLoaders for segmentation in one step.

Combines create_segmentation_training_dataframes() and create_segmentation_dataloaders().

Parameters:
  • image_directories (str or list[str]) – Directories containing image files.

  • mask_directories (str or list[str]) – Directories containing mask files.

  • save_dir (str) – Directory where backup CSVs are written.

  • channels (int or list[int]) – Channel indices to load.

  • validation_set_ratio (float, optional) – Fraction for validation. (default: 0.25)

  • test_set_ratio (float, optional) – Fraction for testing. (default: 0.1)

  • batch_size (int, optional) – Batch size. (default: 5)

  • num_workers (int, optional) – DataLoader worker processes. (default: 32)

  • pin_memory (bool, optional) – Whether to pin memory. (default: True)

  • train_on_tiles (bool, optional) – Whether to train on tiles. (default: True)

  • tiler_params (dict, optional) – Tile parameters; see create_segmentation_dataloaders(). (default: None)

  • training_transform (callable, optional) – Override training transform. (default: None)

  • validation_transform (callable, optional) – Override validation transform. (default: None)

Returns:

(training_dataframe, validation_dataframe, train_loader, val_loader).

Return type:

tuple

towbintools.deep_learning.utils.dataset.get_unique_shapes_from_tiffs(image_paths=list[str], channels_to_keep: list[int] | None = None) ndarray[source]

Get unique shapes from a list of TIFF images in parallel.

Parameters:
  • image_paths (List[str]) – List of image paths to extract shapes from

  • channels_to_keep (Optional[list[int]]) – List of channel indices to keep. If None, all channels are considered.

Returns:

Unique image shapes found in the dataframe

Return type:

np.ndarray

towbintools.deep_learning.utils.dataset.resize(src, dsize[, dst[, fx[, fy[, interpolation]]]]) dst

. @brief Resizes an image. . . The function resize resizes the image src down to or up to the specified size. Note that the . initial dst type or size are not taken into account. Instead, the size and type are derived from . the src,`dsize`,`fx`, and fy. If you want to resize src so that it fits the pre-created dst, . you may call the function as follows: . @code . // explicitly specify dsize=dst.size(); fx and fy will be computed from that. . resize(src, dst, dst.size(), 0, 0, interpolation); . @endcode . If you want to decimate the image by factor of 2 in each direction, you can call the function this . way: . @code . // specify fx and fy and let the function compute the destination image size. . resize(src, dst, Size(), 0.5, 0.5, interpolation); . @endcode . To shrink an image, it will generally look best with #INTER_AREA interpolation, whereas to . enlarge an image, it will generally look best with #INTER_CUBIC (slow) or #INTER_LINEAR . (faster but still looks OK). . . @param src input image. . @param dst output image; it has the size dsize (when it is non-zero) or the size computed from . src.size(), fx, and fy; the type of dst is the same as of src. . @param dsize output image size; if it equals zero (None in Python), it is computed as: . f[texttt{dsize = Size(round(fx*src.cols), round(fy*src.rows))}f] . Either dsize or both fx and fy must be non-zero. . @param fx scale factor along the horizontal axis; when it equals 0, it is computed as . f[texttt{(double)dsize.width/src.cols}f] . @param fy scale factor along the vertical axis; when it equals 0, it is computed as . f[texttt{(double)dsize.height/src.rows}f] . @param interpolation interpolation method, see #InterpolationFlags . . @sa warpAffine, warpPerspective, remap

towbintools.deep_learning.utils.dataset.split_dataset(dataframe, validation_size, test_size)[source]

Split a DataFrame (or CSV path) into training, validation, and test sets.

Parameters:
  • dataframe (pd.DataFrame or str) – DataFrame or path to a CSV file.

  • validation_size (float) – Fraction of the total data for validation.

  • test_size (float) – Fraction of the total data for testing.

Returns:

(train_dataframe, validation_dataframe, test_dataframe).

Return type:

tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]

Raises:

ValueError – If validation_size + test_size >= 1.0.

augmentation

class towbintools.deep_learning.utils.augmentation.CustomFlip(*args, **kwargs)[source]

Bases: MapTransform, Randomizable

MONAI Randomizable MapTransform that randomly flips arrays along spatial axes.

With probability prob, flips along one of: height axis only, width axis only, or both height and width axes.

Parameters:
  • keys (sequence) – Keys in the data dictionary to apply the transform to.

  • prob (float, optional) – Probability of applying the flip. (default: 0.75)

randomize(data=None)[source]

Within this method, self.R should be used, instead of np.random, to introduce random factors.

all self.R calls happen here so that we have a better chance to identify errors of sync the random state.

This method can generate the random factors based on properties of the input data.

Raises:

NotImplementedError – When the subclass does not override this method.

class towbintools.deep_learning.utils.augmentation.CustomRotate90(*args, **kwargs)[source]

Bases: MapTransform, Randomizable

MONAI Randomizable MapTransform that randomly rotates arrays by 90°, 180°, or 270°.

With probability prob, rotates in the spatial (H, W) plane by a randomly chosen multiple of 90°.

Parameters:
  • keys (sequence) – Keys in the data dictionary to apply the transform to.

  • prob (float, optional) – Probability of applying the rotation. (default: 0.75)

randomize(data=None)[source]

Within this method, self.R should be used, instead of np.random, to introduce random factors.

all self.R calls happen here so that we have a better chance to identify errors of sync the random state.

This method can generate the random factors based on properties of the input data.

Raises:

NotImplementedError – When the subclass does not override this method.

class towbintools.deep_learning.utils.augmentation.EnforceNChannels(*args, **kwargs)[source]

Bases: MapTransform

MONAI MapTransform that tiles channel data to reach exactly n_channels.

Delegates to _enforce_n_channels() for the actual tiling logic.

Parameters:
  • keys (sequence) – Keys in the data dictionary to apply the transform to.

  • n_channels (int) – Target number of channels.

class towbintools.deep_learning.utils.augmentation.NormalizeDataRange(*args, **kwargs)[source]

Bases: MapTransform

MONAI MapTransform that normalizes arrays to [0, 1] by their per-sample min/max.

Parameters:

keys (sequence) – Keys in the data dictionary to apply the transform to.

class towbintools.deep_learning.utils.augmentation.NormalizeMeanStd(*args, **kwargs)[source]

Bases: MapTransform

MONAI MapTransform that standardizes arrays using a fixed mean and std.

Parameters:
  • keys (sequence) – Keys in the data dictionary to apply the transform to.

  • mean (float) – Mean to subtract.

  • std (float) – Standard deviation to divide by.

class towbintools.deep_learning.utils.augmentation.NormalizePercentile(*args, **kwargs)[source]

Bases: MapTransform

MONAI MapTransform that normalizes arrays using percentile clipping (csbdeep).

Clips values at the lo-th and hi-th percentiles and rescales to [0, 1].

Parameters:
  • keys (sequence) – Keys in the data dictionary to apply the transform to.

  • lo (float) – Lower percentile for clipping (e.g. 1 for the 1st percentile).

  • hi (float) – Upper percentile for clipping (e.g. 99 for the 99th percentile).

  • axis (int or None, optional) – Axis over which to compute percentiles. None uses the global min/max. (default: None)

towbintools.deep_learning.utils.augmentation.get_mean_and_std(image_path: str) tuple[float, float][source]

Compute the mean and standard deviation of channel 2 in a TIFF image.

Parameters:

image_path (str) – Path to the TIFF image file.

Returns:

(mean, std) of the pixel values in channel 2.

Return type:

tuple[float, float]

towbintools.deep_learning.utils.augmentation.get_prediction_augmentation(normalization_type: str, **kwargs) Compose[source]

Build the MONAI transform pipeline for inference (normalization only).

Parameters:
  • normalization_type (str) – Normalization type passed to _build_normalization().

  • **kwargs – Additional parameters forwarded to the normalization transform and optionally enforce_n_channels (int).

Returns:

MONAI Compose pipeline ready for prediction.

Return type:

Compose

towbintools.deep_learning.utils.augmentation.get_prediction_augmentation_from_model(model, enforce_n_channels=None) Compose[source]

Build the inference transform pipeline from a model’s stored normalization config.

Reads the normalization attribute of model (a dict with at least a "type" key) and delegates to get_prediction_augmentation().

Parameters:
  • model – A model instance exposing a normalization dict attribute (e.g. a PretrainedSegmentationModel).

  • enforce_n_channels (int, optional) – If not None, tile channels to this count via EnforceNChannels. (default: None)

Returns:

MONAI Compose pipeline ready for prediction.

Return type:

Compose

towbintools.deep_learning.utils.augmentation.get_qc_training_augmentation(normalization_type: str, **kwargs) Compose[source]

Build the MONAI augmentation pipeline for quality-control model training.

Lighter than get_training_augmentation(): includes only random flips and normalization (no rotations or intensity transforms).

Parameters:
  • normalization_type (str) – Normalization type passed to _build_normalization().

  • **kwargs – Additional parameters forwarded to the normalization transform and optionally enforce_n_channels (int).

Returns:

MONAI Compose pipeline ready for QC training.

Return type:

Compose

towbintools.deep_learning.utils.augmentation.get_training_augmentation(normalization_type: str, **kwargs) Compose[source]

Build the MONAI augmentation pipeline for segmentation training.

Includes random flips, random 90° rotations, random Gaussian smoothing, random Gaussian sharpening, and normalization.

Parameters:
  • normalization_type (str) – Normalization type passed to _build_normalization() ("data_range", "mean_std", "percentile").

  • **kwargs – Additional parameters forwarded to the normalization transform and optionally enforce_n_channels (int) to tile channels.

Returns:

MONAI Compose pipeline ready for training.

Return type:

Compose

loss

class towbintools.deep_learning.utils.loss.BCELossWithIgnore(ignore_index=-1)[source]

Bases: Module

Binary cross-entropy loss that ignores a specified target value.

Computes element-wise BCE loss, zeroes out entries where target == ignore_index, and returns the mean over non-ignored elements.

forward(input, target)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class towbintools.deep_learning.utils.loss.FocalTverskyLoss(ignore_index=-1, smooth=100, alpha=0.3, beta=0.7, gamma=1.3333333333333333, activation=True)[source]

Bases: Module

Focal Tversky loss for binary segmentation with class-imbalance handling.

Combines the Tversky index (a generalization of Dice that independently weights false positives and false negatives) with a focal exponent to down-weight easy examples and focus training on hard ones.

Reference: Abraham & Khan (2019), arXiv:1810.07842.

forward(inputs, targets)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class towbintools.deep_learning.utils.loss.MultiClassFocalLoss(alpha: Tensor | None = None, gamma: float = 2.0, reduction: str = 'mean', ignore_index: int = -1)[source]

Bases: Module

Focal loss for multi-class classification.

Extends cross-entropy with a modulating factor (1 - p_t)^gamma that down-weights well-classified examples, focusing training on hard negatives. Supports per-class weights and an ignore index.

Reference: Lin et al. (2017), arXiv:1708.02002.

forward(x: Tensor, y: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class towbintools.deep_learning.utils.loss.PeakWeightedMSELoss(ignore_index=-1, peak_weight=3.0)[source]

Bases: Module

MSE loss with additional weight on high-value (peak) target positions.

Assigns a per-element weight of 1 + peak_weight * target so that positions with larger target values (peaks in a heatmap) contribute more to the loss. Designed for 1D keypoint heatmap regression.

forward(input, target)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

models and architectures

class towbintools.deep_learning.architectures.models.ClassificationModel(architecture, input_channels, classes, learning_rate, normalization)[source]

Bases: LightningModule

PyTorch Lightning module for image classification using a pretrained backbone.

Uses timm to load the specified architecture with ImageNet-pretrained weights. Applies BCEWithLogitsLoss + BinaryF1Score for binary tasks, or CrossEntropyLoss + MulticlassF1Score for multiclass tasks.

Parameters:
  • architecture (str) – timm model name (e.g. "efficientnet_b0").

  • input_channels (int) – Number of input image channels.

  • classes (list[str]) – Class labels; len(classes) determines binary vs multiclass.

  • learning_rate (float) – Learning rate for the Adam optimizer.

  • normalization (dict) – Normalization config stored as a hyperparameter and used at inference time to reconstruct the preprocessing pipeline.

configure_optimizers()[source]

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Returns:

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.

  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.

  • If you need to control how often the optimizer steps, override the optimizer_step() hook.

forward(x)[source]

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

log_tb_images(viz_batch) None[source]
training_step(batch)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

validation_step(batch)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})

Note

If you don’t need to validate you don’t need to implement this method.

Note

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

class towbintools.deep_learning.architectures.models.KeypointDetection1DModel(input_channels, n_classes, learning_rate, architecture='UnetPlusPlus', activation='sigmoid', criterion=None)[source]

Bases: LightningModule

PyTorch Lightning module for 1D keypoint detection using a U-Net architecture.

Operates on 1D sequences (e.g. straightened worm fluorescence profiles). Supports "Unet", "AttentionUnet", and "UnetPlusPlus" 1D architectures. Uses PeakWeightedMSELoss by default.

Parameters:
  • input_channels (int) – Number of input sequence channels.

  • n_classes (int) – Number of keypoint classes (output channels).

  • learning_rate (float) – Learning rate for the Adam optimizer.

  • architecture (str, optional) – Architecture name; one of "Unet", "AttentionUnet", or "UnetPlusPlus". (default: "UnetPlusPlus")

  • activation (str, optional) – Output activation; one of "relu", "leaky_relu", "sigmoid", or "none". (default: "sigmoid")

  • criterion (nn.Module, optional) – Loss function. If None, PeakWeightedMSELoss is used. (default: None)

Raises:

ValueError – If architecture or activation is not supported.

configure_optimizers()[source]

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Returns:

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.

  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.

  • If you need to control how often the optimizer steps, override the optimizer_step() hook.

forward(x)[source]

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

predict_step(batch)[source]

Step function called during predict(). By default, it calls forward(). Override to add any processing logic.

The predict_step() is used to scale inference on multi-devices.

To prevent an OOM error, it is possible to use BasePredictionWriter callback to write the predictions to disk or database after each batch or on epoch end.

The BasePredictionWriter should be used while using a spawn based accelerator. This happens for Trainer(strategy="ddp_spawn") or training on 8 TPU cores with Trainer(accelerator="tpu", devices=8) as predictions won’t be returned.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

Predicted output (optional).

Example

class MyModel(LightningModule):

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

dm = ...
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=2)
predictions = trainer.predict(model, dm)
training_step(batch)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

validation_step(batch)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})

Note

If you don’t need to validate you don’t need to implement this method.

Note

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

class towbintools.deep_learning.architectures.models.SegmentationModel(input_channels, n_classes, learning_rate, architecture, encoder, pretrained_weights, normalization, criterion=None, ignore_index=None)[source]

Bases: LightningModule

PyTorch Lightning module for image segmentation using a pretrained encoder.

Uses segmentation_models_pytorch to build an encoder–decoder model. For binary tasks (n_classes == 1): sigmoid activation + FocalTverskyLoss + BinaryF1Score. For multiclass tasks: softmax activation + MultiClassFocalLoss + MulticlassF1Score.

Parameters:
  • input_channels (int) – Number of input image channels.

  • n_classes (int) – Number of foreground segmentation classes.

  • learning_rate (float) – Learning rate for the Adam optimizer.

  • architecture (str) – smp architecture name (e.g. "Unet").

  • encoder (str) – Encoder backbone name (e.g. "resnet34").

  • pretrained_weights (str) – Dataset the encoder was pretrained on (e.g. "imagenet").

  • normalization (dict) – Normalization config stored as a hyperparameter and used at inference time to reconstruct the preprocessing pipeline.

  • criterion (nn.Module, optional) – Loss function. If None, FocalTverskyLoss is used for binary tasks and MultiClassFocalLoss for multiclass. (default: None)

  • ignore_index (int, optional) – Target value to ignore in the loss and F1 score. (default: None)

configure_optimizers()[source]

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Returns:

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.

  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.

  • If you need to control how often the optimizer steps, override the optimizer_step() hook.

forward(x)[source]

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

predict_step(batch)[source]

Step function called during predict(). By default, it calls forward(). Override to add any processing logic.

The predict_step() is used to scale inference on multi-devices.

To prevent an OOM error, it is possible to use BasePredictionWriter callback to write the predictions to disk or database after each batch or on epoch end.

The BasePredictionWriter should be used while using a spawn based accelerator. This happens for Trainer(strategy="ddp_spawn") or training on 8 TPU cores with Trainer(accelerator="tpu", devices=8) as predictions won’t be returned.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

Predicted output (optional).

Example

class MyModel(LightningModule):

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

dm = ...
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=2)
predictions = trainer.predict(model, dm)
training_step(batch)[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

validation_step(batch)[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

# if you have one val dataloader:
def validation_step(self, batch, batch_idx): ...


# if you have multiple val dataloaders:
def validation_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single validation dataset
def validation_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'val_loss': loss, 'val_acc': val_acc})

If you pass in multiple val dataloaders, validation_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple validation dataloaders
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})

Note

If you don’t need to validate you don’t need to implement this method.

Note

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

class towbintools.deep_learning.architectures.archs.AttentionBlock1D(F_g, F_l, F_int)[source]

Bases: Module

1D attention gate for use in AttentionUnet1D.

Computes a soft attention map from a gating signal g (from the decoder) and a skip-connection feature map x (from the encoder). The output is x weighted element-wise by the attention coefficients.

Parameters:
  • F_g (int) – Number of channels in the gating signal g.

  • F_l (int) – Number of channels in the skip-connection feature map x.

  • F_int (int) – Number of intermediate channels used to compute the attention map.

forward(g, x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class towbintools.deep_learning.architectures.archs.AttentionUnet1D(num_classes, input_channels=1, **kwargs)[source]

Bases: Module

1D U-Net with attention gates for sequence segmentation and keypoint detection.

Extends Unet1D by inserting an AttentionBlock1D at each decoder stage. The attention gates suppress irrelevant activations in the encoder skip connections before concatenation. Filter counts follow [64, 128, 256, 512, 1024].

Parameters:
  • num_classes (int) – Number of output classes (output channels).

  • input_channels (int, optional) – Number of input sequence channels. (default: 1)

  • **kwargs – Ignored; accepted for API compatibility.

forward(input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class towbintools.deep_learning.architectures.archs.Unet1D(num_classes, input_channels=1, **kwargs)[source]

Bases: Module

1D U-Net for sequence segmentation and keypoint detection.

The 1D analogue of Unet, operating on 1D sequences with MaxPool1d downsampling and linear upsampling. Filter counts follow [64, 128, 256, 512, 1024].

Parameters:
  • num_classes (int) – Number of output classes (output channels).

  • input_channels (int, optional) – Number of input sequence channels. (default: 1)

  • **kwargs – Ignored; accepted for API compatibility.

forward(input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class towbintools.deep_learning.architectures.archs.UnetPlusPlus1D(num_classes, input_channels=1, deep_supervision=False, **kwargs)[source]

Bases: Module

1D UNet++ for sequence segmentation with optional deep supervision.

The 1D analogue of UnetPlusPlus, with dense nested skip connections between all encoder and decoder nodes at the same resolution. Uses Conv1d, MaxPool1d, and linear upsampling. Filter counts follow [64, 128, 256, 512, 1024]. When deep_supervision=True, returns a list of four outputs from intermediate decoder nodes; otherwise returns a single output.

Parameters:
  • num_classes (int) – Number of output classes (output channels).

  • input_channels (int, optional) – Number of input sequence channels. (default: 1)

  • deep_supervision (bool, optional) – If True, return outputs from all intermediate decoder stages. (default: False)

  • **kwargs – Ignored; accepted for API compatibility.

forward(input)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class towbintools.deep_learning.architectures.archs.VGGBlock1D(in_channels, middle_channels, out_channels)[source]

Bases: Module

One-dimensional two-layer convolutional block with BatchNorm and ReLU.

The 1D analogue of VGGBlock, using Conv1d. Used as the basic building block in Unet1D, AttentionUnet1D, and UnetPlusPlus1D.

Parameters:
  • in_channels (int) – Number of input channels.

  • middle_channels (int) – Number of channels after the first convolution.

  • out_channels (int) – Number of output channels.

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

util

towbintools.deep_learning.utils.util.adjust_tensor_dimensions(source_tensor: Tensor, target_tensor_shape: tuple) Tensor[source]

Squeeze source_tensor then unsqueeze it to match target_tensor_shape.

Parameters:
  • source_tensor (Tensor) – Tensor to reshape.

  • target_tensor_shape (tuple[int, ...]) – Desired shape of the output tensor.

Returns:

Reshaped tensor compatible with target_tensor_shape.

Return type:

Tensor

towbintools.deep_learning.utils.util.create_lightweight_checkpoint(input_path: str, output_path: str) dict[source]

Load a PyTorch Lightning checkpoint and save a lightweight version.

Keeps only state_dict, hyper_parameters, and a small set of optional metadata keys (epoch, global_step, pytorch-lightning_version), discarding optimizer state and other large tensors.

Parameters:
  • input_path (str) – Path to the full .ckpt checkpoint file.

  • output_path (str) – Destination path for the lightweight checkpoint.

Returns:

The lightweight checkpoint dictionary that was saved to

output_path.

Return type:

dict

towbintools.deep_learning.utils.util.divide_batch(batch: Tensor, n: int) Iterator[Tensor][source]

Yield successive mini-batches of size n from batch.

Parameters:
  • batch (Tensor) – Input batch tensor with the batch dimension as axis 0.

  • n (int) – Mini-batch size.

Yields:

Tensor – Slice of batch along axis 0 of at most n samples.

towbintools.deep_learning.utils.util.get_closest_lower_multiple(dim: int | float, multiple: int) int[source]

Round dim down to the nearest multiple of multiple.

Parameters:
  • dim (int or float) – Value to round down.

  • multiple (int) – The multiple to round to.

Returns:

Largest multiple of multiple that is <= dim.

Return type:

int

towbintools.deep_learning.utils.util.get_closest_upper_multiple(dim: int | float, multiple: int) int[source]

Round dim up to the nearest multiple of multiple.

Parameters:
  • dim (int or float) – Value to round up.

  • multiple (int) – The multiple to round to.

Returns:

Smallest multiple of multiple that is >= dim.

Return type:

int

towbintools.deep_learning.utils.util.get_input_channels_from_checkpoint(checkpoint_path: str) int[source]

Infer the number of input channels from a PyTorch Lightning checkpoint.

Searches the state dict for the first convolutional weight tensor and returns its in_channels dimension (weight.shape[1]).

Parameters:

checkpoint_path (str) – Path to the .ckpt checkpoint file.

Returns:

Number of input channels, or 0 if no convolutional layer was found.

Return type:

int

towbintools.deep_learning.utils.util.rename_keys_and_adjust_dimensions(model: Module, pretrained_model: dict) dict[source]

Map pretrained weights into a model with differently named or shaped parameters.

Pairs keys from model.state_dict() with keys from pretrained_model by position, adjusting tensor shapes via adjust_tensor_dimensions() when they differ.

Parameters:
  • model (nn.Module) – Target model whose state dict keys define the mapping.

  • pretrained_model (dict) – Source state dict (e.g. from torch.load(...)["state_dict"]).

Returns:

New state dict compatible with model.load_state_dict().

Return type:

dict

Raises:

AssertionError – If the number of keys in model and pretrained_model do not match.