Source code for towbintools.deep_learning.architectures.archs

import torch
from torch import nn


# 1D UNet implementation
[docs] class VGGBlock1D(nn.Module): """ One-dimensional two-layer convolutional block with BatchNorm and ReLU. The 1D analogue of :class:`VGGBlock`, using Conv1d. Used as the basic building block in :class:`Unet1D`, :class:`AttentionUnet1D`, and :class:`UnetPlusPlus1D`. Parameters: in_channels (int): Number of input channels. middle_channels (int): Number of channels after the first convolution. out_channels (int): Number of output channels. """ def __init__(self, in_channels, middle_channels, out_channels): super().__init__() self.relu = nn.ReLU(inplace=True) self.conv1 = nn.Conv1d(in_channels, middle_channels, 3, padding=1) self.bn1 = nn.BatchNorm1d(middle_channels) self.conv2 = nn.Conv1d(middle_channels, out_channels, 3, padding=1) self.bn2 = nn.BatchNorm1d(out_channels)
[docs] def forward(self, x): out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) return out
[docs] class AttentionBlock1D(nn.Module): """ 1D attention gate for use in :class:`AttentionUnet1D`. Computes a soft attention map from a gating signal ``g`` (from the decoder) and a skip-connection feature map ``x`` (from the encoder). The output is ``x`` weighted element-wise by the attention coefficients. Parameters: F_g (int): Number of channels in the gating signal ``g``. F_l (int): Number of channels in the skip-connection feature map ``x``. F_int (int): Number of intermediate channels used to compute the attention map. """ def __init__(self, F_g, F_l, F_int): super().__init__() self.W_g = nn.Sequential( nn.Conv1d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm1d(F_int), ) self.W_x = nn.Sequential( nn.Conv1d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm1d(F_int), ) self.psi = nn.Sequential( nn.Conv1d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm1d(1), nn.Sigmoid(), ) self.relu = nn.ReLU(inplace=True)
[docs] def forward(self, g, x): g1 = self.W_g(g) x1 = self.W_x(x) psi = self.relu(g1 + x1) psi = self.psi(psi) return x * psi
[docs] class Unet1D(nn.Module): """ 1D U-Net for sequence segmentation and keypoint detection. The 1D analogue of :class:`Unet`, operating on 1D sequences with MaxPool1d downsampling and linear upsampling. Filter counts follow [64, 128, 256, 512, 1024]. Parameters: num_classes (int): Number of output classes (output channels). input_channels (int, optional): Number of input sequence channels. (default: 1) **kwargs: Ignored; accepted for API compatibility. """ def __init__(self, num_classes, input_channels=1, **kwargs): super().__init__() nb_filter = [64, 128, 256, 512, 1024] self.pool = nn.MaxPool1d(2, 2) self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True) self.conv0_0 = VGGBlock1D(input_channels, nb_filter[0], nb_filter[0]) self.conv1_0 = VGGBlock1D(nb_filter[0], nb_filter[1], nb_filter[1]) self.conv2_0 = VGGBlock1D(nb_filter[1], nb_filter[2], nb_filter[2]) self.conv3_0 = VGGBlock1D(nb_filter[2], nb_filter[3], nb_filter[3]) self.conv4_0 = VGGBlock1D(nb_filter[3], nb_filter[4], nb_filter[4]) self.conv3_1 = VGGBlock1D( nb_filter[3] + nb_filter[4], nb_filter[3], nb_filter[3] ) self.conv2_2 = VGGBlock1D( nb_filter[2] + nb_filter[3], nb_filter[2], nb_filter[2] ) self.conv1_3 = VGGBlock1D( nb_filter[1] + nb_filter[2], nb_filter[1], nb_filter[1] ) self.conv0_4 = VGGBlock1D( nb_filter[0] + nb_filter[1], nb_filter[0], nb_filter[0] ) self.final = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1)
[docs] def forward(self, input): x0_0 = self.conv0_0(input) x1_0 = self.conv1_0(self.pool(x0_0)) x2_0 = self.conv2_0(self.pool(x1_0)) x3_0 = self.conv3_0(self.pool(x2_0)) x4_0 = self.conv4_0(self.pool(x3_0)) x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1)) x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1)) x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1)) x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1)) output = self.final(x0_4) return output
[docs] class AttentionUnet1D(nn.Module): """ 1D U-Net with attention gates for sequence segmentation and keypoint detection. Extends :class:`Unet1D` by inserting an :class:`AttentionBlock1D` at each decoder stage. The attention gates suppress irrelevant activations in the encoder skip connections before concatenation. Filter counts follow [64, 128, 256, 512, 1024]. Parameters: num_classes (int): Number of output classes (output channels). input_channels (int, optional): Number of input sequence channels. (default: 1) **kwargs: Ignored; accepted for API compatibility. """ def __init__(self, num_classes, input_channels=1, **kwargs): super().__init__() nb_filter = [64, 128, 256, 512, 1024] self.pool = nn.MaxPool1d(2, 2) self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True) self.conv0_0 = VGGBlock1D(input_channels, nb_filter[0], nb_filter[0]) self.conv1_0 = VGGBlock1D(nb_filter[0], nb_filter[1], nb_filter[1]) self.conv2_0 = VGGBlock1D(nb_filter[1], nb_filter[2], nb_filter[2]) self.conv3_0 = VGGBlock1D(nb_filter[2], nb_filter[3], nb_filter[3]) self.conv4_0 = VGGBlock1D(nb_filter[3], nb_filter[4], nb_filter[4]) self.conv3_1 = VGGBlock1D( nb_filter[3] + nb_filter[4], nb_filter[3], nb_filter[3] ) self.conv2_2 = VGGBlock1D( nb_filter[2] + nb_filter[3], nb_filter[2], nb_filter[2] ) self.conv1_3 = VGGBlock1D( nb_filter[1] + nb_filter[2], nb_filter[1], nb_filter[1] ) self.conv0_4 = VGGBlock1D( nb_filter[0] + nb_filter[1], nb_filter[0], nb_filter[0] ) self.att4 = AttentionBlock1D( F_g=nb_filter[4], F_l=nb_filter[3], F_int=nb_filter[2] ) self.att3 = AttentionBlock1D( F_g=nb_filter[3], F_l=nb_filter[2], F_int=nb_filter[1] ) self.att2 = AttentionBlock1D( F_g=nb_filter[2], F_l=nb_filter[1], F_int=nb_filter[0] ) self.att1 = AttentionBlock1D( F_g=nb_filter[1], F_l=nb_filter[0], F_int=nb_filter[0] // 2 ) self.final = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1)
[docs] def forward(self, input): x0_0 = self.conv0_0(input) x1_0 = self.conv1_0(self.pool(x0_0)) x2_0 = self.conv2_0(self.pool(x1_0)) x3_0 = self.conv3_0(self.pool(x2_0)) x4_0 = self.conv4_0(self.pool(x3_0)) att4_out = self.att4(g=self.up(x4_0), x=x3_0) x3_1 = self.conv3_1(torch.cat([att4_out, self.up(x4_0)], 1)) att3_out = self.att3(g=self.up(x3_1), x=x2_0) x2_2 = self.conv2_2(torch.cat([att3_out, self.up(x3_1)], 1)) att2_out = self.att2(g=self.up(x2_2), x=x1_0) x1_3 = self.conv1_3(torch.cat([att2_out, self.up(x2_2)], 1)) att1_out = self.att1(g=self.up(x1_3), x=x0_0) x0_4 = self.conv0_4(torch.cat([att1_out, self.up(x1_3)], 1)) output = self.final(x0_4) return output
[docs] class UnetPlusPlus1D(nn.Module): """ 1D UNet++ for sequence segmentation with optional deep supervision. The 1D analogue of :class:`UnetPlusPlus`, with dense nested skip connections between all encoder and decoder nodes at the same resolution. Uses Conv1d, MaxPool1d, and linear upsampling. Filter counts follow [64, 128, 256, 512, 1024]. When ``deep_supervision=True``, returns a list of four outputs from intermediate decoder nodes; otherwise returns a single output. Parameters: num_classes (int): Number of output classes (output channels). input_channels (int, optional): Number of input sequence channels. (default: 1) deep_supervision (bool, optional): If ``True``, return outputs from all intermediate decoder stages. (default: False) **kwargs: Ignored; accepted for API compatibility. """ def __init__(self, num_classes, input_channels=1, deep_supervision=False, **kwargs): super().__init__() # nb_filter = [32, 64, 128, 256, 512] nb_filter = [64, 128, 256, 512, 1024] self.deep_supervision = deep_supervision self.pool = nn.MaxPool1d(2, 2) self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True) self.conv0_0 = VGGBlock1D(input_channels, nb_filter[0], nb_filter[0]) self.conv1_0 = VGGBlock1D(nb_filter[0], nb_filter[1], nb_filter[1]) self.conv2_0 = VGGBlock1D(nb_filter[1], nb_filter[2], nb_filter[2]) self.conv3_0 = VGGBlock1D(nb_filter[2], nb_filter[3], nb_filter[3]) self.conv4_0 = VGGBlock1D(nb_filter[3], nb_filter[4], nb_filter[4]) self.conv0_1 = VGGBlock1D( nb_filter[0] + nb_filter[1], nb_filter[0], nb_filter[0] ) self.conv1_1 = VGGBlock1D( nb_filter[1] + nb_filter[2], nb_filter[1], nb_filter[1] ) self.conv2_1 = VGGBlock1D( nb_filter[2] + nb_filter[3], nb_filter[2], nb_filter[2] ) self.conv3_1 = VGGBlock1D( nb_filter[3] + nb_filter[4], nb_filter[3], nb_filter[3] ) self.conv0_2 = VGGBlock1D( nb_filter[0] * 2 + nb_filter[1], nb_filter[0], nb_filter[0] ) self.conv1_2 = VGGBlock1D( nb_filter[1] * 2 + nb_filter[2], nb_filter[1], nb_filter[1] ) self.conv2_2 = VGGBlock1D( nb_filter[2] * 2 + nb_filter[3], nb_filter[2], nb_filter[2] ) self.conv0_3 = VGGBlock1D( nb_filter[0] * 3 + nb_filter[1], nb_filter[0], nb_filter[0] ) self.conv1_3 = VGGBlock1D( nb_filter[1] * 3 + nb_filter[2], nb_filter[1], nb_filter[1] ) self.conv0_4 = VGGBlock1D( nb_filter[0] * 4 + nb_filter[1], nb_filter[0], nb_filter[0] ) if self.deep_supervision: self.final1 = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1) self.final2 = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1) self.final3 = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1) self.final4 = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1) else: self.final = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1)
[docs] def forward(self, input): x0_0 = self.conv0_0(input) x1_0 = self.conv1_0(self.pool(x0_0)) x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1)) x2_0 = self.conv2_0(self.pool(x1_0)) x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1)) x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1)) x3_0 = self.conv3_0(self.pool(x2_0)) x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1)) x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1)) x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1)) x4_0 = self.conv4_0(self.pool(x3_0)) x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1)) x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1)) x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1)) x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1)) if self.deep_supervision: output1 = self.final1(x0_1) output2 = self.final2(x0_2) output3 = self.final3(x0_3) output4 = self.final4(x0_4) return [output1, output2, output3, output4] else: output = self.final(x0_4) return output