import torch
from torch import nn
# 1D UNet implementation
[docs]
class VGGBlock1D(nn.Module):
"""
One-dimensional two-layer convolutional block with BatchNorm and ReLU.
The 1D analogue of :class:`VGGBlock`, using Conv1d. Used as the basic
building block in :class:`Unet1D`, :class:`AttentionUnet1D`, and
:class:`UnetPlusPlus1D`.
Parameters:
in_channels (int): Number of input channels.
middle_channels (int): Number of channels after the first convolution.
out_channels (int): Number of output channels.
"""
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv1d(in_channels, middle_channels, 3, padding=1)
self.bn1 = nn.BatchNorm1d(middle_channels)
self.conv2 = nn.Conv1d(middle_channels, out_channels, 3, padding=1)
self.bn2 = nn.BatchNorm1d(out_channels)
[docs]
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
return out
[docs]
class AttentionBlock1D(nn.Module):
"""
1D attention gate for use in :class:`AttentionUnet1D`.
Computes a soft attention map from a gating signal ``g`` (from the decoder)
and a skip-connection feature map ``x`` (from the encoder). The output is
``x`` weighted element-wise by the attention coefficients.
Parameters:
F_g (int): Number of channels in the gating signal ``g``.
F_l (int): Number of channels in the skip-connection feature map ``x``.
F_int (int): Number of intermediate channels used to compute the
attention map.
"""
def __init__(self, F_g, F_l, F_int):
super().__init__()
self.W_g = nn.Sequential(
nn.Conv1d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm1d(F_int),
)
self.W_x = nn.Sequential(
nn.Conv1d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm1d(F_int),
)
self.psi = nn.Sequential(
nn.Conv1d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm1d(1),
nn.Sigmoid(),
)
self.relu = nn.ReLU(inplace=True)
[docs]
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
[docs]
class Unet1D(nn.Module):
"""
1D U-Net for sequence segmentation and keypoint detection.
The 1D analogue of :class:`Unet`, operating on 1D sequences with
MaxPool1d downsampling and linear upsampling. Filter counts follow
[64, 128, 256, 512, 1024].
Parameters:
num_classes (int): Number of output classes (output channels).
input_channels (int, optional): Number of input sequence channels. (default: 1)
**kwargs: Ignored; accepted for API compatibility.
"""
def __init__(self, num_classes, input_channels=1, **kwargs):
super().__init__()
nb_filter = [64, 128, 256, 512, 1024]
self.pool = nn.MaxPool1d(2, 2)
self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
self.conv0_0 = VGGBlock1D(input_channels, nb_filter[0], nb_filter[0])
self.conv1_0 = VGGBlock1D(nb_filter[0], nb_filter[1], nb_filter[1])
self.conv2_0 = VGGBlock1D(nb_filter[1], nb_filter[2], nb_filter[2])
self.conv3_0 = VGGBlock1D(nb_filter[2], nb_filter[3], nb_filter[3])
self.conv4_0 = VGGBlock1D(nb_filter[3], nb_filter[4], nb_filter[4])
self.conv3_1 = VGGBlock1D(
nb_filter[3] + nb_filter[4], nb_filter[3], nb_filter[3]
)
self.conv2_2 = VGGBlock1D(
nb_filter[2] + nb_filter[3], nb_filter[2], nb_filter[2]
)
self.conv1_3 = VGGBlock1D(
nb_filter[1] + nb_filter[2], nb_filter[1], nb_filter[1]
)
self.conv0_4 = VGGBlock1D(
nb_filter[0] + nb_filter[1], nb_filter[0], nb_filter[0]
)
self.final = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1)
[docs]
def forward(self, input):
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.pool(x0_0))
x2_0 = self.conv2_0(self.pool(x1_0))
x3_0 = self.conv3_0(self.pool(x2_0))
x4_0 = self.conv4_0(self.pool(x3_0))
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))
output = self.final(x0_4)
return output
[docs]
class AttentionUnet1D(nn.Module):
"""
1D U-Net with attention gates for sequence segmentation and keypoint detection.
Extends :class:`Unet1D` by inserting an :class:`AttentionBlock1D` at each
decoder stage. The attention gates suppress irrelevant activations in the
encoder skip connections before concatenation. Filter counts follow
[64, 128, 256, 512, 1024].
Parameters:
num_classes (int): Number of output classes (output channels).
input_channels (int, optional): Number of input sequence channels. (default: 1)
**kwargs: Ignored; accepted for API compatibility.
"""
def __init__(self, num_classes, input_channels=1, **kwargs):
super().__init__()
nb_filter = [64, 128, 256, 512, 1024]
self.pool = nn.MaxPool1d(2, 2)
self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
self.conv0_0 = VGGBlock1D(input_channels, nb_filter[0], nb_filter[0])
self.conv1_0 = VGGBlock1D(nb_filter[0], nb_filter[1], nb_filter[1])
self.conv2_0 = VGGBlock1D(nb_filter[1], nb_filter[2], nb_filter[2])
self.conv3_0 = VGGBlock1D(nb_filter[2], nb_filter[3], nb_filter[3])
self.conv4_0 = VGGBlock1D(nb_filter[3], nb_filter[4], nb_filter[4])
self.conv3_1 = VGGBlock1D(
nb_filter[3] + nb_filter[4], nb_filter[3], nb_filter[3]
)
self.conv2_2 = VGGBlock1D(
nb_filter[2] + nb_filter[3], nb_filter[2], nb_filter[2]
)
self.conv1_3 = VGGBlock1D(
nb_filter[1] + nb_filter[2], nb_filter[1], nb_filter[1]
)
self.conv0_4 = VGGBlock1D(
nb_filter[0] + nb_filter[1], nb_filter[0], nb_filter[0]
)
self.att4 = AttentionBlock1D(
F_g=nb_filter[4], F_l=nb_filter[3], F_int=nb_filter[2]
)
self.att3 = AttentionBlock1D(
F_g=nb_filter[3], F_l=nb_filter[2], F_int=nb_filter[1]
)
self.att2 = AttentionBlock1D(
F_g=nb_filter[2], F_l=nb_filter[1], F_int=nb_filter[0]
)
self.att1 = AttentionBlock1D(
F_g=nb_filter[1], F_l=nb_filter[0], F_int=nb_filter[0] // 2
)
self.final = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1)
[docs]
def forward(self, input):
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.pool(x0_0))
x2_0 = self.conv2_0(self.pool(x1_0))
x3_0 = self.conv3_0(self.pool(x2_0))
x4_0 = self.conv4_0(self.pool(x3_0))
att4_out = self.att4(g=self.up(x4_0), x=x3_0)
x3_1 = self.conv3_1(torch.cat([att4_out, self.up(x4_0)], 1))
att3_out = self.att3(g=self.up(x3_1), x=x2_0)
x2_2 = self.conv2_2(torch.cat([att3_out, self.up(x3_1)], 1))
att2_out = self.att2(g=self.up(x2_2), x=x1_0)
x1_3 = self.conv1_3(torch.cat([att2_out, self.up(x2_2)], 1))
att1_out = self.att1(g=self.up(x1_3), x=x0_0)
x0_4 = self.conv0_4(torch.cat([att1_out, self.up(x1_3)], 1))
output = self.final(x0_4)
return output
[docs]
class UnetPlusPlus1D(nn.Module):
"""
1D UNet++ for sequence segmentation with optional deep supervision.
The 1D analogue of :class:`UnetPlusPlus`, with dense nested skip connections
between all encoder and decoder nodes at the same resolution. Uses Conv1d,
MaxPool1d, and linear upsampling. Filter counts follow [64, 128, 256, 512,
1024]. When ``deep_supervision=True``, returns a list of four outputs from
intermediate decoder nodes; otherwise returns a single output.
Parameters:
num_classes (int): Number of output classes (output channels).
input_channels (int, optional): Number of input sequence channels. (default: 1)
deep_supervision (bool, optional): If ``True``, return outputs from all
intermediate decoder stages. (default: False)
**kwargs: Ignored; accepted for API compatibility.
"""
def __init__(self, num_classes, input_channels=1, deep_supervision=False, **kwargs):
super().__init__()
# nb_filter = [32, 64, 128, 256, 512]
nb_filter = [64, 128, 256, 512, 1024]
self.deep_supervision = deep_supervision
self.pool = nn.MaxPool1d(2, 2)
self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
self.conv0_0 = VGGBlock1D(input_channels, nb_filter[0], nb_filter[0])
self.conv1_0 = VGGBlock1D(nb_filter[0], nb_filter[1], nb_filter[1])
self.conv2_0 = VGGBlock1D(nb_filter[1], nb_filter[2], nb_filter[2])
self.conv3_0 = VGGBlock1D(nb_filter[2], nb_filter[3], nb_filter[3])
self.conv4_0 = VGGBlock1D(nb_filter[3], nb_filter[4], nb_filter[4])
self.conv0_1 = VGGBlock1D(
nb_filter[0] + nb_filter[1], nb_filter[0], nb_filter[0]
)
self.conv1_1 = VGGBlock1D(
nb_filter[1] + nb_filter[2], nb_filter[1], nb_filter[1]
)
self.conv2_1 = VGGBlock1D(
nb_filter[2] + nb_filter[3], nb_filter[2], nb_filter[2]
)
self.conv3_1 = VGGBlock1D(
nb_filter[3] + nb_filter[4], nb_filter[3], nb_filter[3]
)
self.conv0_2 = VGGBlock1D(
nb_filter[0] * 2 + nb_filter[1], nb_filter[0], nb_filter[0]
)
self.conv1_2 = VGGBlock1D(
nb_filter[1] * 2 + nb_filter[2], nb_filter[1], nb_filter[1]
)
self.conv2_2 = VGGBlock1D(
nb_filter[2] * 2 + nb_filter[3], nb_filter[2], nb_filter[2]
)
self.conv0_3 = VGGBlock1D(
nb_filter[0] * 3 + nb_filter[1], nb_filter[0], nb_filter[0]
)
self.conv1_3 = VGGBlock1D(
nb_filter[1] * 3 + nb_filter[2], nb_filter[1], nb_filter[1]
)
self.conv0_4 = VGGBlock1D(
nb_filter[0] * 4 + nb_filter[1], nb_filter[0], nb_filter[0]
)
if self.deep_supervision:
self.final1 = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1)
self.final2 = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1)
self.final3 = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1)
self.final4 = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1)
else:
self.final = nn.Conv1d(nb_filter[0], num_classes, kernel_size=1)
[docs]
def forward(self, input):
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.pool(x0_0))
x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
x2_0 = self.conv2_0(self.pool(x1_0))
x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
x3_0 = self.conv3_0(self.pool(x2_0))
x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
x4_0 = self.conv4_0(self.pool(x3_0))
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
if self.deep_supervision:
output1 = self.final1(x0_1)
output2 = self.final2(x0_2)
output3 = self.final3(x0_3)
output4 = self.final4(x0_4)
return [output1, output2, output3, output4]
else:
output = self.final(x0_4)
return output