import functools
import logging
import typing
import cv2 as cv
import kornia
import numpy as np
import torch
import torch.optim
import thelper.concepts
import thelper.typedefs as typ # noqa: F401
import thelper.utils
from thelper.train.base import Trainer
logger = logging.getLogger(__name__)
[docs]@thelper.concepts.classification
@thelper.concepts.segmentation
class AutoEncoderTrainer(Trainer):
[docs] def __init__(self,
session_name, # type: typing.AnyStr
session_dir, # type: typing.AnyStr
model, # type: thelper.typedefs.ModelType
task, # type: thelper.tasks.Task
loaders, # type: thelper.typedefs.MultiLoaderType
config, # type: thelper.typedefs.ConfigDict
ckptdata=None # type: typing.Optional[thelper.typedefs.CheckpointContentType]
):
"""Receives session parameters, parses image/label keys from task object, and sets up metrics."""
super().__init__(session_name, session_dir, model, task, loaders, config, ckptdata=ckptdata)
assert isinstance(self.task, (thelper.tasks.Classification, thelper.tasks.Segmentation)), \
"expected task to be classification/segmentation only"
self.warned_no_shuffling_augments = False
self.reconstr_display_count = thelper.utils.get_key("reconstr_display_count", config["trainer"])
self.reconstr_display_mean = thelper.utils.get_key("reconstr_display_mean", config["trainer"])
self.reconstr_display_stddev = thelper.utils.get_key("reconstr_display_stddev", config["trainer"])
self.reconstr_scale = thelper.utils.get_key("reconstr_scale", config["trainer"])
self.reconstr_edges_layer = thelper.utils.get_key("reconstr_edges", config["trainer"])
if self.reconstr_edges_layer:
self.reconstr_edges_layer = kornia.filters.SpatialGradient()
self.reconstr_l2_loss, self.reconstr_l1_loss = torch.nn.MSELoss(), torch.nn.L1Loss()
classif_loss_config = thelper.utils.get_key("classif_loss", config["trainer"])
uploader = functools.partial(self._move_tensor, dev=self.devices)
self.classif_loss = thelper.optim.utils.create_loss_fn(classif_loss_config, model, uploader=uploader)
def _to_tensor(self, sample):
"""Fetches and returns tensors of input images and class labels from a batched sample dictionary."""
assert isinstance(sample, dict), "trainer expects samples to come in dicts for key-based usage"
assert self.task.input_key in sample, f"could not find input key '{self.task.input_key}' in sample dict"
input_val, target_val = sample[self.task.input_key].float(), None
if self.task.gt_key in sample and sample[self.task.gt_key] is not None:
gt_tensor = sample[self.task.gt_key]
assert len(gt_tensor) == len(input_val), \
"target tensor should be an array of the same length as input (== batch size)"
if isinstance(gt_tensor, torch.Tensor) and gt_tensor.dtype == torch.int64:
target_val = gt_tensor # shortcut with less checks (dataset is already using tensor'd indices)
else:
if isinstance(self.task, thelper.tasks.Classification):
if self.task.multi_label:
assert isinstance(gt_tensor, torch.Tensor) and \
gt_tensor.shape == (len(input_val), len(self.task.class_names)), \
"gt tensor for multi-label classification should be 2d array (batch size x nbclasses)"
target_val = gt_tensor.float()
else:
target_val = []
for class_name in gt_tensor:
assert isinstance(class_name, (int, torch.Tensor, str)), \
"expected gt tensor to be an array of names (string) or indices (int)"
if isinstance(class_name, (int, torch.Tensor)):
if isinstance(class_name, torch.Tensor):
assert torch.numel(class_name) == 1, "unexpected scalar label, got vector"
class_name = class_name.item()
# dataset must already be using indices, we will forgive this...
assert 0 <= class_name < len(self.task.class_names), \
"class name given as out-of-range index (%d) for class list" % class_name
target_val.append(class_name)
else:
assert class_name in self.task.class_names, \
"got unexpected label '%s' for a sample (unknown class)" % class_name
target_val.append(self.task.class_indices[class_name])
target_val = torch.LongTensor(target_val)
elif isinstance(self.task, thelper.tasks.Segmentation):
assert not isinstance(gt_tensor, list), "unexpected label map type"
if gt_tensor.ndim == 4:
assert gt_tensor.shape[1] == 1, "unexpected channel count (should be index map)"
gt_tensor = gt_tensor.squeeze(1)
target_val = gt_tensor.long() # long instead of bytes to support large/negative values for dontcare
return input_val, target_val
[docs] def train_epoch(self, model, epoch, dev, classif_loss, optimizer, loader, metrics, output_path):
"""Trains the model for a single epoch using the provided objects.
Args:
model: the model to train that is already uploaded to the target device(s).
epoch: the epoch index we are training for (0-based).
dev: the target device that tensors should be uploaded to.
loss: the loss function used to evaluate model fidelity.
optimizer: the optimizer used for back propagation.
loader: the data loader used to get transformed training samples.
metrics: the dictionary of metrics/consumers to update every iteration.
output_path: directory where output files should be written, if necessary.
"""
assert classif_loss is None, "loss function defined by trainer"
assert optimizer is not None, "missing optimizer"
assert loader, "no available data to load"
assert isinstance(metrics, dict), "expect metrics as dict object"
epoch_loss = 0
epoch_size = len(loader)
self.logger.debug("fetching data loader samples...")
for idx, sample in enumerate(loader):
input_val, target_val = self._to_tensor(sample)
input_val_dev = self._move_tensor(input_val, dev)
target_val_dev = self._move_tensor(target_val, dev)
assert target_val is not None, "groundtruth required when training a model"
optimizer.zero_grad()
class_logits, reconstr = model(input_val_dev)
classif_loss = self.classif_loss(class_logits, target_val_dev)
reconstr_loss = self.reconstr_l2_loss(reconstr, input_val_dev)
if self.reconstr_edges_layer:
target_edges_shape = (
reconstr.shape[0],
reconstr.shape[1] * 2, # for gradX/gradY
reconstr.shape[2],
reconstr.shape[3],
)
reconstr_gradients = self.reconstr_edges_layer(reconstr).view(target_edges_shape)
input_gradients = self.reconstr_edges_layer(input_val_dev).view(target_edges_shape)
reconstr_edge_loss = self.reconstr_l1_loss(reconstr_gradients, input_gradients)
reconstr_loss += reconstr_edge_loss
iter_loss = classif_loss + self.reconstr_scale * reconstr_loss
iter_loss.backward()
optimizer.step()
iter_loss = iter_loss.item()
for metric in metrics.values():
metric.update(task=self.task, input=input_val, pred=class_logits,
target=target_val, sample=sample, loss=iter_loss, iter_idx=idx,
max_iters=epoch_size, epoch_idx=epoch, max_epochs=self.epochs,
output_path=output_path)
epoch_loss += iter_loss
epoch_loss /= epoch_size
return epoch_loss
[docs] def eval_epoch(self, model, epoch, dev, loader, metrics, output_path):
"""Evaluates the model using the provided objects.
Args:
model: the model to evaluate that is already uploaded to the target device(s).
epoch: the epoch index we are training for (0-based).
dev: the target device that tensors should be uploaded to.
loader: the data loader used to get transformed valid/test samples.
metrics: the dictionary of metrics/consumers to update every iteration.
output_path: directory where output files should be written, if necessary.
"""
assert loader, "no available data to load"
assert isinstance(metrics, dict), "expect metrics as dict object"
with torch.no_grad():
epoch_size = len(loader)
self.logger.debug("fetching data loader samples...")
display_array = []
for idx, sample in enumerate(loader):
if idx < self.skip_eval_iter:
continue # skip until previous iter count (if set externally; no effect otherwise)
input_val, target_val = self._to_tensor(sample)
input_val_dev = self._move_tensor(input_val, dev)
target_val_dev = self._move_tensor(target_val, dev)
class_logits, reconstr = model(input_val_dev)
classif_loss = self.classif_loss(class_logits, target_val_dev)
reconstr_loss = self.reconstr_l2_loss(reconstr, input_val_dev)
if self.reconstr_edges_layer:
target_edges_shape = (
reconstr.shape[0],
reconstr.shape[1] * 2, # for gradX/gradY
reconstr.shape[2],
reconstr.shape[3],
)
reconstr_gradients = self.reconstr_edges_layer(reconstr).view(target_edges_shape)
input_gradients = self.reconstr_edges_layer(input_val_dev).view(target_edges_shape)
reconstr_edge_loss = self.reconstr_l1_loss(reconstr_gradients, input_gradients)
reconstr_loss += reconstr_edge_loss
iter_loss = (classif_loss + self.reconstr_scale * reconstr_loss).item()
for metric in metrics.values():
metric.update(task=self.task, input=input_val, pred=class_logits,
target=target_val, sample=sample, loss=iter_loss, iter_idx=idx,
max_iters=epoch_size, epoch_idx=epoch, max_epochs=self.epochs,
output_path=output_path)
if self.use_tbx:
if isinstance(self.reconstr_display_mean, str):
display_mean = eval(self.reconstr_display_mean)
else:
display_mean = np.asarray(self.reconstr_display_mean)
if isinstance(self.reconstr_display_stddev, str):
display_stddev = eval(self.reconstr_display_stddev)
else:
display_stddev = np.asarray(self.reconstr_display_stddev)
# make sure not to shuffle if you want to get the same images each epoch...
while len(display_array) < self.reconstr_display_count:
for input_img, reconstr_img in zip(input_val_dev, reconstr):
display = []
for img in [input_img, reconstr_img]:
# move back to HxWxC format
img = np.transpose(img.cpu().numpy(), (1, 2, 0))
# de-normalize w/ provided vals
img = (img * display_stddev) + display_mean
clip_max = display_mean + display_stddev * 3,
clip_min = display_mean - display_stddev * 3
img = (img - clip_min) / (clip_max - clip_min)
img = np.minimum(np.maximum(0, img), 1)
display.append(thelper.draw.get_displayable_image(img))
display_array.append(cv.vconcat(display))
if len(display_array) >= self.reconstr_display_count:
break
if self.use_tbx:
writer_prefix = "epoch/"
file_suffix = f"-{epoch:04d}"
output = {"reconstr/image": cv.hconcat(display_array)}
self._write_data(output, writer_prefix, file_suffix,
self.writers["valid"], self.output_paths["valid"], epoch)