Source code for towbintools.deep_learning.deep_learning_tools

from typing import Any

from towbintools.deep_learning.architectures import ClassificationModel
from towbintools.deep_learning.architectures import KeypointDetection1DModel
from towbintools.deep_learning.architectures import SegmentationModel
from towbintools.deep_learning.utils.util import (
    get_input_channels_from_checkpoint,
)


[docs] def create_classification_model( architecture: str, input_channels: int, classes: list[str], learning_rate: float = 1e-4, checkpoint_path: str | None = None, normalization: dict = {"type": "percentile", "lo": 1, "hi": 99}, ) -> ClassificationModel: """ Create a classification model. Parameters: architecture (str): ``timm`` model name (e.g. ``"efficientnet_b0"``). input_channels (int): Number of input image channels. classes (list[str]): List of class label strings. learning_rate (float, optional): Learning rate for the Adam optimizer. (default: 1e-4) checkpoint_path (str, optional): Path to a ``.ckpt`` checkpoint; if provided, the model is loaded from the checkpoint and all other arguments are ignored. (default: None) normalization (dict, optional): Normalization config stored as a hyperparameter for inference. (default: percentile 1–99) Returns: ClassificationModel: Constructed or loaded classification model. """ if checkpoint_path is not None: model = ClassificationModel.load_from_checkpoint( checkpoint_path, weights_only=False ) return model model = ClassificationModel( architecture, input_channels, classes, learning_rate, normalization, ) return model
[docs] def create_segmentation_model( input_channels: int = 1, n_classes: int = 1, architecture: str = "UnetPlusPlus", encoder: str = "efficientnet-b4", pretrained_weights: str = "image-micronet", normalization: dict = {"type": "percentile", "lo": 1, "hi": 99}, learning_rate: float = 1e-5, checkpoint_path: str | None = None, reset_optimizer: bool = True, criterion: Any | None = None, ) -> SegmentationModel: """ Create a segmentation model with a pretrained encoder. Parameters: input_channels (int, optional): Number of input image channels. (default: 1) n_classes (int, optional): Number of foreground segmentation classes. (default: 1) architecture (str, optional): ``smp`` architecture name. (default: ``"UnetPlusPlus"``) encoder (str, optional): Encoder backbone name. (default: ``"efficientnet-b4"``) pretrained_weights (str, optional): Dataset the encoder was pretrained on. (default: ``"image-micronet"``) normalization (dict, optional): Normalization config. (default: percentile 1–99) learning_rate (float, optional): Learning rate for the Adam optimizer. (default: 1e-5) checkpoint_path (str, optional): Path to a ``.ckpt`` checkpoint. If provided, the model is loaded from the checkpoint, then ``learning_rate`` and ``normalization`` are updated. (default: None) reset_optimizer (bool, optional): If ``True``, discard the optimizer state from the checkpoint. (default: True) criterion (nn.Module, optional): Loss function. If ``None``, a default is chosen based on ``n_classes``. (default: None) Returns: SegmentationModel: Constructed or loaded segmentation model. Raises: ValueError: If the checkpoint architecture or encoder does not match the requested one. """ model = SegmentationModel( input_channels=input_channels, n_classes=n_classes, learning_rate=learning_rate, architecture=architecture, encoder=encoder, pretrained_weights=pretrained_weights, normalization=normalization, criterion=criterion, ) if checkpoint_path is not None: loaded_model = SegmentationModel.load_from_checkpoint( checkpoint_path, map_location="cpu", weights_only=False ) if reset_optimizer: loaded_model.optimizer = None loaded_model.lr_scheduler = None loaded_model._optimizer = None loaded_model._lr_scheduler = None # change the learning rate and normalization loaded_model.learning_rate = learning_rate loaded_model.normalization = normalization # check if the architecture matches if not isinstance(loaded_model.model, model.model.__class__): raise ValueError( f"Checkpoint architecture {loaded_model.model.__class__} does not match the requested architecture {architecture}" ) # check if the encoder matches if loaded_model.model.encoder.__class__ != model.model.encoder.__class__: raise ValueError( f"Checkpoint encoder architecture {type(loaded_model.model.encoder).__name__} does not match the requested encoder architecture {type(model.model.encoder).__name__}" ) loaded_model.configure_optimizers() return loaded_model return model
[docs] def create_keypoint_detection_model( architecture: str, input_channels: int, n_classes: int, learning_rate: float = 1e-4, checkpoint_path: str | None = None, criterion: Any | None = None, activation: str = "relu", ) -> KeypointDetection1DModel: """ Create a 1D keypoint detection model. Parameters: architecture (str): Architecture name; one of ``"Unet"``, ``"AttentionUnet"``, or ``"UnetPlusPlus"``. input_channels (int): Number of input sequence channels. n_classes (int): Number of keypoint classes (output channels). learning_rate (float, optional): Learning rate for the Adam optimizer. (default: 1e-4) checkpoint_path (str, optional): Path to a ``.ckpt`` checkpoint; if provided, the model is loaded from the checkpoint and all other arguments are ignored. (default: None) criterion (nn.Module, optional): Loss function. If ``None``, ``PeakWeightedMSELoss`` is used. (default: None) activation (str, optional): Output activation; one of ``"relu"``, ``"leaky_relu"``, ``"sigmoid"``, or ``"none"``. (default: ``"relu"``) Returns: KeypointDetection1DModel: Constructed or loaded keypoint detection model. """ if checkpoint_path is not None: model = KeypointDetection1DModel.load_from_checkpoint( checkpoint_path, weights_only=False ) return model model = KeypointDetection1DModel( input_channels, n_classes, learning_rate, architecture=architecture, activation=activation, criterion=criterion, ) return model
[docs] def load_segmentation_model_from_checkpoint( checkpoint_path: str, ) -> SegmentationModel: """ Load a segmentation model from a checkpoint. First tries a direct load; if that fails (e.g. mismatched ``input_channels`` in the checkpoint metadata), infers the channel count from the checkpoint weights and retries with ``pretrained_weights=None``. Parameters: checkpoint_path (str): Path to a ``.ckpt`` checkpoint file. Returns: SegmentationModel: Loaded segmentation model. Raises: ValueError: If both loading attempts fail. """ try: return SegmentationModel.load_from_checkpoint( checkpoint_path, weights_only=False ) except Exception as e: try: return SegmentationModel.load_from_checkpoint( checkpoint_path, input_channels=get_input_channels_from_checkpoint(checkpoint_path), pretrained_weights=None, weights_only=False, ) except Exception as e2: raise ValueError( f"Could not load model from checkpoint {checkpoint_path}. Error: {e} and {e2}" )
[docs] def load_keypoint_detection_model_from_checkpoint( checkpoint_path: str, ) -> KeypointDetection1DModel: """ Load a 1D keypoint detection model from a checkpoint. Parameters: checkpoint_path (str): Path to a ``.ckpt`` checkpoint file. Returns: KeypointDetection1DModel: Loaded keypoint detection model. Raises: ValueError: If the model cannot be loaded from the checkpoint. """ try: return KeypointDetection1DModel.load_from_checkpoint( checkpoint_path, weights_only=False ) except Exception as e: raise ValueError( f"Could not load keypoint detection model from checkpoint {checkpoint_path}. Error: {e}" )