Source code for thelper.nn.coordconv

import collections

import torch
import torch.nn


[docs]def get_coords_map(height, width, centered=True, normalized=True, noise=None, dtype=torch.float32): """Returns a HxW intrinsic coordinates map tensor (shape=2xHxW).""" x = torch.arange(width, dtype=dtype).unsqueeze(0) y = torch.arange(height, dtype=dtype).unsqueeze(0) if centered: x -= (width - 1) // 2 y -= (height - 1) // 2 if normalized: x /= width - 1 y /= height - 1 x = x.repeat(height, 1) y = y.t().repeat(1, width) if noise is not None: assert isinstance(noise, float) and noise >= 0, "invalid noise stddev value" x = torch.normal(mean=x, std=noise) y = torch.normal(mean=y, std=noise) return torch.stack([x, y])
class AddCoords(torch.nn.Module): """Creates a torch-compatible layer that adds intrinsic coordinate layers to input tensors.""" def __init__(self, centered=True, normalized=True, noise=None, radius_channel=False, scale=None): super().__init__() self.centered = centered self.normalized = normalized self.noise = noise self.radius_channel = radius_channel self.scale = None def forward(self, in_tensor): batch_size, channels, height, width = in_tensor.shape coords_map = get_coords_map(height, width, self.centered, self.normalized, self.noise) if self.scale is not None: coords_map *= self.scale if self.radius_channel: middle_slice = coords_map[:, (height - 1) // 2, (width - 1) // 2] radius = torch.sqrt(torch.pow(coords_map[0, :, :] - middle_slice[0], 2) + torch.pow(coords_map[1, :, :] - middle_slice[1], 2)) coords_map = torch.cat([coords_map, radius.unsqueeze(0)], dim=0) coords_map = coords_map.repeat(batch_size, 1, 1, 1) dev = in_tensor.device out = torch.cat([in_tensor, coords_map.to(dev)], dim=1) return out class CoordConv2d(torch.nn.Module): """CoordConv-equivalent of torch's default Conv2d model layer. .. seealso:: | Liu et al., An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution` <https://arxiv.org/abs/1807.03247>`_ [arXiv], 2018. """ def __init__(self, in_channels, *args, centered=True, normalized=True, noise=None, radius_channel=False, scale=None, **kwargs): super().__init__() self.addcoord = AddCoords(centered=centered, normalized=normalized, noise=noise, radius_channel=radius_channel, scale=scale) extra_ch = 3 if radius_channel else 2 self.conv = torch.nn.Conv2d(in_channels + extra_ch, *args, **kwargs) def forward(self, in_tensor): out = self.addcoord(in_tensor) out = self.conv(out) return out class CoordConvTranspose2d(torch.nn.Module): """CoordConv-equivalent of torch's default ConvTranspose2d model layer. .. seealso:: | Liu et al., An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution` <https://arxiv.org/abs/1807.03247>`_ [arXiv], 2018. """ def __init__(self, in_channels, *args, centered=True, normalized=True, noise=None, radius_channel=False, scale=None, **kwargs): super().__init__() self.addcoord = AddCoords(centered=centered, normalized=normalized, noise=noise, radius_channel=radius_channel, scale=scale) extra_ch = 3 if radius_channel else 2 self.conv = torch.nn.ConvTranspose2d(in_channels + extra_ch, *args, **kwargs) def forward(self, in_tensor): out = self.addcoord(in_tensor) out = self.conv(out) return out
[docs]def swap_coordconv_layers(module, centered=True, normalized=True, noise=None, radius_channel=False, scale=None): """Modifies the provided module by swapping Conv2d layers for CoordConv-equivalent layers.""" # note: this is a pretty 'dumb' way to add coord maps to a model, as it will add them everywhere, even # in a potential output (1x1) conv layer; manually designing the network with these would be more adequate! coordconv_params = {"centered": centered, "normalized": normalized, "noise": noise, "radius_channel": radius_channel, "scale": scale} if isinstance(module, torch.nn.Conv2d): return CoordConv2d(in_channels=module.in_channels, out_channels=module.out_channels, kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, dilation=module.dilation, groups=module.groups, bias=module.bias is not None, padding_mode=module.padding_mode, **coordconv_params) elif isinstance(module, torch.nn.ConvTranspose2d): return CoordConvTranspose2d(in_channels=module.in_channels, out_channels=module.out_channels, kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, output_padding=module.output_padding, groups=module.groups, bias=module.bias is not None, dilation=module.dilation, padding_mode=module.padding_mode, **coordconv_params) elif isinstance(module, torch.nn.Sequential): return torch.nn.Sequential(*[swap_coordconv_layers(m, **coordconv_params) for m in module]) elif isinstance(module, collections.OrderedDict): for mname, m in module.items(): module[mname] = swap_coordconv_layers(m, **coordconv_params) elif isinstance(module, torch.nn.Module): for attrib, m in module.__dict__.items(): if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): setattr(module, attrib, swap_coordconv_layers(m, **coordconv_params)) elif isinstance(m, (torch.nn.Module, collections.OrderedDict)): setattr(module, attrib, swap_coordconv_layers(m, **coordconv_params)) return module
[docs]def make_conv2d(*args, coordconv=False, centered=True, normalized=True, noise=None, radius_channel=False, scale=None, **kwargs): """Creates a 2D convolution layer with optional CoordConv support.""" if coordconv: return CoordConv2d(*args, centered=centered, normalized=normalized, noise=noise, radius_channel=radius_channel, scale=scale, **kwargs) else: return torch.nn.Conv2d(*args, **kwargs)