Source code for thelper.tasks.utils

"""Task utility functions & base interface module.

This module contains utility functions used to instantiate tasks and check their compatibility,
and the base interface used to define new tasks.
"""

import logging
import re
import typing

import thelper.typedefs
import thelper.utils

logger = logging.getLogger(__name__)


[docs]def create_task(config: typing.Union[thelper.typedefs.ConfigDict, typing.AnyStr]) -> "Task": """Parses a configuration dictionary or repr string and instantiates a task from it. If a string is provided, it will first be parsed to get the task type, and then the object will be instantiated by forwarding the parameters contained in the string to the constructor of that type. Note that it is important for this function to work that the constructor argument names match the names of parameters printed in the task's ``__repr__`` function. If a dict is provided, it should contain a 'type' and a 'params' field with the values required for direct instantiation. If a :class:`Task` instance was specified, it is directly returned. .. seealso:: | :class:`thelper.tasks.utils.Task` """ if isinstance(config, Task): return config assert config is not None and isinstance(config, (str, dict)), \ "unexpected config type (should be str or dict)" if isinstance(config, dict): if "type" not in config or not isinstance(config["type"], str): raise AssertionError("invalid field 'type' in task config") task_type = thelper.utils.import_class(config["type"]) task_params = thelper.utils.get_key(["params", "parameters"], config) if not isinstance(task_params, dict): raise AssertionError("invalid field 'params' in task config") task = task_type(**task_params) if not isinstance(task, thelper.tasks.Task): raise AssertionError("the task must be derived from 'thelper.tasks.Task'") return task elif isinstance(config, str): if re.search(r"^[\w\.]+: ", config) is not None: # for backwards compat (pre v0.3.0) task_type_name = config.split(": ")[0] if "." not in task_type_name: # dirty hotfix task_type_name = "thelper.tasks." + task_type_name task_type = thelper.utils.import_class(task_type_name) task_params = eval(": ".join(config.split(": ")[1:])) task = task_type(**task_params) else: task = eval(config) if not isinstance(task, thelper.tasks.Task): raise AssertionError("the task must be derived from 'thelper.tasks.Task'") return task
[docs]def create_global_task(tasks: typing.Optional[typing.Iterable["Task"]]) -> typing.Optional["Task"]: """Returns a new task object that is compatible with a list of subtasks. When different datasets must be combined in a session, the tasks they define must also be merged. This functions allows us to do so as long as the tasks all share a common objective. If creating a globally-compatible task is impossible, this function will raise an exception. Otherwise, the returned task object can be used to replace the subtasks of all used datasets. .. seealso:: | :class:`thelper.tasks.utils.Task` | :func:`thelper.tasks.utils.create_task` | :func:`thelper.data.utils.create_parsers` """ if tasks is None: return None if not isinstance(tasks, list): raise AssertionError("tasks should be provided as list") ref_task = None for task in tasks: if task is None: # skip all undefined tasks continue if not isinstance(task, thelper.tasks.Task): raise AssertionError("all tasks should derive from thelper.tasks.Task") if ref_task is None: # no reference task set; take the first instance and continue to next ref_task = task continue if type(ref_task) != Task: # reference task already specialized, we can ask it for compatible instances ref_task = ref_task.get_compat(task) else: # otherwise, keep asking the new one to stay compatible with the base ref ref_task = task.get_compat(ref_task) return ref_task
[docs]class Task: """Basic task interface that defines a training objective and that holds sample i/o keys. Since the framework's data loaders expect samples to be passed in as dictionaries, keys are required to obtain the input that should be forwarded to a model, and to obtain the groundtruth required for the evaluation of model predictions. Other keys might also be kept by this interface for reference (these are considered meta keys). Note that while this interface can be instantiated directly, trainers and models might not be provided enough information about their goal to be correctly instantiated. Thus, specialized task objects derived from this base class should be used if possible. Attributes: input_key: the key used to fetch input tensors from a sample dictionary. gt_key: the key used to fetch gt tensors from a sample dictionary. meta_keys: the list of extra keys provided by the data parser inside each sample. .. seealso:: | :class:`thelper.tasks.classif.Classification` | :class:`thelper.tasks.segm.Segmentation` | :class:`thelper.tasks.regr.Regression` | :class:`thelper.tasks.detect.Detection` """
[docs] def __init__(self, input_key: typing.Hashable, gt_key: typing.Optional[typing.Hashable] = None, meta_keys: typing.Optional[typing.Iterable[typing.Hashable]] = None, ): """Receives and stores the keys used to index dataset sample contents.""" self.input_key = input_key self.gt_key = gt_key self.meta_keys = meta_keys
@property def input_key(self) -> typing.Hashable: """Returns the key used to fetch input data tensors from a sample dictionary.""" return self._input_key @input_key.setter def input_key(self, value: typing.Hashable) -> None: """Sets the input key used to fetch input data tensors from a sample dictionary. The key can be of any type, as long as it can be used to index a dictionary. Print- friendly types (e.g. string) are recommended for debugging. This key can never be ``None``, as input tensors should always be available in loaded samples. """ assert value is not None, "input key cannot be `None` (input data should always be available)" assert isinstance(value, typing.Hashable), "key type must be hashable" self._input_key = value @property def gt_key(self) -> typing.Optional[typing.Hashable]: """Returns the key used to fetch groundtruth data tensors from a sample dictionary.""" return self._gt_key @gt_key.setter def gt_key(self, value: typing.Optional[typing.Hashable]) -> None: """Sets the key used to fetch groundtruth data tensors from a sample dictionary. The key can be of any type, as long as it can be used to index a dictionary. Print- friendly types (e.g. string) are recommended for debugging. If groundtruth is not available through the dataset parsers, this key can be set to ``None``. """ assert value is None or isinstance(value, typing.Hashable), "key type must be hashable" self._gt_key = value @property def meta_keys(self) -> typing.Optional[typing.Iterable[typing.Hashable]]: """Returns the list of keys used to carry meta/auxiliary data in samples.""" return self._meta_keys @meta_keys.setter def meta_keys(self, value: typing.Optional[typing.Iterable[typing.Hashable]]) -> None: """Sets the list of keys used to carry meta/auxiliary data in samples. The keys can be of any type, as long as they can be used to index a dictionary. Print-friendly types (e.g. string) are recommended for debugging. This list can be empty if no extra data is available. """ assert value is None or isinstance(value, typing.Iterable), "meta keys should be an iterable" assert not isinstance(value, str), "meta keys should be list/tuple/array/... of strings" value = [] if value is None else value assert all([v is not None and isinstance(v, typing.Hashable) for v in value]), \ "all meta key types must be hashable" self._meta_keys = value @property def keys(self) -> typing.List[typing.Hashable]: """Returns a list of all keys used to carry tensors and metadata in samples.""" return list(set([k for k in [self.input_key, self.gt_key, *self.meta_keys] if k is not None]))
[docs] def check_compat(self, task: "Task", exact: bool = False, ) -> bool: """Returns whether the current task is compatible with the provided one or not. This is useful for sanity-checking, and to see if the inputs/outputs of two models are compatible. It should be overridden in derived classes to specialize the compatibility verification. If ``exact = True``, all fields will be checked for exact compatibility. """ return type(task) == Task and \ (self.input_key == task.input_key and (self.gt_key is None or task.gt_key is None or self.gt_key == task.gt_key) and (not exact or (set(self.meta_keys) == set(task.meta_keys) and self.gt_key == task.gt_key)))
[docs] def get_compat(self, task: "Task") -> "Task": """Returns a task instance compatible with the current task and the given one.""" assert type(task) == Task, f"cannot create compatible task from types '{type(task)}' and '{type(self)}'" assert self.check_compat(task), f"cannot create compatible task between:\n\t{str(self)}\n\t{str(task)}" return Task(input_key=self.input_key, gt_key=self.gt_key, meta_keys=list(set(self.meta_keys + task.meta_keys)))
def __repr__(self) -> str: """Creates a print-friendly representation of an abstract task. Note that this representation might also be used to check the compatibility of tasks without importing the whole framework. Therefore, it should contain all the necessary information about the task. The name of the parameters herein should also match the argument names given to the constructor in case we need to recreate a task object from this string. """ return self.__class__.__module__ + "." + self.__class__.__qualname__ + \ f"(input_key={repr(self.input_key)}, gt_key={repr(self.gt_key)}, meta_keys={repr(self.meta_keys)})"