Source code for thelper.cli

#!/usr/bin/env python
"""
Command-line module, for use with a ``__main__`` entrypoint.

This module contains the primary functions used to create or resume a training session, to start a
visualization session, or to start an annotation session. The basic argument that needs to be provided
by the user to create any kind of session is a configuration dictionary. For sessions that produce
outputs, the path to a directory where to save the data is also needed.
"""

import argparse
import logging
import os
from typing import Any, Union

import torch
import tqdm

import thelper

TASK_COMPAT_CHOICES = frozenset(["old", "new", "compat"])


[docs]def create_session(config, save_dir): """Creates a session to train a model. All generated outputs (model checkpoints and logs) will be saved in a directory named after the session (the name itself is specified in ``config``), and located in ``save_dir``. Args: config: a dictionary that provides all required data configuration and trainer parameters; see :class:`thelper.train.base.Trainer` and :func:`thelper.data.utils.create_loaders` for more information. Here, it is only expected to contain a ``name`` field that specifies the name of the session. save_dir: the path to the root directory where the session directory should be saved. Note that this is not the path to the session directory itself, but its parent, which may also contain other session directories. .. seealso:: | :class:`thelper.train.base.Trainer` """ logger = thelper.utils.get_func_logger() session_name = thelper.utils.get_config_session_name(config) assert session_name is not None, "config missing 'name' field required for output directory" logger.info("creating new training session '%s'..." % session_name) thelper.utils.setup_globals(config) save_dir = thelper.utils.get_save_dir(save_dir, session_name, config) logger.debug("session will be saved at '%s'" % os.path.abspath(save_dir).replace("\\", "/")) task, train_loader, valid_loader, test_loader = thelper.data.create_loaders(config, save_dir) model = thelper.nn.create_model(config, task, save_dir=save_dir) loaders = (train_loader, valid_loader, test_loader) trainer = thelper.train.create_trainer(session_name, save_dir, config, model, task, loaders) logger.debug("starting trainer") if train_loader: trainer.train() else: trainer.eval() logger.debug("all done") thelper.utils.report_orion_results(trainer) return trainer.outputs
[docs]def resume_session(ckptdata, save_dir, config=None, eval_only=False, task_compat=None): """Resumes a previously created training session. Since the saved checkpoints contain the original session's configuration, the ``config`` argument can be set to ``None`` if the session should simply pick up where it was interrupted. Otherwise, the ``config`` argument can be set to a new configuration that will override the older one. This is useful when fine-tuning a model, or when testing on a new dataset. .. warning:: If a session is resumed with an overriding configuration, the user must make sure that the inputs/outputs of the older model are compatible with the new parameters. For example, with classifiers, this means that the number of classes parsed by the dataset (and thus to be predicted by the model) should remain the same. This is a limitation of the framework that should be addressed in a future update. .. warning:: A resumed session will not be compatible with its original RNG states if the number of workers used is changed. To get 100% reproducible results, make sure you run with the same worker count. Args: ckptdata: raw checkpoint data loaded via ``torch.load()``; it will be parsed by the various parts of the framework that need to reload their previous state. save_dir: the path to the root directory where the session directory should be saved. Note that this is not the path to the session directory itself, but its parent, which may also contain other session directories. config: a dictionary that provides all required data configuration and trainer parameters; see :class:`thelper.train.base.Trainer` and :func:`thelper.data.utils.create_loaders` for more information. Here, it is only expected to contain a ``name`` field that specifies the name of the session. eval_only: specifies whether training should be resumed or the model should only be evaluated. task_compat: specifies how to handle discrepancy between old task from checkpoint and new task from config .. seealso:: | :class:`thelper.train.base.Trainer` """ logger = thelper.utils.get_func_logger() if ckptdata is None or not ckptdata: raise AssertionError("must provide valid checkpoint data to resume a session!") if not config: if "config" not in ckptdata or not ckptdata["config"]: raise AssertionError("checkpoint data missing 'config' field") config = ckptdata["config"] session_name = thelper.utils.get_config_session_name(config) assert session_name is not None, "config missing 'name' field required for output directory" logger.info("loading training session '%s' objects..." % session_name) thelper.utils.setup_globals(config) save_dir = thelper.utils.get_save_dir(save_dir, session_name, config, resume=True) logger.debug("session will be saved at '%s'" % os.path.abspath(save_dir).replace("\\", "/")) new_task, train_loader, valid_loader, test_loader = thelper.data.create_loaders(config, save_dir) if "task" not in ckptdata or not ckptdata["task"] or not isinstance(ckptdata["task"], (thelper.tasks.Task, str)): raise AssertionError("invalid checkpoint, cannot reload previous model task") old_task = thelper.tasks.create_task(ckptdata["task"]) if isinstance(ckptdata["task"], str) else ckptdata["task"] if not old_task.check_compat(new_task, exact=True): compat_task = None if not old_task.check_compat(new_task) else old_task.get_compat(new_task) if task_compat in TASK_COMPAT_CHOICES: logger.warning("discrepancy between old task from checkpoint and new task from config resolved by " + f"input argument: task_compat={task_compat}") task_compat_mode = task_compat else: loaders_config = thelper.utils.get_key(["data_config", "loaders"], config) task_compat_mode = thelper.utils.get_key_def("task_compat_mode", loaders_config, default=None) if task_compat_mode in TASK_COMPAT_CHOICES: logger.warning("discrepancy between old task from checkpoint and new task from config resolved by " + f"config: task_compat_mode={task_compat_mode}") if task_compat_mode not in TASK_COMPAT_CHOICES: task_compat_mode = thelper.utils.query_string( "Found discrepancy between old task from checkpoint and new task from config; " + "which one would you like to resume the session with?\n" + f"\told: {str(old_task)}\n\tnew: {str(new_task)}\n" + (f"\tcompat: {str(compat_task)}\n\n" if compat_task is not None else "\n") + "WARNING: if resuming with new or compat, some weights might be discarded!", choices=list(TASK_COMPAT_CHOICES)) task = old_task if task_compat_mode == "old" else new_task if task_compat_mode == "new" else compat_task if task_compat_mode != "old": # saved optimizer state might cause issues with mismatched tasks, let's get rid of it logger.warning("dumping optimizer state to avoid issues when resuming with modified task") ckptdata["optimizer"], ckptdata["scheduler"] = None, None else: task = new_task assert task is not None, "invalid task" model = thelper.nn.create_model(config, task, save_dir=save_dir, ckptdata=ckptdata) loaders = (None if eval_only else train_loader, valid_loader, test_loader) trainer = thelper.train.create_trainer(session_name, save_dir, config, model, task, loaders, ckptdata=ckptdata) if eval_only: logger.info("evaluating session '%s' checkpoint @ epoch %d" % (trainer.name, trainer.current_epoch)) trainer.eval() else: logger.info("resuming training session '%s' @ epoch %d" % (trainer.name, trainer.current_epoch)) trainer.train() logger.debug("all done") thelper.utils.report_orion_results(trainer) return trainer.outputs
[docs]def visualize_data(config): """Displays the images used in a training session. This mode does not generate any output, and is only used to visualize the (transformed) images used in a training session. This is useful to debug the data augmentation and base transformation pipelines and make sure the modified images are valid. It does not attempt to load a model or instantiate a trainer, meaning the related fields are not required inside ``config``. If the configuration dictionary includes a 'loaders' field, it will be parsed and used. Otherwise, if only a 'datasets' field is available, basic loaders will be instantiated to load the data. The 'loaders' field can also be ignored if 'ignore_loaders' is found within the 'viz' section of the config and set to ``True``. Each minibatch will be displayed via pyplot or OpenCV. The display will block and wait for user input, unless 'block' is set within the 'viz' section's 'kwargs' config as ``False``. Args: config: a dictionary that provides all required data configuration parameters; see :func:`thelper.data.utils.create_loaders` for more information. .. seealso:: | :func:`thelper.data.utils.create_loaders` | :func:`thelper.data.utils.create_parsers` """ logger = thelper.utils.get_func_logger() logger.info("creating visualization session...") thelper.utils.setup_globals(config) viz_config = thelper.utils.get_key_def("viz", config, default={}) if not isinstance(viz_config, dict): raise AssertionError("unexpected viz config type") ignore_loaders = thelper.utils.get_key_def("ignore_loaders", viz_config, default=False) viz_kwargs = thelper.utils.get_key_def(["params", "kwargs"], viz_config, default={}) if not isinstance(viz_kwargs, dict): raise AssertionError("unexpected viz kwargs type") if thelper.utils.get_key_def(["data_config", "loaders"], config, default=None) is None or ignore_loaders: datasets, task = thelper.data.create_parsers(config) loader_map = {dataset_name: thelper.data.DataLoader(dataset,) for dataset_name, dataset in datasets.items()} # we assume no transforms were done in the parser, and images are given as read by opencv viz_kwargs["ch_transpose"] = thelper.utils.get_key_def("ch_transpose", viz_kwargs, False) viz_kwargs["flip_bgr"] = thelper.utils.get_key_def("flip_bgr", viz_kwargs, False) else: task, train_loader, valid_loader, test_loader = thelper.data.create_loaders(config) loader_map = {"train": train_loader, "valid": valid_loader, "test": test_loader} # we assume transforms were done in the loader, and images are given as expected by pytorch viz_kwargs["ch_transpose"] = thelper.utils.get_key_def("ch_transpose", viz_kwargs, True) viz_kwargs["flip_bgr"] = thelper.utils.get_key_def("flip_bgr", viz_kwargs, False) redraw = None viz_kwargs["block"] = thelper.utils.get_key_def("block", viz_kwargs, default=True) assert "quit" not in loader_map choices = list(loader_map.keys()) + ["quit"] while True: choice = thelper.utils.query_string("Which loader would you like to visualize?", choices=choices) if choice == "quit": break loader = loader_map[choice] if loader is None: logger.info("loader '%s' is empty" % choice) continue batch_count = len(loader) logger.info("initializing loader '%s' with %d batches..." % (choice, batch_count)) for sample in tqdm.tqdm(loader): input = sample[task.input_key] target = sample[task.gt_key] if task.gt_key in sample else None redraw = thelper.draw.draw(task=task, input=input, target=target, redraw=redraw, **viz_kwargs) logger.info("all done")
[docs]def annotate_data(config, save_dir): """Launches an annotation session for a dataset using a specialized GUI tool. Note that the annotation type must be supported by the GUI tool. The annotations created by the user during the session will be saved in the session directory. Args: config: a dictionary that provides all required dataset and GUI tool configuration parameters; see :func:`thelper.data.utils.create_parsers` and :func:`thelper.gui.utils.create_annotator` for more information. save_dir: the path to the root directory where the session directory should be saved. Note that this is not the path to the session directory itself, but its parent, which may also contain other session directories. .. seealso:: | :func:`thelper.gui.annotators.Annotator` | :func:`thelper.gui.annotators.ImageSegmentAnnotator` """ # import gui here since it imports packages that will cause errors in CLI-only environments import thelper.gui logger = thelper.utils.get_func_logger() session_name = thelper.utils.get_config_session_name(config) assert session_name is not None, "config missing 'name' field required for output directory" logger.info("creating annotation session '%s'..." % session_name) thelper.utils.setup_globals(config) save_dir = thelper.utils.get_save_dir(save_dir, session_name, config) logger.debug("session will be saved at '%s'" % os.path.abspath(save_dir).replace("\\", "/")) datasets, _ = thelper.data.create_parsers(config) annotator = thelper.gui.create_annotator(session_name, save_dir, config, datasets) logger.debug("starting annotator") annotator.run() logger.debug("all done")
[docs]def split_data(config, save_dir): """Launches a dataset splitting session. This mode will generate an HDF5 archive that contains the split datasets defined in the session configuration file. This archive can then be reused in a new training session to guarantee a fixed distribution of training, validation, and testing samples. It can also be used outside the framework in order to reproduce an experiment. The configuration dictionary must minimally contain two sections: 'datasets' and 'loaders'. A third section, 'split', can be used to provide settings regarding the archive packing and compression approaches to use. The HDF5 archive will be saved in the session's output directory. Args: config: a dictionary that provides all required data configuration parameters; see :func:`thelper.data.utils.create_loaders` for more information. save_dir: the path to the root directory where the session directory should be saved. Note that this is not the path to the session directory itself, but its parent, which may also contain other session directories. .. seealso:: | :func:`thelper.data.utils.create_loaders` | :func:`thelper.data.utils.create_hdf5` | :class:`thelper.data.parsers.HDF5Dataset` """ logger = thelper.utils.get_func_logger() session_name = thelper.utils.get_config_session_name(config) assert session_name is not None, "config missing 'name' field required for output directory" split_config = thelper.utils.get_key_def("split", config, default={}) if not isinstance(split_config, dict): raise AssertionError("unexpected split config type") compression = thelper.utils.get_key_def("compression", split_config, default={}) if not isinstance(compression, dict): raise AssertionError("compression params should be given as dictionary") archive_name = thelper.utils.get_key_def("archive_name", split_config, default=(session_name + ".hdf5")) logger.info("creating new splitting session '%s'..." % session_name) thelper.utils.setup_globals(config) save_dir = thelper.utils.get_save_dir(save_dir, session_name, config) logger.debug("session will be saved at '%s'" % os.path.abspath(save_dir).replace("\\", "/")) task, train_loader, valid_loader, test_loader = thelper.data.create_loaders(config, save_dir) archive_path = os.path.join(save_dir, archive_name) thelper.data.create_hdf5(archive_path, task, train_loader, valid_loader, test_loader, compression, config) logger.debug("all done")
[docs]def inference_session(config, save_dir=None, ckpt_path=None): """Executes an inference session on samples with a trained model checkpoint. In order to run inference, a model is mandatory and therefore expected to be provided in the configuration. Similarly, a list of input sample file paths are expected for which to run inference on. These inputs can provide additional data according to how they are being parsed by lower level operations. The session will save the current configuration as `config-infer.json` and the employed model's training configuration as `config-train.json`. Other outputs depend on the specific implementation of the session runner. Args: config: a dictionary that provides all required data configuration. save_dir: the path to the root directory where the session directory should be saved. Note that this is not the path to the session directory itself, but its parent, which may also contain other session directories. If not provided, will use the best value extracted from either the configuration path or the configuration dictionary itself. ckpt_path: explicit checkpoint path to use for loading a model to execute inference. Otherwise look for a model definition in the configuration. .. seealso:: | :func:`thelper.data.geo.utils.prepare_raster_metadata` | :func:`thelper.data.geo.utils.sliding_window_inference` """ import thelper.data.geo logger = thelper.utils.get_func_logger() session_name = thelper.utils.get_config_session_name(config) thelper.utils.setup_globals(config) def is_defined_dict(container, section): # type: (thelper.typedefs.ConfigDict, str) -> bool return section in container and isinstance(container[section], dict) and bool(len(container[section])) if ckpt_path is not None and os.path.exists(ckpt_path): logger.warning("Overriding model definition with explicit checkpoint path argument") config["model"] = {"ckpt_path": ckpt_path, "params": {"pretrained": True}} if not is_defined_dict(config, "model") or "ckpt_path" not in config["model"]: raise RuntimeError("Missing a model checkpoint definition with which to run inference") ckpt_path = config["model"]["ckpt_path"] if not is_defined_dict(config["model"], "params") or "pretrained" not in config["model"]["params"]: logger.warning("Adding model pretrained flag to definition") config["model"].setdefault("params", {}) config["model"]["params"].setdefault("pretrained", True) if not config["model"]["params"]["pretrained"]: raise RuntimeError("Model checkpoint was explicitly defined as not pretrained. It must be for inference.") if not os.path.exists(ckpt_path): logger.fatal("Model not found: %s", ckpt_path) raise AssertionError("Model checkpoint missing to run inference") ckptdata = thelper.utils.load_checkpoint(ckpt_path, map_location=None, always_load_latest=False) if "task" not in ckptdata or not isinstance(ckptdata["task"], (thelper.tasks.Task, str)): raise AssertionError("invalid checkpoint, cannot reload model task") task = ckptdata["task"] task = thelper.tasks.create_task(task) # make sure the task is instantiated if string if save_dir is None: save_dir = thelper.utils.get_checkpoint_session_root(ckpt_path) save_dir = os.path.join(save_dir, session_name) if not os.path.exists(save_dir): os.makedirs(save_dir) if not is_defined_dict(config, "datasets"): raise RuntimeError("Missing at least one dataset definition in configuration.") logger.info("Updating configuration with loaders inputs...") if not is_defined_dict(config, "loaders"): logger.warning("Missing loaders definition in configuration. Using default.") config["loaders"] = dict() loaders = config["loaders"] if "batch_size" not in loaders or not (isinstance(loaders["batch_size"], int) and loaders["batch_size"] > 0): logger.warning("Missing loader 'batch_size' definition in configuration. Using default (batch_size=1).") loaders["batch_size"] = 1 if not is_defined_dict(loaders, "test_split"): logger.warning("Missing loader 'test_split' definition in configuration. Using all found datasets.") loaders["test_split"] = {dataset_name: 1.0 for dataset_name in config["datasets"].keys()} if "task_compat_mode" not in loaders: logger.warning("Missing loader 'task_compat_mode' definition in configuration. Using compatible mode.") loaders["task_compat_mode"] = "compat" if "shuffle" not in loaders: logger.warning("Missing loader 'shuffle' definition in configuration. Using 'false' for inference.") loaders["shuffle"] = False # Because loaders require a 'task' but that we are doing inference instead of training, all class-names are # enforced by the model's task definition (cannot adapt). Override datasets' task to make sure of this and to # avoid having the user needing to plug us a dummy task just to meet the loader requirement. all_datasets = config["datasets"] for dataset_name in all_datasets: dataset = all_datasets[dataset_name] if is_defined_dict(dataset, "task"): logger.warning("Overriding task of dataset '%s' with model's task", dataset_name) all_datasets[dataset_name]["task"] = str(task) # use str to allow both json dump and parsing by dataset loader _, _, _, test_loader = thelper.data.create_loaders(config, save_dir) if not test_loader or not len(test_loader): raise RuntimeError("Could not define a test loader for model inference") model = thelper.nn.create_model(config, task, save_dir=save_dir, ckptdata=ckptdata) loaders = (None, None, test_loader) # Avoid passing any checkpoint data so that any argument don't get incorrectly used by the session runner. # Since we only call test/eval, these values don't matter but they still get generated otherwise and this leads # to weird internal values such as test starting at 'current_epoch' == 'best_epoch' from checkpoint training or # optimizers being instantiated. tester = thelper.infer.create_tester(session_name, save_dir, config, model, task, loaders) # ckptdata=ckptdata) config_file = "config-train.json" config_file_path = os.path.join(save_dir, config_file) train_config = ckptdata["config"] logger.debug("Writing employed train config: [%s]", config_file_path) thelper.utils.save_config(train_config, config_file_path, force_convert=True) config_name = "config-infer.json" config_file_path = os.path.join(save_dir, config_name) logger.debug("Writing employed infer config: [%s]", config_file_path) thelper.utils.save_config(config, config_file_path, force_convert=True) logger.info("Starting inference.") tester.test() logger.info("Inference completed.")
[docs]def export_model(config, save_dir): """Launches a model exportation session. This function will export a model defined via a configuration file into a new checkpoint that can be loaded elsewhere. The model can be built using the framework, or provided via its type, construction parameters, and weights. Its exported format will be compatible with the framework, and may also be an optimized/compiled version obtained using PyTorch's JIT tracer. The configuration dictionary must minimally contain a 'model' section that provides details on the model to be exported. A section named 'export' can be used to provide settings regarding the exportation approaches to use, and the task interface to save with the model. If a task is not explicitly defined in the 'export' section, the session configuration will be parsed for a 'datasets' section that can be used to define it. Otherwise, it must be provided through the model. The exported checkpoint containing the model will be saved in the session's output directory. Args: config: a dictionary that provides all required data configuration parameters; see :func:`thelper.nn.utils.create_model` for more information. save_dir: the path to the root directory where the session directory should be saved. Note that this is not the path to the session directory itself, but its parent, which may also contain other session directories. .. seealso:: | :func:`thelper.nn.utils.create_model` """ logger = thelper.utils.get_func_logger() session_name = thelper.utils.get_config_session_name(config) assert session_name is not None, "config missing 'name' field required for output directory" export_config = thelper.utils.get_key_def("export", config, default={}) if not isinstance(export_config, dict): raise AssertionError("unexpected export config type") ckpt_name = thelper.utils.get_key_def("ckpt_name", export_config, default=(session_name + ".export.pth")) trace_name = thelper.utils.get_key_def("trace_name", export_config, default=(session_name + ".trace.zip")) save_raw = thelper.utils.get_key_def("save_raw", export_config, default=True) trace_input = thelper.utils.get_key_def("trace_input", export_config, default=None) task = thelper.utils.get_key_def("task", export_config, default=None) if isinstance(task, (str, dict)): task = thelper.tasks.create_task(task) if task is None and "datasets" in config: _, task = thelper.data.create_parsers(config) # try to load via datasets... assert task is not None, "could not get proper task object from export config or data parsers" if isinstance(trace_input, str): trace_input = eval(trace_input) logger.info("exporting model '%s'..." % session_name) thelper.utils.setup_globals(config) save_dir = thelper.utils.get_save_dir(save_dir, session_name, config) logger.debug("exported checkpoint will be saved at '%s'" % os.path.abspath(save_dir).replace("\\", "/")) model = thelper.nn.create_model(config, task, save_dir=save_dir) log_stamp = thelper.utils.get_log_stamp() model_type = model.get_name() model_params = model.config if model.config else {} # the saved state below should be kept compatible with the one in thelper.train.base._save export_state = { "name": session_name, "source": log_stamp, "git_sha1": thelper.utils.get_git_stamp(), "version": thelper.__version__, "task": str(task) if save_raw else task, "model_type": model_type, "model_params": model_params, "config": config } if trace_input is not None: trace_path = os.path.join(save_dir, trace_name) torch.jit.trace(model, trace_input).save(trace_path) export_state["model"] = trace_name # will be loaded in thelper.utils.load_checkpoint else: export_state["model"] = model.state_dict() if save_raw else model torch.save(export_state, os.path.join(save_dir, ckpt_name)) logger.debug("all done")
[docs]def make_argparser(): # type: () -> argparse.ArgumentParser """Creates the (default) argument parser to use for the main entrypoint. The argument parser will contain different "operating modes" that dictate the high-level behavior of the CLI. This function may be modified in branches of the framework to add project-specific features. """ ap = argparse.ArgumentParser(description='thelper model trainer application') ap.add_argument("--version", default=False, action="store_true", help="prints the version of the library and exits") ap.add_argument("-l", "--log", default=None, type=str, help="path to the top-level log file (default: None)") ap.add_argument("-v", "--verbose", action="count", default=3, help="set logging terminal verbosity level (additive)") ap.add_argument("--silent", action="store_true", default=False, help="deactivates all console logging activities") ap.add_argument("--force-stdout", action="store_true", default=False, help="force logging output to stdout instead of stderr") subparsers = ap.add_subparsers(title="Operating mode", dest="mode") new_ap = subparsers.add_parser("new", help="creates a new session from a config file") new_ap.add_argument("-c", "--config", required=True, type=str, help="path to the session configuration file") new_ap.add_argument("-d", "--save-dir", required=True, type=str, help="path to the session output root directory") cl_new_ap = subparsers.add_parser("cl_new", help="creates a new session from a config file for the cluster") cl_new_ap.add_argument("-c", "--config", required=True, type=str, help="path to the session configuration file") cl_new_ap.add_argument("-d", "--save-dir", required=True, type=str, help="path to the session output root directory") resume_ap = subparsers.add_parser("resume", help="resume a session from a checkpoint file") resume_ap.add_argument("--ckpt-path", default=None, type=str, help="path to the checkpoint (or directory) to reload") resume_ap.add_argument("-d", "--save-dir", default=None, type=str, help="path to the session output root directory") resume_ap.add_argument("-m", "--map-location", default=None, help="map location for loading data (default=None)") resume_ap.add_argument("-c", "--override-config", default=None, help="override config file path (default=None)") resume_ap.add_argument("-e", "--eval-only", default=False, action="store_true", help="only run evaluation pass (valid+test)") resume_ap.add_argument("-t", "--task-compat", default=None, type=str, choices=TASK_COMPAT_CHOICES, help="task compatibility mode to use to resolve any discrepancy between loaded tasks") viz_ap = subparsers.add_parser("viz", help="visualize the loaded data for a training/eval session") viz_ap.add_argument("-c", "--config", required=True, type=str, help="path to the session configuration file (or session directory)") annot_ap = subparsers.add_parser("annot", help="launches a dataset annotation session with a GUI tool") annot_ap.add_argument("-c", "--config", required=True, type=str, help="path to the session configuration file (or session directory)") annot_ap.add_argument("-d", "--save-dir", required=True, type=str, help="path to the session output root directory") split_ap = subparsers.add_parser("split", help="launches a dataset splitting session from a config file") split_ap.add_argument("-c", "--config", required=True, type=str, help="path to the session configuration file (or session directory)") split_ap.add_argument("-d", "--save-dir", required=True, type=str, help="path to the session output root directory") export_ap = subparsers.add_parser("export", help="launches a model exportation session from a config file") export_ap.add_argument("-c", "--config", required=True, type=str, help="path to the session configuration file (or session directory)") export_ap.add_argument("-d", "--save-dir", required=True, type=str, help="path to the session output root directory") infer_ap = subparsers.add_parser("infer", help="creates a inference session from a config file") infer_ap.add_argument("--ckpt-path", type=str, help="path to the checkpoint (or directory) to use for inference " "(otherwise uses model checkpoint from configuration)") infer_ap.add_argument("-c", "--config", type=str, help="path to the session configuration file (or session directory)") infer_ap.add_argument("-d", "--save-dir", type=str, help="path to the session output root directory") return ap
[docs]def setup(args=None, argparser=None): # type: (Any, argparse.ArgumentParser) -> Union[int, argparse.Namespace] """Sets up the argument parser (if not already done externally) and parses the input CLI arguments. This function may return an error code (integer) if the program should exit immediately. Otherwise, it will return the parsed arguments to use in order to redirect the execution flow of the entrypoint. """ argparser = argparser or make_argparser() args = argparser.parse_args(args=args) if args.version: print(thelper.__version__) return 0 if args.mode is None: argparser.print_help() return 1 if args.silent and args.verbose > 0: raise AssertionError("contradicting verbose/silent arguments provided") log_level = logging.INFO if args.verbose < 1 else logging.DEBUG if args.verbose < 2 else logging.NOTSET thelper.utils.init_logger(log_level, args.log, args.force_stdout) return args
[docs]def main(args=None, argparser=None): """Main entrypoint to use with console applications. This function parses command line arguments and dispatches the execution based on the selected operating mode. Run with ``--help`` for information on the available arguments. .. warning:: If you are trying to resume a session that was previously executed using a now unavailable GPU, you will have to force the checkpoint data to be loaded on CPU using ``--map-location=cpu`` (or using ``-m=cpu``). .. seealso:: | :func:`thelper.cli.create_session` | :func:`thelper.cli.resume_session` | :func:`thelper.cli.visualize_data` | :func:`thelper.cli.annotate_data` | :func:`thelper.cli.split_data` | :func:`thelper.cli.inference_session` """ args = setup(args=args, argparser=argparser) if isinstance(args, int): return args # CLI must exit immediately with provided error code if args.mode == "new" or args.mode == "cl_new": thelper.logger.debug("parsing config at '%s'" % args.config) config = thelper.utils.load_config(args.config) if args.mode == "cl_new": trainer_config = thelper.utils.get_key_def("trainer", config, {}) device = thelper.utils.get_key_def("device", trainer_config, None) if device is not None: raise AssertionError("cannot specify device in config for cluster sessions, it is determined at runtime") create_session(config, args.save_dir) elif args.mode == "resume": ckptdata = None if args.ckpt_path is not None: ckptdata = thelper.utils.load_checkpoint(args.ckpt_path, map_location=args.map_location, always_load_latest=(not args.eval_only)) override_config = None if args.override_config: thelper.logger.debug(f"parsing override config at: {args.override_config}") override_config = thelper.utils.load_config(args.override_config) save_dir = args.save_dir if save_dir is None and args.ckpt_path is not None: save_dir = thelper.utils.get_checkpoint_session_root(args.ckpt_path) if save_dir is None: save_dir = thelper.utils.get_save_dir(out_root=None, dir_name=None, config=override_config) resume_session(ckptdata, save_dir, config=override_config, eval_only=args.eval_only, task_compat=args.task_compat) elif args.mode == "infer": thelper.logger.debug(f"parsing config at: {args.config}") config = thelper.utils.load_config(args.config) inference_session(config, save_dir=args.save_dir, ckpt_path=args.ckpt_path) else: thelper.logger.debug("parsing config at '%s'" % args.config) config = thelper.utils.load_config(args.config) if args.mode == "viz": visualize_data(config) elif args.mode == "annot": annotate_data(config, args.save_dir) elif args.mode == "export": export_model(config, args.save_dir) else: # if args.mode == "split": split_data(config, args.save_dir) return 0
if __name__ == "__main__": main()