import torch
class DenseBlock(torch.nn.Module):
def __init__(self, input_size, output_size, bias=True, activation='relu', norm='batch'):
super(DenseBlock, self).__init__()
self.fc = torch.nn.Linear(input_size, output_size, bias=bias)
self.norm = norm
if self.norm == 'batch':
self.bn = torch.nn.BatchNorm1d(output_size)
elif self.norm == 'instance':
self.bn = torch.nn.InstanceNorm1d(output_size)
self.activation = activation
if self.activation == 'relu':
self.act = torch.nn.ReLU(True)
elif self.activation == 'prelu':
self.act = torch.nn.PReLU()
elif self.activation == 'lrelu':
self.act = torch.nn.LeakyReLU(0.2, True)
elif self.activation == 'tanh':
self.act = torch.nn.Tanh()
elif self.activation == 'sigmoid':
self.act = torch.nn.Sigmoid()
def forward(self, x):
if self.norm is not None:
out = self.bn(self.fc(x))
else:
out = self.fc(x)
if self.activation is not None:
return self.act(out)
else:
return out
class ConvBlock(torch.nn.Module):
def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1,
bias=True, activation='relu', norm='batch', groups=1, prelu_params=1):
super(ConvBlock, self).__init__()
self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size,
stride, padding, bias=bias, groups=groups)
self.norm = norm
if self.norm == 'batch':
self.bn = torch.nn.BatchNorm2d(output_size)
elif self.norm == 'instance':
self.bn = torch.nn.InstanceNorm2d(output_size)
elif self.norm is None:
self.bn = None
else:
raise(Exception('Bad normalization selection'))
self.activation = activation
if self.activation == 'relu':
self.act = torch.nn.ReLU(True)
elif self.activation == 'prelu':
if prelu_params != 1:
prelu_params = input_size
self.act = torch.nn.PReLU(num_parameters=prelu_params)
elif self.activation == 'lrelu':
self.act = torch.nn.LeakyReLU(0.2, True)
elif self.activation == 'tanh':
self.act = torch.nn.Tanh()
elif self.activation == 'sigmoid':
self.act = torch.nn.Sigmoid()
elif self.activation is None:
self.act = None
else:
raise(Exception('Bad activation selection'))
self.forward = self.forward_bn_act
def forward_bn_act(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.act is not None:
x = self.act(x)
return x
def forward_act_bn(self, x):
x = self.conv(x)
if self.act is not None:
x = self.act(x)
if self.bn is not None:
x = self.bn(x)
return x
class DeconvBlock(torch.nn.Module):
def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1,
bias=True, activation='relu', norm='batch'):
super(DeconvBlock, self).__init__()
self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size,
stride, padding, bias=bias)
self.norm = norm
if self.norm == 'batch':
self.bn = torch.nn.BatchNorm2d(output_size)
elif self.norm == 'instance':
self.bn = torch.nn.InstanceNorm2d(output_size)
self.activation = activation
if self.activation == 'relu':
self.act = torch.nn.ReLU(True)
elif self.activation == 'prelu':
self.act = torch.nn.PReLU()
elif self.activation == 'lrelu':
self.act = torch.nn.LeakyReLU(0.2, True)
elif self.activation == 'tanh':
self.act = torch.nn.Tanh()
elif self.activation == 'sigmoid':
self.act = torch.nn.Sigmoid()
def forward(self, x):
if self.norm is not None:
out = self.bn(self.deconv(x))
else:
out = self.deconv(x)
if self.activation is not None:
return self.act(out)
else:
return out
class ResNetBlock(torch.nn.Module):
def __init__(self, num_filter, kernel_size=3, stride=1, padding=1,
bias=True, activation='relu', norm='batch'):
super(ResNetBlock, self).__init__()
self.conv1 = torch.nn.Conv2d(num_filter, num_filter, kernel_size,
stride, padding, bias=bias)
self.conv2 = torch.nn.Conv2d(num_filter, num_filter, kernel_size,
stride, padding, bias=bias)
self.norm = norm
if self.norm == 'batch':
self.bn = torch.nn.BatchNorm2d(num_filter)
elif norm == 'instance':
self.bn = torch.nn.InstanceNorm2d(num_filter)
self.activation = activation
if self.activation == 'relu':
self.act = torch.nn.ReLU(True)
elif self.activation == 'prelu':
self.act = torch.nn.PReLU()
elif self.activation == 'lrelu':
self.act = torch.nn.LeakyReLU(0.2, True)
elif self.activation == 'tanh':
self.act = torch.nn.Tanh()
elif self.activation == 'sigmoid':
self.act = torch.nn.Sigmoid()
def forward(self, x):
residual = x
if self.norm is not None:
out = self.bn(self.conv1(x))
else:
out = self.conv1(x)
if self.activation is not None:
out = self.act(out)
if self.norm is not None:
out = self.bn(self.conv2(out))
else:
out = self.conv2(out)
out = torch.add(out, residual)
return out
class PSBlock(torch.nn.Module):
def __init__(self, input_size, output_size, scale_factor, kernel_size=3,
stride=1, padding=1, bias=True, activation='relu', norm='batch'):
super(PSBlock, self).__init__()
self.conv = torch.nn.Conv2d(input_size, output_size * scale_factor**2, kernel_size, stride, padding, bias=bias)
self.ps = torch.nn.PixelShuffle(scale_factor)
self.norm = norm
if self.norm == 'batch':
self.bn = torch.nn.BatchNorm2d(output_size)
elif norm == 'instance':
self.bn = torch.nn.InstanceNorm2d(output_size)
self.activation = activation
if self.activation == 'relu':
self.act = torch.nn.ReLU(True)
elif self.activation == 'prelu':
self.act = torch.nn.PReLU()
elif self.activation == 'lrelu':
self.act = torch.nn.LeakyReLU(0.2, True)
elif self.activation == 'tanh':
self.act = torch.nn.Tanh()
elif self.activation == 'sigmoid':
self.act = torch.nn.Sigmoid()
def forward(self, x):
if self.norm is not None:
out = self.bn(self.ps(self.conv(x)))
else:
out = self.ps(self.conv(x))
if self.activation is not None:
out = self.act(out)
return out
class Upsample2xBlock(torch.nn.Module):
def __init__(self, input_size, output_size, bias=True, upsample='deconv', activation='relu', norm='batch'):
super(Upsample2xBlock, self).__init__()
scale_factor = 2
# 1. Deconvolution (Transposed convolution)
if upsample == 'deconv':
self.upsample = DeconvBlock(input_size, output_size,
kernel_size=4, stride=2, padding=1,
bias=bias, activation=activation, norm=norm)
# 2. Sub-pixel convolution (Pixel shuffler)
elif upsample == 'ps':
self.upsample = PSBlock(input_size, output_size, scale_factor=scale_factor,
bias=bias, activation=activation, norm=norm)
# 3. Resize and Convolution
elif upsample == 'rnc':
self.upsample = torch.nn.Sequential(
torch.nn.Upsample(scale_factor=scale_factor, mode='nearest'),
ConvBlock(input_size, output_size,
kernel_size=3, stride=1, padding=1,
bias=bias, activation=activation, norm=norm)
)
def forward(self, x):
out = self.upsample(x)
return out
[docs]def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
torch.nn.init.kaiming_normal(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Conv2d') != -1:
torch.nn.init.kaiming_normal(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('ConvTranspose2d') != -1:
torch.nn.init.kaiming_normal(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Norm') != -1:
m.weight.data.normal_(1.0, 0.02)
if m.bias is not None:
m.bias.data.zero_()
[docs]def weights_init_xavier(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
torch.nn.init.xavier_uniform(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Conv2d') != -1:
torch.nn.init.xavier_uniform(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('ConvTranspose2d') != -1:
torch.nn.init.xavier_uniform(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Norm') != -1:
m.weight.data.normal_(1.0, 0.02)
if m.bias is not None:
m.bias.data.zero_()
[docs]def shave(imgs, border_size=0):
size = list(imgs.shape)
if len(size) == 4:
shave_imgs = torch.FloatTensor(size[0], size[1], size[2] - border_size * 2, size[3] - border_size * 2)
for i, img in enumerate(imgs):
shave_imgs[i, :, :, :] = img[:, border_size:-border_size, border_size:-border_size]
return shave_imgs
else:
return imgs[:, border_size:-border_size, border_size:-border_size]