Source code for thelper.tasks.detect

"""Detection task interface module.

This module contains classes that define object detection utilities and task interfaces.
"""
import logging
from typing import List, Optional, Tuple, Union  # noqa: F401

import numpy as np
import torch
import tqdm

import thelper.concepts
import thelper.utils
from thelper.ifaces import ClassNamesHandler
from thelper.tasks.regr import Regression
from thelper.tasks.utils import Task

logger = logging.getLogger(__name__)


[docs]@thelper.concepts.detection class BoundingBox: """Interface used to hold instance metadata for object detection tasks. Object detection trainers and display utilities in the framework will expect this interface to be used when parsing a predicted detection or an annotation. The base contents are based on the PASCALVOC metadata structure, and this class can be derived if necessary to contain more metadata. Attributes: class_id: type identifier for the underlying object instance. bbox: four-element tuple holding the (xmin,ymin,xmax,ymax) bounding box parameters. include_margin: defines whether xmax/ymax is included in the bounding box area or not. difficult: defines whether this instance is considered "difficult" (false by default). occluded: defines whether this instance is considered "occluded" (false by default). truncated: defines whether this instance is considered "truncated" (false by default). iscrowd: defines whether this instance covers a "crowd" of objects or not (false by default). confidence: scalar or array of prediction confidence values tied to class types (empty by default). image_id: string used to identify the image containing this bounding box (i.e. file path or uuid). task: reference to the task object that holds extra metadata regarding the content of the bbox (None by default). .. seealso:: | :class:`thelper.tasks.utils.Task` | :class:`thelper.tasks.detect.Detection` """
[docs] def __init__(self, class_id, bbox, include_margin=True, difficult=False, occluded=False, truncated=False, iscrowd=False, confidence=None, image_id=None, task=None): """Receives and stores low-level input detection metadata for later access.""" self.class_id = class_id # should be string or int to allow batching in data loaders # note: the input bbox is expected to be a 4 element array (xmin,ymin,xmax,ymax) self.include_margin = include_margin self.bbox = bbox self.difficult = difficult self.occluded = occluded self.truncated = truncated self.iscrowd = iscrowd self.confidence = confidence self.image_id = image_id # should be string that identifies the associated image (file path or uuid) self.task = task
@property def class_id(self): # type: () -> Union[int, str] """Returns the object class type identifier.""" return self._class_id @class_id.setter def class_id(self, value): # type: (Union[int, str]) -> None """Sets the object class type identifier (should be string/int).""" assert isinstance(value, (int, str)), "class should be defined as integer (index) or string (name)" self._class_id = value @property def bbox(self): """Returns the bounding box tuple :math:`(x_min,y_min,x_max,y_max)`.""" return self._bbox @bbox.setter def bbox(self, value): """Sets the bounding box tuple :math:`(x_min,y_min,x_max,y_max)`.""" assert isinstance(value, (list, tuple, np.ndarray, torch.Tensor)) and len(value) == 4, "invalid input type/len" assert not isinstance(value, (list, tuple)) or all([isinstance(v, (int, float)) for v in value]), \ "input bbox values must be integer/float" assert value[0] <= value[2] and value[1] <= value[3], "invalid min/max values for bbox coordinates" assert not self.include_margin or not any([isinstance(v, float) for v in value]), \ "it makes no sense to include xmax/ymax margin if using floating point coordinates" self._bbox = value @property def left(self): """Returns the left bounding box edge origin offset value.""" return self._bbox[0] @left.setter def left(self, value): """Sets the left bounding box edge origin offset value.""" self._bbox[0] = value @property def top(self): """Returns the top bounding box edge origin offset value.""" return self._bbox[1] @top.setter def top(self, value): """Sets the top bounding box edge origin offset value.""" self._bbox[1] = value @property def top_left(self): """Returns the top left bounding box corner coordinates :math:`(x,y)`.""" return self._bbox[0], self._bbox[1] @top_left.setter def top_left(self, value): """Sets the top left bounding box corner coordinates :math:`(x,y)`.""" assert isinstance(value, (list, tuple, np.ndarray, torch.Tensor)) and len(value) == 2, "invalid input type/len" self._bbox[0], self._bbox[1] = value[0], value[1] @property def right(self): """Returns the right bounding box edge origin offset value.""" return self._bbox[2] @right.setter def right(self, value): """Sets the right bounding box edge origin offset value.""" self._bbox[2] = value @property def bottom(self): """Returns the bottom bounding box edge origin offset value.""" return self._bbox[3] @bottom.setter def bottom(self, value): """Sets the bottom bounding box edge origin offset value.""" self._bbox[3] = value @property def bottom_right(self): """Returns the bottom right bounding box corner coordinates :math:`(x,y)`.""" return self._bbox[2], self._bbox[3] @bottom_right.setter def bottom_right(self, value): """Sets the bottom right bounding box corner coordinates :math:`(x,y)`.""" assert isinstance(value, (list, tuple, np.ndarray, torch.Tensor)) and len(value) == 2, "invalid input type/len" assert value[0] >= self._bbox[0] and value[1] >= self._bbox[1] self._bbox[2], self._bbox[3] = value[0], value[1] @property def width(self): """Returns the width of the bounding box.""" return (self._bbox[2] - self._bbox[0]) + 1 if self.include_margin else 0 @property def height(self): """Returns the height of the bounding box.""" return (self._bbox[3] - self._bbox[1]) + 1 if self.include_margin else 0 @property def centroid(self, floor=False): """Returns the bounding box centroid coordinates :math:`(x,y)`.""" if self.include_margin: if floor: return (self._bbox[0] + self._bbox[2] + 1) // 2, (self._bbox[1] + self._bbox[3] + 1) // 2 return (self._bbox[0] + self._bbox[2] + 1) / 2, (self._bbox[1] + self._bbox[3] + 1) / 2 else: if floor: return (self._bbox[0] + self._bbox[2]) // 2, (self._bbox[1] + self._bbox[3]) // 2 return (self._bbox[0] + self._bbox[2]) / 2, (self._bbox[1] + self._bbox[3]) / 2 @property def include_margin(self): """Returns whether :math:`x_max` and :math:`y_max` are included in the bounding box area or not""" return self._include_margin @include_margin.setter def include_margin(self, value): """Sets whether :math:`x_max` and :math:`y_max` are is included in the bounding box area or not""" assert isinstance(value, (int, bool)), "flag type must be integer or boolean" self._include_margin = value @property def difficult(self): """Returns whether this bounding box is considered *difficult* by the dataset (false by default).""" return self._difficult @difficult.setter def difficult(self, value): """Sets whether this bounding box is considered *difficult* by the dataset.""" assert isinstance(value, (int, bool)), "flag type must be integer or boolean" self._difficult = value @property def occluded(self): """Returns whether this bounding box is considered *occluded* by the dataset (false by default).""" return self._occluded @occluded.setter def occluded(self, value): """Sets whether this bounding box is considered *occluded* by the dataset.""" assert isinstance(value, (int, bool)), "flag type must be integer or boolean" self._occluded = value @property def truncated(self): """Returns whether this bounding box is considered *truncated* by the dataset (false by default).""" return self._truncated @truncated.setter def truncated(self, value): """Sets whether this bounding box is considered *truncated* by the dataset.""" assert isinstance(value, (int, bool)), "flag type must be integer or boolean" self._truncated = value @property def iscrowd(self): """Returns whether this instance covers a *crowd* of objects or not (false by default).""" return self._iscrowd @iscrowd.setter def iscrowd(self, value): """Sets whether this instance covers a *crowd* of objects or not.""" assert isinstance(value, (int, bool)), "flag type must be integer or boolean" self._iscrowd = value @property def area(self): """Returns a scalar indicating the total surface of the annotation (may be None if unknown/unspecified).""" return self.width * self.height @property def confidence(self): """Returns the confidence value (or array of confidence values) associated to the predicted class types.""" return self._confidence @confidence.setter def confidence(self, value): """Sets the confidence value (or array of confidence values) associated to the predicted class types.""" assert value is None or isinstance(value, (float, list, np.ndarray, torch.Tensor)), "value should be float/list/ndarray/tensor" self._confidence = value @property def image_id(self): """Returns the image string identifier.""" return self._image_id @image_id.setter def image_id(self, value): """Sets the image string identifier.""" assert value is None or isinstance(value, (str, int)), "image identifier should be a string/int (file path or uuid)" self._image_id = value @property def task(self): """Returns the reference to the task object that holds extra metadata regarding the content of the bbox.""" return self._task @task.setter def task(self, value): """Sets the reference to the task object that holds extra metadata regarding the content of the bbox.""" if value is not None: assert isinstance(value, Detection), "task should be detection-related" assert self.class_id in value.class_indices.values(), f"cannot find class_id '{self.class_id}' in task indices" self._task = value
[docs] def encode(self, format=None): """Returns a vectorizable representation of this bounding box in a specified format. WARNING: Encoding might cause information loss (e.g. task reference is discarded). """ if format == "coco": return [*self.top_left, self.width, self.height, self.class_id] elif format == "pascal_voc": return [*self.top_left, *self.bottom_right, self.class_id, self.difficult, self.occluded, self.truncated] else: assert format is None, "unrecognized/unknown encoding format" vec = [*self.bbox, self.class_id, self.include_margin, self.difficult, self.occluded, self.truncated, self.iscrowd, self.image_id] if self.confidence is not None: vec += [self.confidence] if isinstance(self.confidence, float) else [*self.confidence] return vec
[docs] @staticmethod def decode(vec, format=None): """Returns a BoundingBox object from a vectorized representation in a specified format. .. note:: The input bbox is expected to be a 4 element array :math:`(x_min,y_min,x_max,y_max)`. """ if format == "coco": assert len(vec) == 5, "unexpected vector length (should contain 5 values)" return BoundingBox(class_id=vec[4], bbox=[vec[0], vec[1], vec[0] + vec[2], vec[1] + vec[3]]) elif format == "pascal_voc": assert len(vec) == 8, "unexpected vector length (should contain 8 values)" return BoundingBox(class_id=vec[4], bbox=vec[0:4], difficult=vec[5], occluded=vec[6], truncated=vec[7]) else: assert format is None, "unrecognized/unknown encoding format" assert len(vec) >= 12, "unexpected vector length (should contain 9 values or more)" return BoundingBox(class_id=vec[4], bbox=vec[0:4], include_margin=vec[5], difficult=vec[6], occluded=vec[7], truncated=vec[8], iscrowd=vec[9], confidence=(None if len(vec) == 11 else vec[11:]), image_id=vec[10], task=None)
[docs] def intersects(self, geom): """Returns whether the bounding box intersects a geometry (i.e. a 2D point or another bbox).""" assert isinstance(geom, (tuple, list, np.ndarray, BoundingBox)), "unexpected input geometry type" if isinstance(geom, (tuple, list, np.ndarray)): # check intersection with point assert len(geom) == 2, "point should be given as list of coordinates (x,y)" return self._bbox[0] <= geom[0] <= self.bbox[2] and self._bbox[1] <= geom[1] <= self.bbox[3] else: return not (self._bbox[0] > geom._bbox[2] or geom._bbox[0] > self._bbox[2] or self._bbox[3] < geom._bbox[1] or geom._bbox[3] < self._bbox[1])
[docs] def totuple(self): # type: () -> Tuple[int, int, int, int] """Gets a ``tuple`` representation of the underlying bounding box tuple :math:`(x_min,y_min,x_max,y_max)`. This ensures that ``Tensor`` objects are converted to native *Python* types.""" return tuple(self.tolist())
[docs] def tolist(self): # type: () -> List[int] """Gets a ``list`` representation of the underlying bounding box tuple :math:`(x_min,y_min,x_max,y_max)`. This ensures that ``Tensor`` objects are converted to native *Python* types.""" return self._bbox.tolist() if isinstance(self._bbox, torch.Tensor) else list(self._bbox)
[docs] def json(self): # type: () -> thelper.typedefs.JSON """Gets a JSON-serializable representation of the bounding box parameters.""" return { "class_id": self.class_id, "image_id": self.image_id, "bbox": self.tolist(), "confidence": self.confidence, "include_margin": self.include_margin, "difficult": self.difficult, "occluded": self.occluded, "truncated": self.truncated, "is_crowd": self.iscrowd, }
def __repr__(self): """Creates a print-friendly representation of the object detection bbox instance.""" # note: we do not export the task reference here (it might be too heavy for logs) return self.__class__.__module__ + "." + self.__class__.__qualname__ + \ f"(class_id={repr(self.class_id)}, bbox={repr(self.bbox)}, include_margin={repr(self.include_margin)}, " + \ f"difficult={repr(self.difficult)}, occluded={repr(self.occluded)}, truncated={repr(self.truncated)}, " + \ f"iscrowd={repr(self.iscrowd)}, confidence={repr(self.confidence)}, image_id={repr(self.image_id)})"
[docs]@thelper.concepts.detection class Detection(Regression, ClassNamesHandler): """Interface for object detection tasks. This specialization requests that when given an input image, the trained model should provide a list of bounding box (bbox) proposals that correspond to probable objects detected in the image. This specialized regression interface is currently used to help display functions. Attributes: class_names: map of class name-value pairs for object types to detect. input_key: the key used to fetch input tensors from a sample dictionary. bboxes_key: the key used to fetch target (groundtruth) bboxes from a sample dictionary. meta_keys: the list of extra keys provided by the data parser inside each sample. input_shape: a numpy-compatible shape to expect input images to possess. target_shape: a numpy-compatible shape to expect the predictions to be in. target_min: a 2-dim tensor containing minimum (x,y) bounding box corner values. target_max: a 2-dim tensor containing maximum (x,y) bounding box corner values. background: value of the 'background' label (if any) used in the class map. color_map: map of class name-color pairs to use when displaying results. .. seealso:: | :class:`thelper.tasks.utils.Task` | :class:`thelper.tasks.regr.Regression` | :class:`thelper.train.detect.ObjDetectTrainer` """
[docs] def __init__(self, class_names, input_key, bboxes_key, meta_keys=None, input_shape=None, target_shape=None, target_min=None, target_max=None, background=None, color_map=None): """Receives and stores the bbox types to detect, the input tensor key, the groundtruth bboxes list key, the extra (meta) keys produced by the dataset parser(s), and the color map used to color bboxes when displaying results. The class names can be provided as a list of strings, as a path to a json file that contains such a list, or as a map of predefined name-value pairs to use in gt maps. This list/map must contain at least two elements (background and one class). All other arguments are used as-is to index dictionaries, and must therefore be key- compatible types. """ super(Detection, self).__init__(input_key, bboxes_key, meta_keys, input_shape=input_shape, target_shape=target_shape, target_min=target_min, target_max=target_max) ClassNamesHandler.__init__(self, class_names=class_names) if background is not None: background = None if "background" not in self.class_indices else self.class_indices["background"] self.background = background self.color_map = color_map
@property def background(self): """Returns the 'background' label value used in loss functions (can be ``None``).""" return self._background @background.setter def background(self, background): """Sets the 'background' label value for this segmentation task (can be ``None``).""" if background is not None: assert isinstance(background, int), "'background' value should be integer (index)" assert background not in self.class_indices.values() or self.class_indices["background"] == background, \ "found 'background' value tied to another class label" self._background = background @property def color_map(self): """Returns the color map used to swap label indices for colors when displaying results.""" return self._color_map @color_map.setter def color_map(self, color_map): """Sets the color map used to swap label indices for colors when displaying results.""" if color_map is not None: assert isinstance(color_map, dict), "color map should be given as dictionary" self._color_map = {} assert all([isinstance(k, int) for k in color_map]) or all([isinstance(k, str) for k in color_map]), \ "color map keys should be only class names or only class indices" for key, val in color_map.items(): if isinstance(key, str): if key == "background" and self.background is not None: key = self.background else: assert key in self.class_indices, f"could not find color map key '{key}' in class names" key = self.class_indices[key] assert key in self.class_indices.values() or key == self.background, f"unrecognized class index '{key}'" if isinstance(val, (list, tuple)): val = np.asarray(val) assert isinstance(val, np.ndarray) and val.size == 3, "color values should be given as triplets" self._color_map[key] = val if self.background is not None and self.background not in self._color_map: self._color_map[self.background] = np.asarray([0, 0, 0]) # use black as default 'background' color else: self._color_map = {}
[docs] def get_class_sizes(self, samples, bbox_format=None): """Given a list of samples, returns a map of element counts for each object type.""" assert samples is not None and samples, "provided invalid sample list" elem_counts = {class_name: 0 for class_name in self.class_names} for sample_idx, sample in tqdm.tqdm(enumerate(samples), desc="cumulating bbox counts", total=len(samples)): if self.gt_key is None or self.gt_key not in sample: continue else: bboxes = sample[self.gt_key] if isinstance(bboxes, torch.Tensor): bboxes = bboxes.cpu().numpy() if isinstance(bboxes, (np.ndarray, list, tuple)): bboxes = [BoundingBox.decode(bbox, format=bbox_format) if isinstance(bbox, (np.ndarray, list, tuple)) else bbox for bbox in bboxes] assert all([isinstance(bbox, BoundingBox) for bbox in bboxes]), "unrecognized sample bbox format" assert all([bbox.class_id in self.class_indices.values() for bbox in bboxes]), \ "bboxes contain unknown class ids" for cname, cval in self.class_indices.items(): elem_counts[cname] += len([b for b in bboxes if b.class_id == cval]) return elem_counts
[docs] def check_compat(self, task, exact=False): # type: (Detection, Optional[bool]) -> bool """Returns whether the current task is compatible with the provided one or not. This is useful for sanity-checking, and to see if the inputs/outputs of two models are compatible. If ``exact = True``, all fields will be checked for exact (perfect) compatibility (in this case, matching meta keys and class maps). """ if isinstance(task, Detection): if not Regression.check_compat(self, task, exact=exact): return False return self.background == task.background and \ all([cls in self.class_names for cls in task.class_names]) and \ (not exact or (self.class_names == task.class_names and self.color_map.keys() == task.color_map.keys() and all([np.array_equal(self.color_map[k], task.color_map[k]) for k in self.color_map]))) elif type(task) == Task: # if 'task' simply has no gt, compatibility rests on input key only return not exact and self.input_key == task.input_key and task.gt_key is None return False
[docs] def get_compat(self, task): """Returns a task instance compatible with the current task and the given one.""" assert isinstance(task, Detection) or type(task) == Task, \ f"cannot create compatible task from types '{type(task)}' and '{type(self)}'" if isinstance(task, Detection): assert self.input_key == task.input_key, "input key mismatch, cannot create compatible task" assert self.gt_key is None or task.gt_key is None or self.gt_key == task.gt_key, \ "gt key mismatch, cannot create compatible task" assert self.background == task.background, "background value mismatch, cannot create compatible task" meta_keys = list(set(self.meta_keys + task.meta_keys)) # cannot use set for class names, order needs to stay intact! class_indices = {cname: cval for cname, cval in task.class_indices.items() if cname not in self.class_indices} class_indices = {**self.class_indices, **class_indices} color_map = {cname: cval for cname, cval in task.color_map.items() if cname not in self.color_map} color_map = {**self.color_map, **color_map} return Detection(class_names=class_indices, input_key=self.input_key, bboxes_key=self.gt_key, meta_keys=meta_keys, input_shape=self.input_shape if self.input_shape is not None else task.input_shape, target_shape=self.target_shape if self.target_shape is not None else task.target_shape, target_min=self.target_min if self.target_min is not None else task.target_min, target_max=self.target_max if self.target_max is not None else task.target_max, background=self.background, color_map=color_map) elif type(task) == Task: assert self.check_compat(task), f"cannot create compatible task between:\n\t{str(self)}\n\t{str(task)}" meta_keys = list(set(self.meta_keys + task.meta_keys)) return Detection(class_names=self.class_indices, input_key=self.input_key, bboxes_key=self.gt_key, meta_keys=meta_keys, input_shape=self.input_shape, target_shape=self.target_shape, target_min=self.target_min, target_max=self.target_max, background=self.background, color_map=self.color_map)
def __repr__(self): """Creates a print-friendly representation of a segmentation task.""" color_map = {k: v.tolist() for k, v in self.color_map.items()} return self.__class__.__module__ + "." + self.__class__.__qualname__ + \ f"(class_names={repr(self.class_indices)}, input_key={repr(self.input_key)}, " + \ f"bboxes_key={repr(self.gt_key)}, meta_keys={repr(self.meta_keys)}, " + \ f"input_shape={repr(self.input_shape)}, target_shape={repr(self.target_shape)}, " + \ f"target_min={repr(self.target_min)}, target_max={repr(self.target_max)}, " + \ f"background={repr(self.background)}, color_map={repr(color_map)})"