Source code for thelper.nn.unet

import math

import torch
import torch.nn

import thelper.nn.coordconv
import thelper.nn.srm

warned_bad_input_size_power2 = False


[docs]class BasicBlock(torch.nn.Module): """Default (double-conv) block used in U-Net layers."""
[docs] def __init__(self, in_channels, out_channels, coordconv=False, kernel_size=3, padding=1): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.coordconv = coordconv self.kernel_size = kernel_size self.padding = padding self.layer = torch.nn.Sequential( thelper.nn.coordconv.make_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, coordconv=coordconv), torch.nn.BatchNorm2d(out_channels), torch.nn.ReLU(inplace=True), thelper.nn.coordconv.make_conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, coordconv=coordconv), torch.nn.BatchNorm2d(out_channels), torch.nn.ReLU(inplace=True), )
[docs] def forward(self, x): return self.layer(x)
[docs]class UNet(thelper.nn.Module): """U-Net implementation. Not identical to the original. This version includes batchnorm and transposed conv2d layers for upsampling. Coordinate Convolutions (CoordConv) can also be toggled on if requested (see :mod:`thelper.nn.coordconv` for more information). """
[docs] def __init__(self, task, in_channels=3, mid_channels=512, coordconv=False, srm=False): super().__init__(task, **{k: v for k, v in vars().items() if k not in ["self", "task", "__class__"]}) self.in_channels = in_channels self.mid_channels = mid_channels self.coordconv = coordconv self.srm = srm self.pool = torch.nn.MaxPool2d(2) self.srm_conv = thelper.nn.srm.setup_srm_layer(in_channels) if srm else None self.encoder_block1 = BasicBlock(in_channels=in_channels + 3 if srm else in_channels, out_channels=mid_channels // 16, coordconv=coordconv) self.encoder_block2 = BasicBlock(in_channels=mid_channels // 16, out_channels=mid_channels // 8, coordconv=coordconv) self.encoder_block3 = BasicBlock(in_channels=mid_channels // 8, out_channels=mid_channels // 4, coordconv=coordconv) self.encoder_block4 = BasicBlock(in_channels=mid_channels // 4, out_channels=mid_channels // 2, coordconv=coordconv) self.mid_block = BasicBlock(in_channels=mid_channels // 2, out_channels=mid_channels, coordconv=coordconv) self.upsampling_block1 = torch.nn.ConvTranspose2d(in_channels=mid_channels, out_channels=mid_channels // 2, kernel_size=2, stride=2) self.decoder_block1 = BasicBlock(in_channels=mid_channels, out_channels=mid_channels // 2, coordconv=coordconv) self.upsampling_block2 = torch.nn.ConvTranspose2d(in_channels=mid_channels // 2, out_channels=mid_channels // 4, kernel_size=2, stride=2) self.decoder_block2 = BasicBlock(in_channels=mid_channels // 2, out_channels=mid_channels // 4, coordconv=coordconv) self.upsampling_block3 = torch.nn.ConvTranspose2d(in_channels=mid_channels // 4, out_channels=mid_channels // 8, kernel_size=2, stride=2) self.decoder_block3 = BasicBlock(in_channels=mid_channels // 4, out_channels=mid_channels // 8, coordconv=coordconv) self.upsampling_block4 = torch.nn.ConvTranspose2d(in_channels=mid_channels // 8, out_channels=mid_channels // 16, kernel_size=2, stride=2) self.final_block = None self.num_classes = None self.set_task(task)
[docs] def forward(self, x): global warned_bad_input_size_power2 if not warned_bad_input_size_power2 and len(x.shape) == 4: if not math.log(x.shape[-1], 2).is_integer() or not math.log(x.shape[-2], 2).is_integer(): warned_bad_input_size_power2 = True thelper.nn.logger.warning("unet input size should be power of 2 (e.g. 256x256, 512x512, ...)") if self.srm_conv is not None: noise = self.srm_conv(x) x = torch.cat([x, noise], dim=1) encoded1 = self.encoder_block1(x) # 512x512 encoded2 = self.encoder_block2(self.pool(encoded1)) # 256x256 encoded3 = self.encoder_block3(self.pool(encoded2)) # 128x128 encoded4 = self.encoder_block4(self.pool(encoded3)) # 64x64 embedding = self.mid_block(self.pool(encoded4)) # 32x32 decoded1 = self.decoder_block1(torch.cat([encoded4, self.upsampling_block1(embedding)], dim=1)) decoded2 = self.decoder_block2(torch.cat([encoded3, self.upsampling_block2(decoded1)], dim=1)) decoded3 = self.decoder_block3(torch.cat([encoded2, self.upsampling_block3(decoded2)], dim=1)) out = self.final_block(torch.cat([encoded1, self.upsampling_block4(decoded3)], dim=1)) return out
[docs] def set_task(self, task): assert isinstance(task, thelper.tasks.Segmentation), "missing impl for non-segm task type" if self.final_block is None or self.num_classes != len(task.class_names): self.num_classes = len(task.class_names) self.final_block = torch.nn.Sequential( torch.nn.Conv2d(in_channels=self.mid_channels // 8, out_channels=self.mid_channels // 16, kernel_size=3, padding=1), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(in_channels=self.mid_channels // 16, out_channels=self.num_classes, kernel_size=1), ) self.task = task