deep learning package
deep_learning_tools
- towbintools.deep_learning.deep_learning_tools.create_classification_model(architecture: str, input_channels: int, classes: list[str], learning_rate: float = 0.0001, checkpoint_path: str | None = None, normalization: dict = {'hi': 99, 'lo': 1, 'type': 'percentile'}) ClassificationModel[source]
Create a classification model.
- Parameters:
architecture (str) –
timmmodel name (e.g."efficientnet_b0").input_channels (int) – Number of input image channels.
classes (list[str]) – List of class label strings.
learning_rate (float, optional) – Learning rate for the Adam optimizer. (default: 1e-4)
checkpoint_path (str, optional) – Path to a
.ckptcheckpoint; if provided, the model is loaded from the checkpoint and all other arguments are ignored. (default: None)normalization (dict, optional) – Normalization config stored as a hyperparameter for inference. (default: percentile 1–99)
- Returns:
Constructed or loaded classification model.
- Return type:
- towbintools.deep_learning.deep_learning_tools.create_keypoint_detection_model(architecture: str, input_channels: int, n_classes: int, learning_rate: float = 0.0001, checkpoint_path: str | None = None, criterion: Any | None = None, activation: str = 'relu') KeypointDetection1DModel[source]
Create a 1D keypoint detection model.
- Parameters:
architecture (str) – Architecture name; one of
"Unet","AttentionUnet", or"UnetPlusPlus".input_channels (int) – Number of input sequence channels.
n_classes (int) – Number of keypoint classes (output channels).
learning_rate (float, optional) – Learning rate for the Adam optimizer. (default: 1e-4)
checkpoint_path (str, optional) – Path to a
.ckptcheckpoint; if provided, the model is loaded from the checkpoint and all other arguments are ignored. (default: None)criterion (nn.Module, optional) – Loss function. If
None,PeakWeightedMSELossis used. (default: None)activation (str, optional) – Output activation; one of
"relu","leaky_relu","sigmoid", or"none". (default:"relu")
- Returns:
Constructed or loaded keypoint detection model.
- Return type:
- towbintools.deep_learning.deep_learning_tools.create_segmentation_model(input_channels: int = 1, n_classes: int = 1, architecture: str = 'UnetPlusPlus', encoder: str = 'efficientnet-b4', pretrained_weights: str = 'image-micronet', normalization: dict = {'hi': 99, 'lo': 1, 'type': 'percentile'}, learning_rate: float = 1e-05, checkpoint_path: str | None = None, reset_optimizer: bool = True, criterion: Any | None = None) SegmentationModel[source]
Create a segmentation model with a pretrained encoder.
- Parameters:
input_channels (int, optional) – Number of input image channels. (default: 1)
n_classes (int, optional) – Number of foreground segmentation classes. (default: 1)
architecture (str, optional) –
smparchitecture name. (default:"UnetPlusPlus")encoder (str, optional) – Encoder backbone name. (default:
"efficientnet-b4")pretrained_weights (str, optional) – Dataset the encoder was pretrained on. (default:
"image-micronet")normalization (dict, optional) – Normalization config. (default: percentile 1–99)
learning_rate (float, optional) – Learning rate for the Adam optimizer. (default: 1e-5)
checkpoint_path (str, optional) – Path to a
.ckptcheckpoint. If provided, the model is loaded from the checkpoint, thenlearning_rateandnormalizationare updated. (default: None)reset_optimizer (bool, optional) – If
True, discard the optimizer state from the checkpoint. (default: True)criterion (nn.Module, optional) – Loss function. If
None, a default is chosen based onn_classes. (default: None)
- Returns:
Constructed or loaded segmentation model.
- Return type:
- Raises:
ValueError – If the checkpoint architecture or encoder does not match the requested one.
- towbintools.deep_learning.deep_learning_tools.load_keypoint_detection_model_from_checkpoint(checkpoint_path: str) KeypointDetection1DModel[source]
Load a 1D keypoint detection model from a checkpoint.
- Parameters:
checkpoint_path (str) – Path to a
.ckptcheckpoint file.- Returns:
Loaded keypoint detection model.
- Return type:
- Raises:
ValueError – If the model cannot be loaded from the checkpoint.
- towbintools.deep_learning.deep_learning_tools.load_segmentation_model_from_checkpoint(checkpoint_path: str) SegmentationModel[source]
Load a segmentation model from a checkpoint.
First tries a direct load; if that fails (e.g. mismatched
input_channelsin the checkpoint metadata), infers the channel count from the checkpoint weights and retries withpretrained_weights=None.- Parameters:
checkpoint_path (str) – Path to a
.ckptcheckpoint file.- Returns:
Loaded segmentation model.
- Return type:
- Raises:
ValueError – If both loading attempts fail.
dataset
- class towbintools.deep_learning.utils.dataset.ClassificationDataset(dataset, channels, n_classes, class_column='class', image_column='image', transform=None)[source]
Bases:
DatasetPyTorch Dataset for image classification training.
Loads images from file paths and their corresponding class labels. For multi-class problems (n_classes > 2) labels are one-hot encoded.
- Parameters:
dataset (pd.DataFrame) – DataFrame with
image_columnandclass_columncolumns.channels (int or list[int]) – Channel index (or indices) to load.
n_classes (int) – Total number of classes. Labels are one-hot encoded when n_classes > 2.
class_column (str, optional) – Column name for class labels. (default:
"class")image_column (str, optional) – Column name for image paths. (default:
"image")transform (callable, optional) – MONAI Compose transform applied to
{"image": ...}. (default: None)
- class towbintools.deep_learning.utils.dataset.KeypointDetection1DPredictionDataset(inputs, enforce_divisibility_by=32, resize_method='pad')[source]
Bases:
DatasetPyTorch Dataset for 1D keypoint detection inference.
Stores input time-series for prediction. The collate function pads or crops all series to the same length and replaces NaN-containing series with zeros, returning their indices as
invalid_series_index.- Parameters:
inputs (array-like) – Sequence of 1D (or 2D) input series arrays.
enforce_divisibility_by (int, optional) – Target batch length is rounded to a multiple of this value. (default: 32)
resize_method (str, optional) –
"pad"or"crop". (default:"pad")
- class towbintools.deep_learning.utils.dataset.KeypointDetection1DTrainingDataset(inputs, targets, enforce_divisibility_by=32, resize_method='pad')[source]
Bases:
DatasetPyTorch Dataset for 1D keypoint detection training.
Stores pairs of input time-series and target heatmaps. The collate function pads or crops all series in a batch to the same length (a multiple of
enforce_divisibility_by) and drops samples containing NaN values.- Parameters:
inputs (array-like) – Sequence of 1D (or 2D) input series arrays.
targets (array-like) – Sequence of target heatmap arrays aligned with
inputs.enforce_divisibility_by (int, optional) – Target batch length is rounded to a multiple of this value. (default: 32)
resize_method (str, optional) –
"pad"or"crop". (default:"pad")
- class towbintools.deep_learning.utils.dataset.QualityControlDataset(image_paths, mask_paths, channels, labels, classes, enforce_divisibility_by=32, resize_method='pad', transform=None)[source]
Bases:
DatasetPyTorch Dataset for quality-control classification training.
Loads image + mask pairs and their quality labels. When
mask_pathsis provided, image and mask are concatenated along the channel axis before the transform. The collate function discards samples whose mask has no foreground.- Parameters:
image_paths (list[str]) – Paths to image files.
mask_paths (list[str] or None) – Paths to mask files. Pass
Noneor an empty list for image-only mode.channels (int or list[int]) – Channel indices to load from images.
labels (list) – Class labels (integers or strings matching
classes).classes (list) – Ordered list of class names.
enforce_divisibility_by (int, optional) – Batch spatial dimensions are resized to multiples of this value. (default: 32)
resize_method (str, optional) –
"pad"or"crop". (default:"pad")transform (callable, optional) – MONAI Compose transform. (default: None)
- class towbintools.deep_learning.utils.dataset.QualityControlPredictionDataset(image_paths, mask_paths, channels, enforce_divisibility_by=32, resize_method='pad', transform=None)[source]
Bases:
DatasetPyTorch Dataset for quality-control classification inference.
Loads image + mask pairs and concatenates them along the channel axis. The collate function tracks samples rejected due to empty or failed masks and returns their indices alongside valid batches.
- Parameters:
image_paths (list[str]) – Paths to image files.
mask_paths (list[str]) – Paths to mask files.
channels (int or list[int]) – Channel indices to load from images.
enforce_divisibility_by (int, optional) – Batch spatial dimensions are resized to multiples of this value. (default: 32)
resize_method (str, optional) –
"pad"or"crop". (default:"pad")transform (callable, optional) – MONAI Compose transform. (default: None)
- class towbintools.deep_learning.utils.dataset.SegmentationDataset(dataset, channels, mask_column='mask', image_column='image', transform=None, enforce_divisibility_by=32, pad_or_crop='pad', mask_pad_value=-1)[source]
Bases:
DatasetPyTorch Dataset for full-image segmentation training.
Loads image and mask pairs from file paths stored in a DataFrame. Images are resized (padded or cropped) in the collate function to a divisibility-enforced common size within each batch.
- Parameters:
dataset (pd.DataFrame) – DataFrame with
image_columnandmask_columncolumns containing file paths.channels (int or list[int]) – Channel indices to load from the image.
mask_column (str, optional) – Column name for mask paths. (default:
"mask")image_column (str, optional) – Column name for image paths. (default:
"image")transform (callable, optional) – MONAI Compose transform. (default: None)
enforce_divisibility_by (int, optional) – Batch images are resized so their spatial dimensions are multiples of this value. (default: 32)
pad_or_crop (str, optional) – Whether to pad (
"pad") or crop ("crop") images to the common batch size. (default:"pad")mask_pad_value (int, optional) – Fill value used when padding masks. (default: -1)
- class towbintools.deep_learning.utils.dataset.SegmentationPredictionDataset(image_paths, channels, transform=None, enforce_divisibility_by=32, scale_factor=1.0, pad_or_crop='pad')[source]
Bases:
DatasetPyTorch Dataset for segmentation inference.
Loads images from a list of file paths, optionally rescales them, and pads or crops batches to a common size that is a multiple of
enforce_divisibility_by. The collate function returns image paths, resized tensors, original shapes, and indices of images that failed to load.- Parameters:
image_paths (list[str]) – Paths to the image files.
channels (int or list[int]) – Channel indices to load.
transform (callable, optional) – MONAI Compose transform. (default: None)
enforce_divisibility_by (int, optional) – Images are resized so their spatial dimensions are multiples of this value. (default: 32)
scale_factor (float, optional) – Isotropic rescaling factor applied to each image before batching. (default: 1.0)
pad_or_crop (str, optional) –
"pad"or"crop". (default:"pad")
- class towbintools.deep_learning.utils.dataset.StackPredictionDataset(stack, channels, transform=None, enforce_divisibility_by=32, pad_or_crop='pad', scale_factor=1.0)[source]
Bases:
DatasetPyTorch Dataset for plane-by-plane inference on a z-stack.
Accepts a z-stack as a NumPy array or a file path. Optionally downscales planes before loading. Each
__getitem__call returns a single plane, ready for inference.- Parameters:
stack (str or np.ndarray) – z-stack as an array of shape
(N, H, W)or(N, C, H, W), or a path to a TIFF file.channels (int, list[int], or None) – Channel indices to load when
stackis a file path.transform (callable, optional) – MONAI Compose transform applied per plane. (default: None)
enforce_divisibility_by (int, optional) – Spatial dimensions are resized to multiples of this value. (default: 32)
pad_or_crop (str, optional) –
"pad"or"crop". (default:"pad")scale_factor (float, optional) – Isotropic rescaling factor applied to each plane via
cv2.resize. (default: 1.0)
- class towbintools.deep_learning.utils.dataset.TiledSegmentationDataset(dataset, image_slicers, channels, mask_column='mask', image_column='image', transform=None)[source]
Bases:
DatasetPyTorch Dataset for tiled segmentation training.
Loads image and mask pairs from file paths stored in a DataFrame and returns a randomly selected tile from each sample. Tile boundaries are pre-computed using
pytorch_toolbelt.inference.ImageSlicer.- Parameters:
dataset (pd.DataFrame) – DataFrame with at least
image_columnandmask_columncolumns containing file paths.image_slicers (dict) – Mapping from image shape to
ImageSlicerinstance.channels (int or list[int]) – Channel indices to load from the image.
mask_column (str, optional) – DataFrame column name for mask paths. (default:
"mask")image_column (str, optional) – DataFrame column name for image paths. (default:
"image")transform (callable, optional) – MONAI Compose transform applied to the
{"image": ..., "mask": ...}dictionary. (default: None)
- towbintools.deep_learning.utils.dataset.create_classification_dataloaders(training_dataframe, validation_dataframe, channels, n_classes, batch_size=64, num_workers=32, pin_memory=True, training_transform=None, validation_transform=None)[source]
Create training and validation DataLoaders for image classification.
- Parameters:
training_dataframe (pd.DataFrame) – DataFrame with
"image"and"class"columns for training data.validation_dataframe (pd.DataFrame) – DataFrame with
"image"and"class"columns for validation data.channels (int or list[int]) – Channel indices to load.
n_classes (int) – Number of classes (labels are one-hot encoded when > 2).
batch_size (int, optional) – Batch size. (default: 64)
num_workers (int, optional) – DataLoader worker processes. (default: 32)
pin_memory (bool, optional) – Whether to pin memory. (default: True)
training_transform (callable, optional) – Override training transform. (default: None)
validation_transform (callable, optional) – Override validation transform. (default: None)
- Returns:
(train_loader, val_loader).- Return type:
tuple[DataLoader, DataLoader]
- towbintools.deep_learning.utils.dataset.create_classification_training_dataframes(ground_truth_csv_paths, image_columns, class_columns, save_dir, validation_set_ratio=0.25, test_set_ratio=0.1)[source]
Build training and validation DataFrames for classification from CSV ground-truth files.
Reads one or more CSV files, extracts image path and class label columns, concatenates them, splits into train/val/test sets, and saves date-stamped CSV backups.
- Parameters:
ground_truth_csv_paths (str or list[str]) – Paths to ground-truth CSV files.
image_columns (str or list[str]) – Column name(s) for image paths. A single string is broadcast to all CSVs.
class_columns (str or list[str]) – Column name(s) for class labels.
save_dir (str) – Directory where backup CSVs are written.
validation_set_ratio (float, optional) – Fraction for validation. (default: 0.25)
test_set_ratio (float, optional) – Fraction for testing. (default: 0.1)
- Returns:
(training_dataframe, validation_dataframe).- Return type:
tuple[pd.DataFrame, pd.DataFrame]
- towbintools.deep_learning.utils.dataset.create_segmentation_dataloaders(training_dataframe, validation_dataframe, channels, batch_size=5, num_workers=32, pin_memory=True, train_on_tiles=True, tiler_params=None, training_transform=None, validation_transform=None)[source]
Create training and validation DataLoaders for segmentation.
When
train_on_tilesisTrue, usesTiledSegmentationDatasetwith per-shapeImageSlicerobjects; otherwise usesSegmentationDatasetwith full images. Default transforms (percentile normalization) are applied if no transform is supplied.- Parameters:
training_dataframe (pd.DataFrame) – DataFrame with
"image"and"mask"columns for training data.validation_dataframe (pd.DataFrame) – DataFrame with
"image"and"mask"columns for validation data.channels (int or list[int]) – Channel indices to load.
batch_size (int, optional) – Batch size. (default: 5)
num_workers (int, optional) – DataLoader worker processes. (default: 32)
pin_memory (bool, optional) – Whether to pin memory. (default: True)
train_on_tiles (bool, optional) – If
True, sample random tiles; otherwise use full images. (default: True)tiler_params (dict, optional) – Required when
train_on_tilesisTrue; must contain"tile_size"and"tile_step"keys. (default: None)training_transform (callable, optional) – Override transform for training. (default: None)
validation_transform (callable, optional) – Override transform for validation. (default: None)
- Returns:
(train_loader, val_loader).- Return type:
tuple[DataLoader, DataLoader]
- towbintools.deep_learning.utils.dataset.create_segmentation_dataloaders_from_filemap(filemap_path, save_dir, channels, image_column='image', mask_column='mask', validation_set_ratio=0.25, test_set_ratio=0.1, batch_size=5, num_workers=32, pin_memory=True, train_on_tiles=True, tiler_params=None, training_transform=None, validation_transform=None)[source]
Build segmentation DataLoaders from a CSV filemap.
Reads a CSV at
filemap_path, renames the specified columns to"image"and"mask", splits into train/val/test sets, saves date-stamped backups, and returns DataLoaders.- Parameters:
filemap_path (str) – Path to the CSV filemap.
save_dir (str) – Directory where backup CSVs are written.
channels (int or list[int]) – Channel indices to load.
image_column (str, optional) – CSV column name for image paths. (default:
"image")mask_column (str, optional) – CSV column name for mask paths. (default:
"mask")validation_set_ratio (float, optional) – Fraction for validation. (default: 0.25)
test_set_ratio (float, optional) – Fraction for testing. (default: 0.1)
batch_size (int, optional) – Batch size. (default: 5)
num_workers (int, optional) – DataLoader worker processes. (default: 32)
pin_memory (bool, optional) – Whether to pin memory. (default: True)
train_on_tiles (bool, optional) – Whether to train on tiles. (default: True)
tiler_params (dict, optional) – Tile parameters; see
create_segmentation_dataloaders(). (default: None)training_transform (callable, optional) – Override training transform. (default: None)
validation_transform (callable, optional) – Override validation transform. (default: None)
- Returns:
(training_dataframe, validation_dataframe, train_loader, val_loader).- Return type:
tuple
- towbintools.deep_learning.utils.dataset.create_segmentation_training_dataframes(image_directories, mask_directories, save_dir, validation_set_ratio=0.25, test_set_ratio=0.1)[source]
Build training and validation DataFrames from image and mask directories.
Pairs files by sorted order within each directory pair. Saves date-stamped CSV backups of all three splits to
save_dir/database_backup/.- Parameters:
image_directories (str or list[str]) – Directories containing image files.
mask_directories (str or list[str]) – Directories containing mask files, paired with
image_directories.save_dir (str) – Directory where backup CSVs are written.
validation_set_ratio (float, optional) – Fraction of data for validation. (default: 0.25)
test_set_ratio (float, optional) – Fraction of data for testing. (default: 0.1)
- Returns:
(training_dataframe, validation_dataframe).- Return type:
tuple[pd.DataFrame, pd.DataFrame]
- Raises:
AssertionError – If the number of images and masks in a directory pair does not match.
- towbintools.deep_learning.utils.dataset.create_segmentation_training_dataframes_and_dataloaders(image_directories, mask_directories, save_dir, channels, validation_set_ratio=0.25, test_set_ratio=0.1, batch_size=5, num_workers=32, pin_memory=True, train_on_tiles=True, tiler_params=None, training_transform=None, validation_transform=None)[source]
Build training DataFrames and DataLoaders for segmentation in one step.
Combines
create_segmentation_training_dataframes()andcreate_segmentation_dataloaders().- Parameters:
image_directories (str or list[str]) – Directories containing image files.
mask_directories (str or list[str]) – Directories containing mask files.
save_dir (str) – Directory where backup CSVs are written.
channels (int or list[int]) – Channel indices to load.
validation_set_ratio (float, optional) – Fraction for validation. (default: 0.25)
test_set_ratio (float, optional) – Fraction for testing. (default: 0.1)
batch_size (int, optional) – Batch size. (default: 5)
num_workers (int, optional) – DataLoader worker processes. (default: 32)
pin_memory (bool, optional) – Whether to pin memory. (default: True)
train_on_tiles (bool, optional) – Whether to train on tiles. (default: True)
tiler_params (dict, optional) – Tile parameters; see
create_segmentation_dataloaders(). (default: None)training_transform (callable, optional) – Override training transform. (default: None)
validation_transform (callable, optional) – Override validation transform. (default: None)
- Returns:
(training_dataframe, validation_dataframe, train_loader, val_loader).- Return type:
tuple
- towbintools.deep_learning.utils.dataset.get_unique_shapes_from_tiffs(image_paths=list[str], channels_to_keep: list[int] | None = None) ndarray[source]
Get unique shapes from a list of TIFF images in parallel.
- Parameters:
image_paths (List[str]) – List of image paths to extract shapes from
channels_to_keep (Optional[list[int]]) – List of channel indices to keep. If None, all channels are considered.
- Returns:
Unique image shapes found in the dataframe
- Return type:
np.ndarray
- towbintools.deep_learning.utils.dataset.resize(src, dsize[, dst[, fx[, fy[, interpolation]]]]) dst
. @brief Resizes an image. . . The function resize resizes the image src down to or up to the specified size. Note that the . initial dst type or size are not taken into account. Instead, the size and type are derived from . the src,`dsize`,`fx`, and fy. If you want to resize src so that it fits the pre-created dst, . you may call the function as follows: . @code . // explicitly specify dsize=dst.size(); fx and fy will be computed from that. . resize(src, dst, dst.size(), 0, 0, interpolation); . @endcode . If you want to decimate the image by factor of 2 in each direction, you can call the function this . way: . @code . // specify fx and fy and let the function compute the destination image size. . resize(src, dst, Size(), 0.5, 0.5, interpolation); . @endcode . To shrink an image, it will generally look best with #INTER_AREA interpolation, whereas to . enlarge an image, it will generally look best with #INTER_CUBIC (slow) or #INTER_LINEAR . (faster but still looks OK). . . @param src input image. . @param dst output image; it has the size dsize (when it is non-zero) or the size computed from . src.size(), fx, and fy; the type of dst is the same as of src. . @param dsize output image size; if it equals zero (None in Python), it is computed as: . f[texttt{dsize = Size(round(fx*src.cols), round(fy*src.rows))}f] . Either dsize or both fx and fy must be non-zero. . @param fx scale factor along the horizontal axis; when it equals 0, it is computed as . f[texttt{(double)dsize.width/src.cols}f] . @param fy scale factor along the vertical axis; when it equals 0, it is computed as . f[texttt{(double)dsize.height/src.rows}f] . @param interpolation interpolation method, see #InterpolationFlags . . @sa warpAffine, warpPerspective, remap
- towbintools.deep_learning.utils.dataset.split_dataset(dataframe, validation_size, test_size)[source]
Split a DataFrame (or CSV path) into training, validation, and test sets.
- Parameters:
dataframe (pd.DataFrame or str) – DataFrame or path to a CSV file.
validation_size (float) – Fraction of the total data for validation.
test_size (float) – Fraction of the total data for testing.
- Returns:
(train_dataframe, validation_dataframe, test_dataframe).- Return type:
tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]
- Raises:
ValueError – If
validation_size + test_size >= 1.0.
augmentation
- class towbintools.deep_learning.utils.augmentation.CustomFlip(*args, **kwargs)[source]
Bases:
MapTransform,RandomizableMONAI Randomizable MapTransform that randomly flips arrays along spatial axes.
With probability
prob, flips along one of: height axis only, width axis only, or both height and width axes.- Parameters:
keys (sequence) – Keys in the data dictionary to apply the transform to.
prob (float, optional) – Probability of applying the flip. (default: 0.75)
- randomize(data=None)[source]
Within this method,
self.Rshould be used, instead of np.random, to introduce random factors.all
self.Rcalls happen here so that we have a better chance to identify errors of sync the random state.This method can generate the random factors based on properties of the input data.
- Raises:
NotImplementedError – When the subclass does not override this method.
- class towbintools.deep_learning.utils.augmentation.CustomRotate90(*args, **kwargs)[source]
Bases:
MapTransform,RandomizableMONAI Randomizable MapTransform that randomly rotates arrays by 90°, 180°, or 270°.
With probability
prob, rotates in the spatial (H, W) plane by a randomly chosen multiple of 90°.- Parameters:
keys (sequence) – Keys in the data dictionary to apply the transform to.
prob (float, optional) – Probability of applying the rotation. (default: 0.75)
- randomize(data=None)[source]
Within this method,
self.Rshould be used, instead of np.random, to introduce random factors.all
self.Rcalls happen here so that we have a better chance to identify errors of sync the random state.This method can generate the random factors based on properties of the input data.
- Raises:
NotImplementedError – When the subclass does not override this method.
- class towbintools.deep_learning.utils.augmentation.EnforceNChannels(*args, **kwargs)[source]
Bases:
MapTransformMONAI MapTransform that tiles channel data to reach exactly
n_channels.Delegates to
_enforce_n_channels()for the actual tiling logic.- Parameters:
keys (sequence) – Keys in the data dictionary to apply the transform to.
n_channels (int) – Target number of channels.
- class towbintools.deep_learning.utils.augmentation.NormalizeDataRange(*args, **kwargs)[source]
Bases:
MapTransformMONAI MapTransform that normalizes arrays to [0, 1] by their per-sample min/max.
- Parameters:
keys (sequence) – Keys in the data dictionary to apply the transform to.
- class towbintools.deep_learning.utils.augmentation.NormalizeMeanStd(*args, **kwargs)[source]
Bases:
MapTransformMONAI MapTransform that standardizes arrays using a fixed mean and std.
- Parameters:
keys (sequence) – Keys in the data dictionary to apply the transform to.
mean (float) – Mean to subtract.
std (float) – Standard deviation to divide by.
- class towbintools.deep_learning.utils.augmentation.NormalizePercentile(*args, **kwargs)[source]
Bases:
MapTransformMONAI MapTransform that normalizes arrays using percentile clipping (csbdeep).
Clips values at the
lo-th andhi-th percentiles and rescales to [0, 1].- Parameters:
keys (sequence) – Keys in the data dictionary to apply the transform to.
lo (float) – Lower percentile for clipping (e.g. 1 for the 1st percentile).
hi (float) – Upper percentile for clipping (e.g. 99 for the 99th percentile).
axis (int or None, optional) – Axis over which to compute percentiles.
Noneuses the global min/max. (default: None)
- towbintools.deep_learning.utils.augmentation.get_mean_and_std(image_path: str) tuple[float, float][source]
Compute the mean and standard deviation of channel 2 in a TIFF image.
- Parameters:
image_path (str) – Path to the TIFF image file.
- Returns:
(mean, std)of the pixel values in channel 2.- Return type:
tuple[float, float]
- towbintools.deep_learning.utils.augmentation.get_prediction_augmentation(normalization_type: str, **kwargs) Compose[source]
Build the MONAI transform pipeline for inference (normalization only).
- Parameters:
normalization_type (str) – Normalization type passed to
_build_normalization().**kwargs – Additional parameters forwarded to the normalization transform and optionally
enforce_n_channels(int).
- Returns:
MONAI Compose pipeline ready for prediction.
- Return type:
Compose
- towbintools.deep_learning.utils.augmentation.get_prediction_augmentation_from_model(model, enforce_n_channels=None) Compose[source]
Build the inference transform pipeline from a model’s stored normalization config.
Reads the
normalizationattribute ofmodel(a dict with at least a"type"key) and delegates toget_prediction_augmentation().- Parameters:
model – A model instance exposing a
normalizationdict attribute (e.g. aPretrainedSegmentationModel).enforce_n_channels (int, optional) – If not
None, tile channels to this count viaEnforceNChannels. (default: None)
- Returns:
MONAI Compose pipeline ready for prediction.
- Return type:
Compose
- towbintools.deep_learning.utils.augmentation.get_qc_training_augmentation(normalization_type: str, **kwargs) Compose[source]
Build the MONAI augmentation pipeline for quality-control model training.
Lighter than
get_training_augmentation(): includes only random flips and normalization (no rotations or intensity transforms).- Parameters:
normalization_type (str) – Normalization type passed to
_build_normalization().**kwargs – Additional parameters forwarded to the normalization transform and optionally
enforce_n_channels(int).
- Returns:
MONAI Compose pipeline ready for QC training.
- Return type:
Compose
- towbintools.deep_learning.utils.augmentation.get_training_augmentation(normalization_type: str, **kwargs) Compose[source]
Build the MONAI augmentation pipeline for segmentation training.
Includes random flips, random 90° rotations, random Gaussian smoothing, random Gaussian sharpening, and normalization.
- Parameters:
normalization_type (str) – Normalization type passed to
_build_normalization()("data_range","mean_std","percentile").**kwargs – Additional parameters forwarded to the normalization transform and optionally
enforce_n_channels(int) to tile channels.
- Returns:
MONAI Compose pipeline ready for training.
- Return type:
Compose
loss
- class towbintools.deep_learning.utils.loss.BCELossWithIgnore(ignore_index=-1)[source]
Bases:
ModuleBinary cross-entropy loss that ignores a specified target value.
Computes element-wise BCE loss, zeroes out entries where
target == ignore_index, and returns the mean over non-ignored elements.- forward(input, target)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class towbintools.deep_learning.utils.loss.FocalTverskyLoss(ignore_index=-1, smooth=100, alpha=0.3, beta=0.7, gamma=1.3333333333333333, activation=True)[source]
Bases:
ModuleFocal Tversky loss for binary segmentation with class-imbalance handling.
Combines the Tversky index (a generalization of Dice that independently weights false positives and false negatives) with a focal exponent to down-weight easy examples and focus training on hard ones.
Reference: Abraham & Khan (2019), arXiv:1810.07842.
- forward(inputs, targets)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class towbintools.deep_learning.utils.loss.MultiClassFocalLoss(alpha: Tensor | None = None, gamma: float = 2.0, reduction: str = 'mean', ignore_index: int = -1)[source]
Bases:
ModuleFocal loss for multi-class classification.
Extends cross-entropy with a modulating factor
(1 - p_t)^gammathat down-weights well-classified examples, focusing training on hard negatives. Supports per-class weights and an ignore index.Reference: Lin et al. (2017), arXiv:1708.02002.
- forward(x: Tensor, y: Tensor) Tensor[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class towbintools.deep_learning.utils.loss.PeakWeightedMSELoss(ignore_index=-1, peak_weight=3.0)[source]
Bases:
ModuleMSE loss with additional weight on high-value (peak) target positions.
Assigns a per-element weight of
1 + peak_weight * targetso that positions with larger target values (peaks in a heatmap) contribute more to the loss. Designed for 1D keypoint heatmap regression.- forward(input, target)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
models and architectures
- class towbintools.deep_learning.architectures.models.ClassificationModel(architecture, input_channels, classes, learning_rate, normalization)[source]
Bases:
LightningModulePyTorch Lightning module for image classification using a pretrained backbone.
Uses
timmto load the specified architecture with ImageNet-pretrained weights. AppliesBCEWithLogitsLoss+BinaryF1Scorefor binary tasks, orCrossEntropyLoss+MulticlassF1Scorefor multiclass tasks.- Parameters:
architecture (str) –
timmmodel name (e.g."efficientnet_b0").input_channels (int) – Number of input image channels.
classes (list[str]) – Class labels;
len(classes)determines binary vs multiclass.learning_rate (float) – Learning rate for the Adam optimizer.
normalization (dict) – Normalization config stored as a hyperparameter and used at inference time to reconstruct the preprocessing pipeline.
- configure_optimizers()[source]
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
- Returns:
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config).Dictionary, with an
"optimizer"key, and (optionally) a"lr_scheduler"key whose value is a single LR scheduler orlr_scheduler_config.None - Fit will run without any optimizer.
The
lr_scheduler_configis a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateauscheduler, Lightning requires that thelr_scheduler_configcontains the keyword"monitor"set to the metric name that the scheduler should be conditioned on.Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)in yourLightningModule.Note
Some things to know:
Lightning calls
.backward()and.step()automatically in case of automatic optimization.If a learning rate scheduler is specified in
configure_optimizers()with key"interval"(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16), Lightning will automatically handle the optimizer.If you use
torch.optim.LBFGS, Lightning handles the closure function automatically for you.If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.
If you need to control how often the optimizer steps, override the
optimizer_step()hook.
- forward(x)[source]
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- training_step(batch)[source]
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary which can include any keys, but must include the key'loss'in the case of automatic optimization.None- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:
def __init__(self): super().__init__() self.automatic_optimization = False # Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx): opt1, opt2 = self.optimizers() # do training_step with encoder ... opt1.step() # do training_step with decoder ... opt2.step()
Note
When
accumulate_grad_batches> 1, the loss returned here will be automatically normalized byaccumulate_grad_batchesinternally.
- validation_step(batch)[source]
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
- class towbintools.deep_learning.architectures.models.KeypointDetection1DModel(input_channels, n_classes, learning_rate, architecture='UnetPlusPlus', activation='sigmoid', criterion=None)[source]
Bases:
LightningModulePyTorch Lightning module for 1D keypoint detection using a U-Net architecture.
Operates on 1D sequences (e.g. straightened worm fluorescence profiles). Supports
"Unet","AttentionUnet", and"UnetPlusPlus"1D architectures. UsesPeakWeightedMSELossby default.- Parameters:
input_channels (int) – Number of input sequence channels.
n_classes (int) – Number of keypoint classes (output channels).
learning_rate (float) – Learning rate for the Adam optimizer.
architecture (str, optional) – Architecture name; one of
"Unet","AttentionUnet", or"UnetPlusPlus". (default:"UnetPlusPlus")activation (str, optional) – Output activation; one of
"relu","leaky_relu","sigmoid", or"none". (default:"sigmoid")criterion (nn.Module, optional) – Loss function. If
None,PeakWeightedMSELossis used. (default: None)
- Raises:
ValueError – If
architectureoractivationis not supported.
- configure_optimizers()[source]
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
- Returns:
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config).Dictionary, with an
"optimizer"key, and (optionally) a"lr_scheduler"key whose value is a single LR scheduler orlr_scheduler_config.None - Fit will run without any optimizer.
The
lr_scheduler_configis a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateauscheduler, Lightning requires that thelr_scheduler_configcontains the keyword"monitor"set to the metric name that the scheduler should be conditioned on.Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)in yourLightningModule.Note
Some things to know:
Lightning calls
.backward()and.step()automatically in case of automatic optimization.If a learning rate scheduler is specified in
configure_optimizers()with key"interval"(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16), Lightning will automatically handle the optimizer.If you use
torch.optim.LBFGS, Lightning handles the closure function automatically for you.If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.
If you need to control how often the optimizer steps, override the
optimizer_step()hook.
- forward(x)[source]
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- predict_step(batch)[source]
Step function called during
predict(). By default, it callsforward(). Override to add any processing logic.The
predict_step()is used to scale inference on multi-devices.To prevent an OOM error, it is possible to use
BasePredictionWritercallback to write the predictions to disk or database after each batch or on epoch end.The
BasePredictionWritershould be used while using a spawn based accelerator. This happens forTrainer(strategy="ddp_spawn")or training on 8 TPU cores withTrainer(accelerator="tpu", devices=8)as predictions won’t be returned.- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Predicted output (optional).
Example
class MyModel(LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0): return self(batch) dm = ... model = MyModel() trainer = Trainer(accelerator="gpu", devices=2) predictions = trainer.predict(model, dm)
- training_step(batch)[source]
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary which can include any keys, but must include the key'loss'in the case of automatic optimization.None- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:
def __init__(self): super().__init__() self.automatic_optimization = False # Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx): opt1, opt2 = self.optimizers() # do training_step with encoder ... opt1.step() # do training_step with decoder ... opt2.step()
Note
When
accumulate_grad_batches> 1, the loss returned here will be automatically normalized byaccumulate_grad_batchesinternally.
- validation_step(batch)[source]
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
- class towbintools.deep_learning.architectures.models.SegmentationModel(input_channels, n_classes, learning_rate, architecture, encoder, pretrained_weights, normalization, criterion=None, ignore_index=None)[source]
Bases:
LightningModulePyTorch Lightning module for image segmentation using a pretrained encoder.
Uses
segmentation_models_pytorchto build an encoder–decoder model. For binary tasks (n_classes == 1): sigmoid activation +FocalTverskyLoss+BinaryF1Score. For multiclass tasks: softmax activation +MultiClassFocalLoss+MulticlassF1Score.- Parameters:
input_channels (int) – Number of input image channels.
n_classes (int) – Number of foreground segmentation classes.
learning_rate (float) – Learning rate for the Adam optimizer.
architecture (str) –
smparchitecture name (e.g."Unet").encoder (str) – Encoder backbone name (e.g.
"resnet34").pretrained_weights (str) – Dataset the encoder was pretrained on (e.g.
"imagenet").normalization (dict) – Normalization config stored as a hyperparameter and used at inference time to reconstruct the preprocessing pipeline.
criterion (nn.Module, optional) – Loss function. If
None,FocalTverskyLossis used for binary tasks andMultiClassFocalLossfor multiclass. (default: None)ignore_index (int, optional) – Target value to ignore in the loss and F1 score. (default: None)
- configure_optimizers()[source]
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
- Returns:
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config).Dictionary, with an
"optimizer"key, and (optionally) a"lr_scheduler"key whose value is a single LR scheduler orlr_scheduler_config.None - Fit will run without any optimizer.
The
lr_scheduler_configis a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateauscheduler, Lightning requires that thelr_scheduler_configcontains the keyword"monitor"set to the metric name that the scheduler should be conditioned on.Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)in yourLightningModule.Note
Some things to know:
Lightning calls
.backward()and.step()automatically in case of automatic optimization.If a learning rate scheduler is specified in
configure_optimizers()with key"interval"(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16), Lightning will automatically handle the optimizer.If you use
torch.optim.LBFGS, Lightning handles the closure function automatically for you.If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.
If you need to control how often the optimizer steps, override the
optimizer_step()hook.
- forward(x)[source]
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- predict_step(batch)[source]
Step function called during
predict(). By default, it callsforward(). Override to add any processing logic.The
predict_step()is used to scale inference on multi-devices.To prevent an OOM error, it is possible to use
BasePredictionWritercallback to write the predictions to disk or database after each batch or on epoch end.The
BasePredictionWritershould be used while using a spawn based accelerator. This happens forTrainer(strategy="ddp_spawn")or training on 8 TPU cores withTrainer(accelerator="tpu", devices=8)as predictions won’t be returned.- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Predicted output (optional).
Example
class MyModel(LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0): return self(batch) dm = ... model = MyModel() trainer = Trainer(accelerator="gpu", devices=2) predictions = trainer.predict(model, dm)
- training_step(batch)[source]
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary which can include any keys, but must include the key'loss'in the case of automatic optimization.None- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:
def __init__(self): super().__init__() self.automatic_optimization = False # Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx): opt1, opt2 = self.optimizers() # do training_step with encoder ... opt1.step() # do training_step with decoder ... opt2.step()
Note
When
accumulate_grad_batches> 1, the loss returned here will be automatically normalized byaccumulate_grad_batchesinternally.
- validation_step(batch)[source]
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
- class towbintools.deep_learning.architectures.archs.AttentionBlock1D(F_g, F_l, F_int)[source]
Bases:
Module1D attention gate for use in
AttentionUnet1D.Computes a soft attention map from a gating signal
g(from the decoder) and a skip-connection feature mapx(from the encoder). The output isxweighted element-wise by the attention coefficients.- Parameters:
F_g (int) – Number of channels in the gating signal
g.F_l (int) – Number of channels in the skip-connection feature map
x.F_int (int) – Number of intermediate channels used to compute the attention map.
- forward(g, x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class towbintools.deep_learning.architectures.archs.AttentionUnet1D(num_classes, input_channels=1, **kwargs)[source]
Bases:
Module1D U-Net with attention gates for sequence segmentation and keypoint detection.
Extends
Unet1Dby inserting anAttentionBlock1Dat each decoder stage. The attention gates suppress irrelevant activations in the encoder skip connections before concatenation. Filter counts follow [64, 128, 256, 512, 1024].- Parameters:
num_classes (int) – Number of output classes (output channels).
input_channels (int, optional) – Number of input sequence channels. (default: 1)
**kwargs – Ignored; accepted for API compatibility.
- forward(input)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class towbintools.deep_learning.architectures.archs.Unet1D(num_classes, input_channels=1, **kwargs)[source]
Bases:
Module1D U-Net for sequence segmentation and keypoint detection.
The 1D analogue of
Unet, operating on 1D sequences with MaxPool1d downsampling and linear upsampling. Filter counts follow [64, 128, 256, 512, 1024].- Parameters:
num_classes (int) – Number of output classes (output channels).
input_channels (int, optional) – Number of input sequence channels. (default: 1)
**kwargs – Ignored; accepted for API compatibility.
- forward(input)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class towbintools.deep_learning.architectures.archs.UnetPlusPlus1D(num_classes, input_channels=1, deep_supervision=False, **kwargs)[source]
Bases:
Module1D UNet++ for sequence segmentation with optional deep supervision.
The 1D analogue of
UnetPlusPlus, with dense nested skip connections between all encoder and decoder nodes at the same resolution. Uses Conv1d, MaxPool1d, and linear upsampling. Filter counts follow [64, 128, 256, 512, 1024]. Whendeep_supervision=True, returns a list of four outputs from intermediate decoder nodes; otherwise returns a single output.- Parameters:
num_classes (int) – Number of output classes (output channels).
input_channels (int, optional) – Number of input sequence channels. (default: 1)
deep_supervision (bool, optional) – If
True, return outputs from all intermediate decoder stages. (default: False)**kwargs – Ignored; accepted for API compatibility.
- forward(input)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class towbintools.deep_learning.architectures.archs.VGGBlock1D(in_channels, middle_channels, out_channels)[source]
Bases:
ModuleOne-dimensional two-layer convolutional block with BatchNorm and ReLU.
The 1D analogue of
VGGBlock, using Conv1d. Used as the basic building block inUnet1D,AttentionUnet1D, andUnetPlusPlus1D.- Parameters:
in_channels (int) – Number of input channels.
middle_channels (int) – Number of channels after the first convolution.
out_channels (int) – Number of output channels.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
util
- towbintools.deep_learning.utils.util.adjust_tensor_dimensions(source_tensor: Tensor, target_tensor_shape: tuple) Tensor[source]
Squeeze
source_tensorthen unsqueeze it to matchtarget_tensor_shape.- Parameters:
source_tensor (Tensor) – Tensor to reshape.
target_tensor_shape (tuple[int, ...]) – Desired shape of the output tensor.
- Returns:
Reshaped tensor compatible with
target_tensor_shape.- Return type:
Tensor
- towbintools.deep_learning.utils.util.create_lightweight_checkpoint(input_path: str, output_path: str) dict[source]
Load a PyTorch Lightning checkpoint and save a lightweight version.
Keeps only
state_dict,hyper_parameters, and a small set of optional metadata keys (epoch,global_step,pytorch-lightning_version), discarding optimizer state and other large tensors.- Parameters:
input_path (str) – Path to the full
.ckptcheckpoint file.output_path (str) – Destination path for the lightweight checkpoint.
- Returns:
- The lightweight checkpoint dictionary that was saved to
output_path.
- Return type:
dict
- towbintools.deep_learning.utils.util.divide_batch(batch: Tensor, n: int) Iterator[Tensor][source]
Yield successive mini-batches of size
nfrombatch.- Parameters:
batch (Tensor) – Input batch tensor with the batch dimension as axis 0.
n (int) – Mini-batch size.
- Yields:
Tensor – Slice of
batchalong axis 0 of at mostnsamples.
- towbintools.deep_learning.utils.util.get_closest_lower_multiple(dim: int | float, multiple: int) int[source]
Round
dimdown to the nearest multiple ofmultiple.- Parameters:
dim (int or float) – Value to round down.
multiple (int) – The multiple to round to.
- Returns:
Largest multiple of
multiplethat is <=dim.- Return type:
int
- towbintools.deep_learning.utils.util.get_closest_upper_multiple(dim: int | float, multiple: int) int[source]
Round
dimup to the nearest multiple ofmultiple.- Parameters:
dim (int or float) – Value to round up.
multiple (int) – The multiple to round to.
- Returns:
Smallest multiple of
multiplethat is >=dim.- Return type:
int
- towbintools.deep_learning.utils.util.get_input_channels_from_checkpoint(checkpoint_path: str) int[source]
Infer the number of input channels from a PyTorch Lightning checkpoint.
Searches the state dict for the first convolutional weight tensor and returns its
in_channelsdimension (weight.shape[1]).- Parameters:
checkpoint_path (str) – Path to the
.ckptcheckpoint file.- Returns:
Number of input channels, or 0 if no convolutional layer was found.
- Return type:
int
- towbintools.deep_learning.utils.util.rename_keys_and_adjust_dimensions(model: Module, pretrained_model: dict) dict[source]
Map pretrained weights into a model with differently named or shaped parameters.
Pairs keys from
model.state_dict()with keys frompretrained_modelby position, adjusting tensor shapes viaadjust_tensor_dimensions()when they differ.- Parameters:
model (nn.Module) – Target model whose state dict keys define the mapping.
pretrained_model (dict) – Source state dict (e.g. from
torch.load(...)["state_dict"]).
- Returns:
New state dict compatible with
model.load_state_dict().- Return type:
dict
- Raises:
AssertionError – If the number of keys in
modelandpretrained_modeldo not match.