Source code for thelper.transforms.wrappers
"""Transformations wrappers module.
The wrapper classes herein are used to either support inline operations on odd sample types (e.g. lists
of images) or for external libraries (e.g. Augmentor).
"""
import functools
import logging
import random
import numpy as np
import PIL.Image
import torch
import thelper.data
import thelper.utils
logger = logging.getLogger(__name__)
[docs]class AlbumentationsWrapper:
"""Albumentations pipeline wrapper that allows dictionary unpacking.
See https://github.com/albu/albumentations for more information.
Attributes:
pipeline: the augmentor pipeline instance to apply to images.
image_key: the key to fetch images from (when dictionaries are passed in).
bboxes_key: the key to fetch bounding boxes from (when dictionaries are passed in).
mask_key: the key to fetch masks from (when dictionaries are passed in).
keypoints_key: the key to fetch keypoints from (when dictionaries are passed in).
cvt_kpts_to_bboxes: specifies whether keypoints should be converted to bboxes for compatbility.
linked_fate: specifies whether input list samples should all have the same fate or not.
.. seealso::
| :func:`thelper.transforms.utils.load_transforms`
"""
[docs] def __init__(self, transforms, bbox_params=None, add_targets=None, image_key="image",
bboxes_key="bboxes", mask_key="mask", keypoints_key="keypoints", probability=1.0,
cvt_kpts_to_bboxes=False, linked_fate=False):
"""Receives and stores an augmentor pipeline for later use.
The pipeline itself is instantiated in :func:`thelper.transforms.utils.load_transforms`.
"""
assert thelper.utils.check_installed("albumentations"), \
"could not import optional 3rd-party dependency 'albumentations'; make sure you install it first!"
if bbox_params is None or not bbox_params:
bbox_params = {"format": "coco"} # i.e. opencv format (X,Y,W,H)
if add_targets is None:
add_targets = {}
if isinstance(image_key, (list, tuple)):
assert len(image_key) <= 1, "current implementation cannot handle more than one input image key per packet"
image_key = image_key[0]
self.image_key = image_key
assert not (isinstance(bboxes_key, (list, tuple)) or
isinstance(keypoints_key, (list, tuple)) or
isinstance(mask_key, (list, tuple))), \
"bboxes/keypoints/masks keys should never be passed as lists"
self.bboxes_key = bboxes_key
self.mask_key = mask_key
self.keypoints_key = keypoints_key
self.cvt_kpts_to_bboxes = cvt_kpts_to_bboxes
assert not (cvt_kpts_to_bboxes and "format" not in bbox_params or bbox_params["format"] != "coco"), \
"if converting kpts to bboxes, must use coco format"
self.bbox_params = bbox_params
self.linked_fate = linked_fate
import albumentations
self.transforms = transforms
self.add_targets = add_targets
self.probability = probability
self.pipeline = albumentations.Compose(transforms, bbox_params=self.bbox_params,
additional_targets=add_targets, p=probability)
[docs] def __call__(self, sample, force_linked_fate=False, op_seed=None):
"""Transforms a (dict) sample, a single image, or a list of images using the augmentor pipeline.
Args:
sample: the sample or image(s) to transform (can also contain embedded lists/tuples of images).
force_linked_fate: override flag for recursive use allowing forced linking of arrays.
op_seed: seed to set before calling the wrapped operation.
Returns:
The transformed image(s), with the same list/tuple formatting as the input.
"""
# todo: add list unwrapping/interlacing support like in other wrappers?
params = {}
unpack_bboxes, decode_bboxes = False, False
if isinstance(sample, dict):
assert self.image_key in sample, \
f"image is missing from sample (key={self.image_key}) but it is mandatory"
image = sample[self.image_key]
if isinstance(image, (list, tuple)):
raise NotImplementedError
# impl should use linked_fate and force_linked_fate
params["image"] = sample[self.image_key]
if self.keypoints_key in sample and sample[self.keypoints_key] is not None:
keypoints = sample[self.keypoints_key]
if self.cvt_kpts_to_bboxes:
assert self.bboxes_key not in sample, \
"trying to override bboxes w/ keypoints while bboxes already exist"
# fake x,y,w,h,c format (w/ labels)
msize = params["image"].shape
params["bboxes"] = [[min(max(kp[0], 0), msize[1] - 1),
min(max(kp[1], 0), msize[0] - 1), 1, 1, 0] for kp in keypoints]
else:
params["keypoints"] = keypoints
if self.bboxes_key in sample and sample[self.bboxes_key] is not None:
bboxes = sample[self.bboxes_key]
if isinstance(bboxes, thelper.data.BoundingBox):
bboxes = [bboxes]
unpack_bboxes = True
if isinstance(bboxes, list) and all([isinstance(bbox, thelper.data.BoundingBox) for bbox in bboxes]):
assert self.bbox_params["format"] in ["coco", "pascal_voc"], "unsupported/unknown bbox format"
bboxes = [bbox.encode(format=self.bbox_params["format"]) for bbox in bboxes]
decode_bboxes = True
params["bboxes"] = bboxes
else:
params["bboxes"] = []
if self.mask_key in sample and sample[self.mask_key] is not None:
params["mask"] = sample[self.mask_key]
output = self.pipeline(**params)
sample[self.image_key] = output["image"]
if "keypoints" in output:
sample[self.keypoints_key] = output["keypoints"]
if "bboxes" in output:
if self.cvt_kpts_to_bboxes:
sample[self.keypoints_key] = [[kp[0], kp[1]] for kp in output["bboxes"]]
else:
bboxes = output["bboxes"]
if decode_bboxes:
bboxes = [thelper.data.BoundingBox.decode(bbox, self.bbox_params["format"]) for bbox in bboxes]
sample[self.bboxes_key] = bboxes[0] if unpack_bboxes else bboxes
if "mask" in output:
sample[self.mask_key] = output["mask"]
return sample
elif isinstance(sample, (list, tuple)):
raise NotImplementedError
# impl should use linked_fate and force_linked_fate
else:
assert sample is None or isinstance(sample, np.ndarray)
if sample is None:
return None
params["image"] = sample
output = self.pipeline(**params)
return output["image"]
def __repr__(self):
"""Create a print-friendly representation of inner augmentation stages."""
# for debug purposes only, transforms probably cannot be expressed as a string
return self.__class__.__module__ + "." + self.__class__.__qualname__ + \
f"(transforms={repr(self.transforms)}, " + \
f"bbox_params={repr(self.bbox_params)}, add_targets={repr(self.add_targets)}, " + \
f"image_key={repr(self.image_key)}, bboxes_key={repr(self.bboxes_key)}, " + \
f"mask_key={repr(self.mask_key)}, keypoints_key={repr(self.keypoints_key)}, " + \
f"probability={repr(self.probability)}, cvt_kpts_to_bboxes={repr(self.cvt_kpts_to_bboxes)}, " + \
f"linked_fate={repr(self.linked_fate)})"
# noinspection PyMethodMayBeStatic
[docs] def set_seed(self, seed):
"""Sets the internal seed to use for stochastic ops."""
if self.pipeline.transforms is not None:
random.random(seed)
np.random.seed(seed)
for t in self.pipeline.transforms:
if hasattr(t, "set_seed") and callable(t.set_seed):
t.set_seed(seed)
[docs] def set_epoch(self, epoch=0):
"""Sets the current epoch number in order to change the behavior of some suboperations."""
assert isinstance(epoch, int) and epoch >= 0, "invalid epoch value"
if self.pipeline.transforms is not None:
for t in self.pipeline.transforms:
if hasattr(t, "set_epoch") and callable(t.set_epoch):
t.set_epoch(epoch)
[docs]class AugmentorWrapper:
"""Augmentor pipeline wrapper that allows pickling and multi-threading.
See https://github.com/mdbloice/Augmentor for more information. This wrapper was last updated to work
with version 0.2.2 --- more recent versions introduced yet unfixed (as of 2018/08) issues on some platforms.
All original transforms are supported here. This wrapper also fixes the list output bug for single-image
samples when using operations individually.
Attributes:
pipeline: the augmentor pipeline instance to apply to images.
target_keys: the sample keys to apply the pipeline to (when dictionaries are passed in).
linked_fate: specifies whether input list samples should all have the same fate or not.
.. seealso::
| :func:`thelper.transforms.utils.load_transforms`
"""
[docs] def __init__(self, pipeline, target_keys=None, linked_fate=True):
"""Receives and stores an augmentor pipeline for later use.
The pipeline itself is instantiated in :func:`thelper.transforms.utils.load_transforms`.
"""
self.pipeline = pipeline
self.target_keys = target_keys
self.linked_fate = linked_fate
[docs] def __call__(self, sample, force_linked_fate=False, op_seed=None, in_cvts=None):
"""Transforms a (dict) sample, a single image, or a list of images using the augmentor pipeline.
Args:
sample: the sample or image(s) to transform (can also contain embedded lists/tuples of images).
force_linked_fate: override flag for recursive use allowing forced linking of arrays.
op_seed: seed to set before calling the wrapped operation.
in_cvts: holds the input conversion flag array (for recursive usage).
Returns:
The transformed image(s), with the same list/tuple formatting as the input.
"""
if isinstance(sample, dict):
# recursive call for unpacking sample content w/ target keys
assert in_cvts is None, "top-level call should never provide in_cvts"
# capture non-scalar objects (according to numpy) if no keys are provided
key_vals = [(k, v) for k, v in sample.items() if (
(self.target_keys is None and not np.isscalar(v)) or
(self.target_keys is not None and k in self.target_keys))]
keys, vals = map(list, zip(*key_vals))
lengths = [len(v) if isinstance(v, (list, tuple)) else -1 for v in vals]
if len(lengths) > 0 and all(n == lengths[0] for n in lengths) and lengths[0] > 0:
# interlace input lists for internal linked fate (if needed; otherwise, it won't change anything)
vals = [[v[idx] if isinstance(v, (list, tuple)) else
v[idx, ...] for v in vals] for idx in range(lengths[0])]
vals = self(vals, force_linked_fate=force_linked_fate, op_seed=op_seed, in_cvts=in_cvts)
assert isinstance(vals, list) and len(vals) == lengths[0], "messed up something internally"
out_vals = [[v] for v in vals[0]] if isinstance(vals[0], list) else [[vals[0]]]
for idx1 in range(1, lengths[0]):
for idx2 in range(len(out_vals)):
out_vals[idx2].append(vals[idx1][idx2] if isinstance(vals[idx1], list) else vals[idx1])
vals = out_vals
else:
vals = self(vals, force_linked_fate=force_linked_fate, op_seed=op_seed, in_cvts=in_cvts)
sample = {k: vals[keys.index(k)] if k in keys else sample[k] for k in sample}
return sample
out_cvts = in_cvts is not None
out_list = isinstance(sample, (list, tuple))
if sample is None or (out_list and not sample):
return ([], []) if out_cvts else []
elif not out_list:
sample = [sample]
assert not any([isinstance(v, dict) for v in sample]), \
"augmentor wrapper cannot handle sample-in-sample (or dict-in-list) inputs"
skip_unpack = in_cvts is not None and isinstance(in_cvts, bool) and in_cvts
if self.linked_fate or force_linked_fate: # process all content with the same operations below
if not skip_unpack:
# noinspection PyProtectedMember
sample, cvts = TransformWrapper._unpack(sample, convert_pil=True)
if not isinstance(sample, (list, tuple)):
sample = [sample]
cvts = [cvts]
else:
cvts = in_cvts
if op_seed is None:
op_seed = np.random.randint(np.iinfo(np.int32).max)
np.random.seed(op_seed)
prev_state = np.random.get_state()
for idx, _ in enumerate(sample):
if not isinstance(sample[idx], PIL.Image.Image):
sample[idx], cvts[idx] = self(sample[idx], force_linked_fate=True,
op_seed=op_seed, in_cvts=cvts[idx])
else:
np.random.set_state(prev_state)
random.seed(np.random.randint(np.iinfo(np.int32).max))
for operation in self.pipeline.operations:
r = round(np.random.uniform(0, 1), 1)
if r <= operation.probability:
if sample[idx] is not None:
sample[idx] = operation.perform_operation([sample[idx]])[0]
else: # each element of the top array will be processed independently below (current seeds are kept)
cvts = [False] * len(sample)
for idx, _ in enumerate(sample):
# noinspection PyProtectedMember
sample[idx], cvts[idx] = TransformWrapper._unpack(sample[idx], convert_pil=True)
if not isinstance(sample[idx], PIL.Image.Image):
sample[idx], cvts[idx] = self(sample[idx], force_linked_fate=True,
op_seed=op_seed, in_cvts=cvts[idx])
else:
random.seed(np.random.randint(np.iinfo(np.int32).max))
for operation in self.pipeline.operations:
r = round(np.random.uniform(0, 1), 1)
if r <= operation.probability:
if sample[idx] is not None:
sample[idx] = operation.perform_operation([sample[idx]])[0]
# noinspection PyProtectedMember
sample, cvts = TransformWrapper._pack(sample, cvts, convert_pil=True)
assert len(sample) == len(cvts), "messed up packing/unpacking logic"
if (skip_unpack or not out_list) and len(sample) == 1:
sample = sample[0]
cvts = cvts[0]
return (sample, cvts) if out_cvts else sample
def __repr__(self):
"""Create a print-friendly representation of inner augmentation stages."""
# for debug purposes only, pipeline probably cannot be expressed as a string
return self.__class__.__module__ + "." + self.__class__.__qualname__ + \
f"(pipeline={repr(self.pipeline)}, target_keys={repr(self.target_keys)}, linked_fate={repr(self.linked_fate)})"
# noinspection PyMethodMayBeStatic
[docs] def set_seed(self, seed):
"""Sets the internal seed to use for stochastic ops."""
np.random.seed(seed)
[docs] def set_epoch(self, epoch=0):
"""Sets the current epoch number in order to change the behavior of some suboperations."""
assert isinstance(epoch, int) and epoch >= 0, "invalid epoch value"
if self.pipeline.operations is not None:
for op in self.pipeline.operations:
if hasattr(op, "set_epoch") and callable(op.set_epoch):
op.set_epoch(epoch)
[docs]class TransformWrapper:
"""Transform wrapper that allows operations on samples, lists, tuples, and single elements.
Can be used to wrap the operations in ``thelper.transforms`` or in ``torchvision.transforms``
that only accept array-like objects as input. Will optionally force-convert content to PIL images.
Can also be used to transform a list/tuple of images uniformly based on a shared dice roll, or
to ensure that each image is transformed independently.
.. warning::
Stochastic transforms (e.g. ``torchvision.transforms.RandomCrop``) will always treat each
image in a list differently. If the same operations are to be applied to all images, you
should consider using a series non-stochastic operations wrapped inside an instance of
``torchvision.transforms.RandomApply``, or simply provide the probability of applying the
transforms to this wrapper's constructor.
Attributes:
operation: the wrapped operation (callable object or class name string to import).
params: the parameters that are passed to the operation when init'd or called.
probability: the probability that the wrapped operation will be applied.
convert_pil: specifies whether images should be converted into PIL format or not.
target_keys: the sample keys to apply the transform to (when dictionaries are passed in).
linked_fate: specifies whether images given in a list/tuple should have the same fate or not.
"""
[docs] def __init__(self, operation, params=None, probability=1, convert_pil=False, target_keys=None, linked_fate=True):
"""Receives and stores a torchvision transform operation for later use.
If the operation is given as a string, it is assumed to be a class name and it will
be imported. The parameters (if any) will then be given to the constructor of that
class. Otherwise, the operation is assumed to be a callable object, and its parameters
(if any) will be provided at call-time.
Args:
operation: the wrapped operation (callable object or class name string to import).
params: the parameters that are passed to the operation when init'd or called.
probability: the probability that the wrapped operation will be applied.
convert_pil: specifies whether images should be forced into PIL format or not.
target_keys: the sample keys to apply the pipeline to (when dictionaries are passed in).
linked_fate: specifies whether images given in a list/tuple should have the same fate or not.
"""
assert params is None or isinstance(params, dict), "expected params to be passed in as a dictionary"
assert 0 <= probability <= 1, "invalid probability value (range is [0,1]"
self.params = {} if params is None else params
self.operation = operation
if isinstance(self.operation, str):
operation_type = thelper.utils.import_class(operation)
self.opcall = operation_type(**self.params)
else:
self.opcall = functools.partial(operation, **self.params)
self.probability = probability
self.convert_pil = convert_pil
self.target_keys = target_keys
self.linked_fate = linked_fate
@staticmethod
def _unpack(sample, force_flatten=False, convert_pil=False):
if isinstance(sample, (list, tuple)):
if len(sample) > 1:
if not force_flatten:
return sample, [False] * len(sample)
flat_samples = []
cvts = []
for s in sample:
out, cvt = TransformWrapper._unpack(s, force_flatten=force_flatten)
if isinstance(cvt, (list, tuple)):
assert isinstance(out, (list, tuple)), "unexpected out/cvt types"
flat_samples += list(out)
cvts += list(cvt)
else:
flat_samples.append(out)
cvts.append(cvt)
return flat_samples, cvts
else:
sample = sample[0]
if convert_pil:
if isinstance(sample, torch.Tensor):
sample = sample.numpy()
if isinstance(sample, np.ndarray) and sample.ndim > 2 and \
sample.shape[-1] > 1 and (sample.dtype != np.uint8):
# PIL images cannot handle multi-channel non-byte arrays; we handle these manually
flat_samples = []
for c in range(sample.shape[-1]):
flat_samples.append(PIL.Image.fromarray(sample[..., c]))
return flat_samples, True # this is the only case where an array can be paired with a single cvt flag
else:
out = PIL.Image.fromarray(np.squeeze(sample))
return out, True
return sample, False
@staticmethod
def _pack(samples, cvts, convert_pil=False):
if not isinstance(samples, (list, tuple)) or not isinstance(cvts, (list, tuple)) or len(samples) != len(cvts):
assert convert_pil and isinstance(cvts, bool) and cvts, \
"unexpected cvts len w/ pil conversion (bad logic somewhere)"
assert all([isinstance(s, PIL.Image.Image) for s in samples]), "unexpected packed list sample types"
samples = [np.asarray(s) for s in samples]
assert all([s.ndim == 2 for s in samples]), "unexpected packed list sample depths"
samples = [np.expand_dims(s, axis=2) for s in samples]
return [np.concatenate(samples, axis=2)], [False]
for idx, cvt in enumerate(cvts):
if not isinstance(cvt, (list, tuple)):
assert isinstance(cvt, bool), "unexpected cvt type"
if cvt:
assert not isinstance(samples[idx], (list, tuple)), "unexpected packed sample type"
samples[idx] = np.asarray(samples[idx])
cvts[idx] = False
return samples, cvts
[docs] def __call__(self, sample, force_linked_fate=False, op_seed=None, in_cvts=None):
"""Transforms a (dict) sample, a single image, or a list of images using a wrapped operation.
Args:
sample: the sample or image(s) to transform (can also contain embedded lists/tuples of images).
force_linked_fate: override flag for recursive use allowing forced linking of arrays.
op_seed: seed to set before calling the wrapped operation.
in_cvts: holds the input conversion flag array (for recursive usage).
Returns:
The transformed image(s), with the same list/tuple formatting as the input.
"""
if isinstance(sample, dict):
# recursive call for unpacking sample content w/ target keys
assert in_cvts is None, "top-level call should never provide in_cvts"
# capture non-scalar objects (according to numpy) if no keys are provided
key_vals = [(k, v) for k, v in sample.items() if (
(self.target_keys is None and not thelper.utils.is_scalar(v)) or
(self.target_keys is not None and k in self.target_keys))]
keys, vals = map(list, zip(*key_vals))
lengths = [len(v) if isinstance(v, (list, tuple)) else -1 for v in vals]
if len(lengths) > 0 and all(n == lengths[0] for n in lengths) and lengths[0] > 0:
# interlace input lists for internal linked fate (if needed; otherwise, it won't change anything)
vals = [[v[idx] if isinstance(v, (list, tuple)) else
v[idx, ...] for v in vals] for idx in range(lengths[0])]
vals = self(vals, force_linked_fate=force_linked_fate, op_seed=op_seed, in_cvts=in_cvts)
assert isinstance(vals, list) and len(vals) == lengths[0], "messed up something internally"
out_vals = [[v] for v in vals[0]] if isinstance(vals[0], list) else [[vals[0]]]
for idx1 in range(1, lengths[0]):
for idx2 in range(len(out_vals)):
out_vals[idx2].append(vals[idx1][idx2] if isinstance(vals[idx1], list) else vals[idx1])
vals = out_vals
else:
vals = self(vals, force_linked_fate=force_linked_fate, op_seed=op_seed, in_cvts=in_cvts)
sample = {k: vals[keys.index(k)] if k in keys else sample[k] for k in sample}
return sample
out_cvts = in_cvts is not None
out_list = isinstance(sample, (list, tuple))
if sample is None or (out_list and not sample):
return ([], []) if out_cvts else []
elif not out_list:
sample = [sample]
assert not any([isinstance(v, dict) for v in sample]), \
"sample transform wrapper cannot handle sample-in-sample (or dict-in-list) inputs"
skip_unpack = in_cvts is not None and isinstance(in_cvts, bool) and in_cvts
if self.linked_fate or force_linked_fate: # process all content with the same operations below
if not skip_unpack:
sample, cvts = self._unpack(sample, convert_pil=self.convert_pil)
if not isinstance(sample, (list, tuple)):
sample = [sample]
cvts = [cvts]
else:
cvts = in_cvts
if self.probability >= 1 or round(np.random.uniform(0, 1), 1) <= self.probability:
if op_seed is None:
op_seed = np.random.randint(np.iinfo(np.int32).max)
for idx, _ in enumerate(sample):
if isinstance(sample[idx], (list, tuple)):
sample[idx], cvts[idx] = self(sample[idx], force_linked_fate=True,
op_seed=op_seed, in_cvts=cvts[idx])
else:
if hasattr(self.opcall, "set_seed") and callable(self.opcall.set_seed):
self.opcall.set_seed(op_seed)
# watch out: if operation is stochastic and we cannot seed above, then there is no
# guarantee that the content will truly have a 'linked fate' (this might cause issues!)
if sample[idx] is not None:
sample[idx] = self.opcall(sample[idx])
else: # each element of the top array will be processed independently below (current seeds are kept)
cvts = [False] * len(sample)
for idx, _ in enumerate(sample):
sample[idx], cvts[idx] = self._unpack(sample[idx], convert_pil=self.convert_pil)
if self.probability >= 1 or round(np.random.uniform(0, 1), 1) <= self.probability:
if isinstance(sample[idx], (list, tuple)):
# we will now force fate linkage for all sub-elements of this array
sample[idx], cvts[idx] = self(sample[idx], force_linked_fate=True,
op_seed=op_seed, in_cvts=cvts[idx])
else:
if sample[idx] is not None:
sample[idx] = self.opcall(sample[idx])
sample, cvts = TransformWrapper._pack(sample, cvts, convert_pil=self.convert_pil)
assert len(sample) == len(cvts), "messed up packing/unpacking logic"
if (skip_unpack or not out_list) and len(sample) == 1:
sample = sample[0]
cvts = cvts[0]
return (sample, cvts) if out_cvts else sample
def __repr__(self):
"""Create a print-friendly representation of inner augmentation stages."""
return self.__class__.__module__ + "." + self.__class__.__qualname__ + \
f"(operation={repr(self.operation)}, params={repr(self.params)}, probability={repr(self.probability)}, " + \
f"convert_pil={repr(self.convert_pil)}, target_keys={repr(self.target_keys)}, linked_fate={repr(self.linked_fate)})"
# noinspection PyMethodMayBeStatic
[docs] def set_seed(self, seed):
"""Sets the internal seed to use for stochastic ops."""
np.random.seed(seed)
[docs] def set_epoch(self, epoch=0):
"""Sets the current epoch number in order to change the behavior of some suboperations."""
assert isinstance(epoch, int) and epoch >= 0, "invalid epoch value"
if hasattr(self.opcall, "set_epoch") and callable(self.opcall.set_epoch):
self.opcall.set_epoch(epoch)