Source code for thelper.infer.utils

from typing import AnyStr, Optional

import thelper.infer.base
import thelper.tasks


[docs]def create_tester(session_name, # type: AnyStr save_dir, # type: AnyStr config, # type: thelper.typedefs.ConfigDict model, # type: thelper.typedefs.ModelType task, # type: thelper.tasks.Task loaders, # type: thelper.typedefs.MultiLoaderType ckptdata=None # type: Optional[thelper.typedefs.CheckpointContentType] ): # type: (...) -> thelper.infer.base.Tester """Instantiates the tester object based on the type contained in the config dictionary. The tester type is expected to be in the configuration dictionary's `tester` field, under the `type` key. For backward compatibility, the fields `runner` and `trainer` will also be looked for. For more information on the configuration, refer to :class:`thelper.train.base.Trainer`. The instantiated type must be compatible with the constructor signature of :class:`thelper.train.base.Trainer`. The object's constructor will be given the full config dictionary and the checkpoint data for resuming the session (if available). If the trainer type is missing, it will be automatically deduced based on the task object. Args: session_name: name of the training session used for printing and to create internal tensorboardX directories. save_dir: path to the session directory where logs and checkpoints will be saved. config: full configuration dictionary that will be parsed for trainer parameters and saved in checkpoints. model: model to train/evaluate; should be compatible with :class:`thelper.nn.utils.Module`. task: global task interface defining the type of model and training goal for the session. loaders: a tuple containing the training/validation/test data loaders (a loader can be ``None`` if empty). ckptdata: raw checkpoint to parse data from when resuming a session (if ``None``, will start from scratch). Returns: The fully-constructed trainer object, ready to begin model training/evaluation. .. seealso:: | :class:`thelper.infer.base.Tester` """ # NOTE: # counter intuitive name 'trainer', but nothing will actually be trained, only to match other thelper modes runner_config = config.get("tester", config.get("runner", config.get("trainer"))) if not runner_config or not isinstance(runner_config, dict): raise AssertionError("Could not retrieve any session runner definition from configuration") if "type" not in runner_config: if isinstance(task, thelper.tasks.Classification): runner_type = thelper.infer.ImageClassifTester elif isinstance(task, thelper.tasks.Detection): runner_type = thelper.infer.ObjDetectTester elif isinstance(task, thelper.tasks.Regression): runner_type = thelper.infer.RegressionTester elif isinstance(task, thelper.tasks.Segmentation): runner_type = thelper.infer.ImageSegmTester else: raise AssertionError(f"unknown trainer type required for task '{str(task)}'") else: runner_type = thelper.utils.import_class(runner_config["type"]) return runner_type(session_name, save_dir, model, task, loaders, config, ckptdata=ckptdata)