Source code for thelper.infer.base

from abc import abstractmethod
from typing import TYPE_CHECKING

import thelper.utils
from thelper.train.base import Trainer

if TYPE_CHECKING:
    from typing import AnyStr, Callable, Optional, Type  # noqa: F401
    import thelper.typedefs  # noqa: F401


[docs]class Tester(Trainer): """Base interface of a session runner for testing. This call mostly delegates calls to existing Trainer implementation, but limiting their use to 'eval' methods to make sure that 'train' operations are not called by mistake. .. seealso:: | :class:`thelper.train.base.Trainer` """
[docs] def __init__(self, session_name, # type: AnyStr session_dir, # type: AnyStr model, # type: thelper.typedefs.ModelType task, # type: thelper.tasks.Task loaders, # type: thelper.typedefs.MultiLoaderType config, # type: thelper.typedefs.ConfigDict ckptdata=None # type: Optional[thelper.typedefs.CheckpointContentType] ): runner_config = thelper.utils.get_key_def(["runner", "tester"], config) or {} # default epoch 0 if omitted as they are not actually needed for single pass inference if "epochs" not in runner_config: runner_config["epochs"] = 1 config["trainer"] = runner_config if "tester" not in config: config["tester"] = runner_config super().__init__(session_name, session_dir, model, task, loaders, config, ckptdata=ckptdata)
[docs] def train(self): raise RuntimeError(f"Invalid call to 'train' using '{type(self).__name__}' (Tester)")
[docs] def train_epoch(self, model, epoch, dev, loss, optimizer, loader, metrics, output_path): raise RuntimeError(f"Invalid call to 'train_epoch' using '{type(self).__name__}' (Tester)")
[docs] def test(self): return self.eval()
[docs] def test_epoch(self, *args, **kwargs): return self.eval_epoch(*args, **kwargs)
[docs] @abstractmethod def eval_epoch(self, model, epoch, dev, loader, metrics, output_path): """Evaluates the model using the provided objects. Args: model: the model with which to run inference that is already uploaded to the target device(s). epoch: the epoch index we are training for (0-based, and should normally only be 0 for single test pass). dev: the target device that tensors should be uploaded to (corresponding to model's device(s)). loader: the data loader used to get transformed test samples. metrics: the dictionary of metrics/consumers to report inference results (mostly loggers and basic report generator in this case since there shouldn't be ground truth labels to validate against). output_path: directory where output files should be written, if necessary. """ raise NotImplementedError