Source code for thelper.session.base

import functools
import json
import logging
import os
import pickle
import platform
import random
import time
from copy import deepcopy
from typing import Any, AnyStr, Optional

import cv2 as cv
import numpy as np
import torch

import thelper.data
import thelper.nn
import thelper.optim
import thelper.tasks
import thelper.typedefs
import thelper.utils
import thelper.viz

logger = logging.getLogger(__name__)


[docs]class SessionRunner: """Abstract session runner interface that defines basic session i/o and setup operations. This class offers the most basic methods that can be employed by more specialized training or inference sessions. By itself, it doesn't actually run anything. Attributes: checkpoint_dir: session checkpoint output directory (located within the 'session directory'). config: session configuration dictionary holding all original settings, including trainer configuration. devices: list of (cuda) device IDs to upload the model/tensors to; can be empty if only the CPU is available. epochs: number of epochs to train the model for. logger: used to output debug/warning/error messages to session log. model: reference to the model being trained or used for evaluation/prediction. monitor: name of the training/validation metric that should be monitored for model improvement. name: name of the session, used for printing and creating log folders. optimization_config: dictionary of optim-related parameters, parsed at training time. output_paths: map of session output paths where training/evaluation results should be saved. save_freq: frequency of checkpoint saves while training (i.e. save every X epochs). save_raw: specifies whether to save raw types or thelper objects in checkpoints. skip_eval_iter: number of evaluation iterations to skip (useful for resuming a session). skip_tbx_histograms: flag used to skip the generation of graph histograms in tbx (useful for large models). task: reference to the object used to specialize the model and that holds task metainformation. tbx_histogram_freq: frequency of tbx histogram saves while training (i.e. save every X epochs). use_tbx: defines whether to use tensorboardX writers for logging or not. writers: map of tbx writers used to save training/evaluation events. .. seealso:: | :class:`thelper.train.base.Trainer` """
[docs] def __init__(self, session_name, # type: AnyStr session_dir, # type: AnyStr model, # type: thelper.typedefs.ModelType task, # type: thelper.tasks.Task loaders, # type: thelper.typedefs.MultiLoaderType config, # type: thelper.typedefs.ConfigDict ckptdata=None # type: Optional[thelper.typedefs.CheckpointContentType] ): """Receives the trainer configuration dictionary, parses it, and sets up the session.""" assert isinstance(model, (thelper.nn.Module, torch.nn.Module)), "unknown model object type" assert isinstance(task, thelper.tasks.Task), "unknown task object type" assert isinstance(loaders, (list, tuple, np.ndarray)) and len(loaders) == 3, "invalid loaders array" assert isinstance(config, dict), "invalid config type" self.task = task self.model = model self.config = config # parse basic training config args # use 'trainer' key first for backward compatibility and to prioritize it - most configs will define it as so trainer_config = thelper.utils.get_key(["trainer", "runner", "tester"], config) os.makedirs(session_dir, exist_ok=True) logs_dir = os.path.join(session_dir, "logs") os.makedirs(logs_dir, exist_ok=True) thelper.utils.init_logger() # make sure all logging is initialized before attaching this part thelper.utils.save_env_list(os.path.join(logs_dir, "packages.log")) train_logger_path = os.path.join(logs_dir, "trainer.log") train_logger_format = logging.Formatter("[%(asctime)s - %(process)s] %(levelname)s : %(message)s") train_logger_fh = logging.FileHandler(train_logger_path) train_logger_fh.setLevel(logging.NOTSET) train_logger_fh.setFormatter(train_logger_format) self.logger = thelper.utils.get_class_logger() self.logger.addHandler(train_logger_fh) self.logger.info(f"created training log for session '{session_name}'") self.logger.debug(f"session directory = {os.path.abspath(session_dir)}") self.logger.debug(f"logs directory = {os.path.abspath(logs_dir)}") logstamp = thelper.utils.get_log_stamp() repover = thelper.__version__ + ":" + thelper.utils.get_git_stamp() self.logger.debug(f"logstamp = {logstamp}") self.logger.debug(f"version = {repover}") self.name = session_name self.epochs = 1 self.save_freq = int(thelper.utils.get_key_def("save_freq", trainer_config, 1)) assert self.save_freq >= 1, "checkpoint save frequency should be strictly positive integer" self.save_raw = thelper.utils.str2bool(thelper.utils.get_key_def("save_raw", trainer_config, True)) self.checkpoint_dir = os.path.join(session_dir, "checkpoints") os.makedirs(self.checkpoint_dir, exist_ok=True) output_root_dir = thelper.utils.get_key_def("output_dir", trainer_config) if not output_root_dir: # append session name for cleaner TBX folder merging output_root_dir = os.path.join(session_dir, "output", self.name) assert isinstance(output_root_dir, str) and len(output_root_dir), "invalid output directory path" self.logger.debug(f"output directory = {os.path.abspath(output_root_dir)}") os.makedirs(output_root_dir, exist_ok=True) unique_output_dir = thelper.utils.get_key_def("unique_output_dir", trainer_config, True) assert isinstance(unique_output_dir, bool), "invalid unique_output_dir flag (should be bool)" self.logger.debug(f"output subdirectories {'will' if unique_output_dir else 'will not'} have unique names") devices_str = thelper.utils.get_key_def(["device", "devices", "train_device"], trainer_config, None) self.devices = self._load_devices(devices_str) self.skip_eval_iter = thelper.utils.get_key_def("skip_eval_iter", trainer_config, 0) # parse and prepare tbx stuff tbx_config_flags = ["use_tbx", "tbx", "use_tb", "tb", "tensorboard"] self.use_tbx = thelper.utils.str2bool(thelper.utils.get_key_def(tbx_config_flags, trainer_config, False)) if self.use_tbx: try: import tensorboardX self.tbx = tensorboardX logger.debug("using external tensorboard") except ImportError: import torch.utils.tensorboard as tensorboard self.tbx = tensorboard logger.debug("using PyTorch's tensorboard") self.logger.debug( f"tensorboard init : tensorboard --logdir {os.path.abspath(output_root_dir)} --port <your_port>") self.skip_tbx_histograms = thelper.utils.str2bool( thelper.utils.get_key_def("skip_tbx_histograms", trainer_config, False)) self.tbx_histogram_freq = int(thelper.utils.get_key_def("tbx_histogram_freq", trainer_config, 5)) assert self.tbx_histogram_freq >= 1, "histogram output frequency should be strictly positive integer" timestr = time.strftime("%Y%m%d-%H%M%S") self.writers, self.output_paths = {}, {} for cname, loader in zip(["train", "valid", "test"], loaders): if loader: folder_name = f"{cname}-{str(platform.node())}-{timestr}" if unique_output_dir else cname self.output_paths[cname] = os.path.join(output_root_dir, folder_name) self.logger.debug(f"output {cname} directory = {os.path.abspath(self.output_paths[cname])}") os.makedirs(self.output_paths[cname], exist_ok=True) else: self.output_paths[cname] = None self.writers[cname] = None # will be instantiated only when needed based on above path # split loaders train_loader, valid_loader, test_loader = loaders assert (train_loader or valid_loader or test_loader), "must provide at least one loader with available data" self.train_loader, self.valid_loader, self.test_loader = train_loader, valid_loader, test_loader if train_loader: assert "epochs" in trainer_config and int(trainer_config["epochs"]) > 0, "bad trainer config epoch count" self.epochs = int(trainer_config["epochs"]) # loading optimization stuff later since model needs to be on correct device self.optimization_config = thelper.utils.get_key_def("optimization", trainer_config, {}) else: self.logger.info("no training data provided, will run a single epoch on valid/test data") # parse metrics assert "metrics" not in trainer_config or "base_metrics" not in trainer_config, \ "trainer config should have only one of 'metrics' and 'base_metrics'" metrics = {} if "metrics" in trainer_config: self.logger.debug("loading metrics defined in trainer config") metrics = thelper.train.create_consumers(trainer_config["metrics"]) elif "base_metrics" in trainer_config: self.logger.debug("loading base metrics defined in trainer config") metrics = thelper.train.create_consumers(trainer_config["base_metrics"]) self.train_metrics, self.valid_metrics, self.test_metrics = \ deepcopy(metrics), deepcopy(metrics), deepcopy(metrics) for skey, sval in zip(["train_metrics", "valid_metrics", "test_metrics"], [self.train_metrics, self.valid_metrics, self.test_metrics]): if skey in trainer_config: new_metrics = thelper.train.create_consumers(trainer_config[skey]) for mkey, mval in new_metrics.items(): assert mkey not in sval, f"metric name '{mkey}' duplicated in set '{skey}'" sval[mkey] = mval for mkey, mval in sval.items(): self.logger.info(f"parsed metric '{mkey}': {str(mval)}") # check for monitored metric self.monitor, self.monitor_best, self.monitor_best_epoch = None, None, -1 if "monitor" in trainer_config and trainer_config["monitor"]: self.monitor = trainer_config["monitor"] if self.monitor == "loss": self.monitor_goal = thelper.optim.Metric.minimize self.monitor_best = thelper.optim.Metric.maximize else: assert any([self.monitor in mset for mset in [self.train_metrics, self.valid_metrics]]), \ f"metric with name '{self.monitor}' could not be found in training/validation metrics" metric = self.valid_metrics[self.monitor] if self.monitor in self.valid_metrics \ else self.train_metrics[self.monitor] # makes no sense to search for it in test metrics... assert isinstance(metric, thelper.optim.metrics.Metric), \ "monitoring target should be an actual 'metric' class that returns a scalar!" assert metric.goal in [thelper.optim.Metric.minimize, thelper.optim.Metric.maximize], \ "monitored metric does not return proper optimization goal" self.monitor_goal = metric.goal self.monitor_best = thelper.optim.Metric.minimize if metric.goal == thelper.optim.Metric.maximize \ else thelper.optim.Metric.maximize self.logger.debug(f"will monitor metric '{self.monitor}' for best state checkpointing/early stopping") # parse checkpoint data from previous run (if available) ckptdata = {} if ckptdata is None else ckptdata self.monitor_best = thelper.utils.get_key_def("monitor_best", ckptdata, self.monitor_best) self.monitor_best_epoch = thelper.utils.get_key_def("monitor_best_epoch", ckptdata, -1) self.optimizer_state = thelper.utils.get_key_def("optimizer", ckptdata, None) self.scheduler_state = thelper.utils.get_key_def("scheduler", ckptdata, None) self.current_iter = thelper.utils.get_key_def("iter", ckptdata, 0) self.current_epoch = thelper.utils.get_key_def("epoch", ckptdata, 0) self.outputs = thelper.utils.get_key_def("outputs", ckptdata, {}) # parse callbacks (see ``thelper.typedefs.IterCallbackType`` and ``thelper.typedefs.IterCallbackParams``) for cname, mset in zip(["train", "valid", "test"], [self.train_metrics, self.valid_metrics, self.test_metrics]): # parse user (custom) callback user_callback_keys = [f"{cname}_iter_callback", f"{cname}_callback", "callback"] user_callback = thelper.utils.get_key_def( user_callback_keys, trainer_config) # type: Optional[thelper.typedefs.IterCallbackType] if user_callback is not None: assert f"{cname}_user_callback" not in mset, f"metrics set already had a '{cname}_user_callback' in it" mset[f"{cname}_user_callback"] = thelper.train.utils.PredictionCallback(user_callback) # parse display callback display_callback_keys = [f"display_{cname}_preds", f"display_{cname}_predictions", f"display_{cname}", "display_preds", "display_predictions", "display"] display_callback = thelper.utils.get_key_def(display_callback_keys, trainer_config) if display_callback: assert f"{cname}_display_callback" not in mset, \ f"metrics set already had a '{cname}_display_callback' in it" if isinstance(display_callback, bool): # if simply toggled on, use default draw function wrapper display_callback = {"type": "thelper.train.utils._draw_wrapper", "params": {"save": False}} mset[f"{cname}_display_callback"] = thelper.train.utils.PredictionCallback(display_callback) # parse logging callback logging_callback_keys = \ [f"{cname}_logger", f"{cname}_log", f"logger_{cname}", f"log_{cname}", "log", "logger"] logging_callback = \ thelper.utils.get_key_def(logging_callback_keys, trainer_config, self._iter_logger_callback) if logging_callback: assert f"{cname}_logger_callback" not in mset, \ f"metrics set already had a '{cname}_logger_callback' in it" logging_kwargs = {"set_name": cname, "writers": self.writers} # pass writers by ref, fill later mset[f"{cname}_logger_callback"] = \ thelper.train.utils.PredictionCallback(logging_callback, logging_kwargs) else: logger.warning("logging is disabled by user, internal iteration count might never be updated") # parse visualization config (if any) self.viz = thelper.utils.get_key_def(["viz", "visualization", "visualizations"], trainer_config, {}) assert isinstance(self.viz, dict), "invalid visulaization dictionary config" for viz_key, viz_config in self.viz.items(): assert isinstance(viz_key, str) and viz_key in thelper.viz.supported_types, \ f"invalid visualization type '{viz_key}' (not in available modules)" assert isinstance(viz_config, dict), f"invalid visualization configuration dictionary for type '{viz_key}'"
def _init_writer(self, writer, path): if self.use_tbx and not writer: writer = self.tbx.SummaryWriter(path, comment=self.name) writer.add_text("config", json.dumps(self.config, indent=4, sort_keys=False, default=lambda x: str(x))) thelper.utils.save_config(self.config, os.path.join(path, "config.json")) return writer @staticmethod def _set_rng_state(seeds, epoch): if "torch" in seeds: torch.manual_seed(seeds["torch"] + epoch) torch.cuda.manual_seed_all(seeds["torch"] + epoch) if "numpy" in seeds: np.random.seed(seeds["numpy"] + epoch) if "random" in seeds: random.seed(seeds["random"] + epoch) @staticmethod def _upload_model(model, dev): """Uploads a model to a specific device, wrapping it in ``torch.nn.DataParallel`` if needed.""" if isinstance(dev, list): if len(dev) == 0: return model.cpu() elif len(dev) == 1: return model.cuda(dev[0]) else: return torch.nn.DataParallel(model, device_ids=dev).cuda(dev[0]) else: return model.to(dev) @staticmethod def _move_tensor(tensor, dev, non_blocking=True, detach=False): """Uploads a tensor to a specific device.""" if isinstance(tensor, (list, tuple)): return [SessionRunner._move_tensor(t, dev) for t in tensor] if isinstance(tensor, dict): return {k: SessionRunner._move_tensor(t, dev) for k, t in tensor.items()} if not isinstance(tensor, torch.Tensor): return tensor # ignored (cannot upload) if isinstance(dev, list): if len(dev) == 0: out = tensor.cpu() else: # no reason to have multiple devices if not cuda-enabled GPUs out = tensor.cuda(dev[0], non_blocking=non_blocking) else: out = tensor.to(dev, non_blocking=non_blocking) return out.detach() if detach else out def _load_optimization(self, model, dev): """Instantiates and returns all optimization objects required for training the model.""" config = self.optimization_config # for abbrev only assert isinstance(config, dict), "optimization config should be provided as a dictionary" assert self.train_loader is not None and self.train_loader, "optimization only useful with training data" loss = None # can be omitted if using custom trainer if "loss" in config: uploader = functools.partial(self._move_tensor, dev=dev) loss = thelper.optim.create_loss_fn(config["loss"], model, self.train_loader, uploader) optimizer = None # can be omitted if using custom trainer if "optimizer" in config: optimizer = thelper.optim.create_optimizer(config["optimizer"], model) scheduler, scheduler_step_metric = None, None if "scheduler" in config and config["scheduler"]: # can always be omitted scheduler, scheduler_step_metric = thelper.optim.create_scheduler(config["scheduler"], optimizer) return loss, optimizer, scheduler, scheduler_step_metric def _load_devices(self, devices_str=None): """Validates and returns the list of CUDA devices available on the system.""" self.logger.debug("loading available devices") if devices_str is not None: devices = [] available_cuda_devices = None assert isinstance(devices_str, (str, list)), "unexpected device string type" if isinstance(devices_str, str): assert devices_str, "cannot specify empty device name, use 'None' to auto-detect" devices_str = devices_str.split(",") elif isinstance(devices_str, list): assert devices_str, "cannot specify empty device list, use 'None' to auto-detect" assert all([isinstance(dev_str, str) for dev_str in devices_str]), "unexpected type in dev list" for dev_idx, dev_str in enumerate(devices_str): assert "cuda" in dev_str or dev_str == "cpu", \ f"unknown device type '{dev_str}' (expecting 'cpu' or 'cuda:X')" if dev_str == "cpu": assert len(devices_str) == 1, "cannot combine cpu with other devices" return [] if dev_str == "cuda" or dev_str == "cuda:all": assert len(devices_str) == 1, "must specify device index (e.g. 'cuda:0') if combining devices" if available_cuda_devices is None: available_cuda_devices = thelper.utils.get_available_cuda_devices() assert available_cuda_devices, "could not find any available cuda devices" return available_cuda_devices assert "cuda:" in dev_str, "expecting cuda device format to be 'cuda:X' (where X is device index)" cuda_dev_idx = int(dev_str.rsplit(":", 1)[-1]) assert thelper.utils.test_cuda_device_availability(cuda_dev_idx), f"cuda device '{dev_str}' unavailable" devices.append(cuda_dev_idx) return devices else: return thelper.utils.get_available_cuda_devices() def _to_tensor(self, sample): """Fetches and returns tensors of input and groundtruth data from a batched sample dictionary. The specifics of how to unpack a sample dictionary into usable parts is tied to the trainer, so it cannot be defined in a perfectly generic way here. The implementation below is given as a baseline to support some visualization techniques (see :mod:`thelper.viz` for more info). Derived trainers (both custom and framework-provided) are likely to override this function to properly unpack groundtruth data. Args: sample: the (batched) sample to unpack into tensors, obtained directly from a data loader. Returns: A tuple of input data and groundtruth data tensors. In this implementation, the groundtruth data tensor is always ``None``. """ assert isinstance(sample, dict), "trainer expects samples to come in dicts for key-based usage" assert self.task.input_key in sample, f"could not find input key '{self.task.input_key}' in sample dict" return torch.FloatTensor(sample[self.task.input_key]), None def _iter_logger_callback(self, # see `thelper.typedefs.IterCallbackParams` for more info task, # type: thelper.tasks.utils.Task input, # type: thelper.typedefs.InputType pred, # type: thelper.typedefs.AnyPredictionType target, # type: thelper.typedefs.AnyTargetType sample, # type: thelper.typedefs.SampleType loss, # type: Optional[float] iter_idx, # type: int max_iters, # type: int epoch_idx, # type: int max_epochs, # type: int output_path, # type: AnyStr # note: kwargs must contain two args here: 'set_name' and 'writers' **kwargs, # type: Any ): # type: (...) -> None """Receives callback data for logging loss/monitored metric values each training/eval iteration.""" # NOTE: THIS FUNCTION IS RESPONSIBLE FOR INCREASING THE INTERNAL ITERATION COUNTER. set_name = thelper.utils.get_key("set_name", kwargs, "missing set name in iter logger args") assert set_name in ["train", "valid", "test"], "unrecognized iter logger set name" metrics = self.train_metrics if set_name == "train" else self.valid_metrics if set_name == "valid" \ else self.test_metrics monitor_val = None monitor_str = "" if self.monitor is not None and self.monitor in metrics: assert isinstance(metrics[self.monitor], thelper.optim.metrics.Metric), "unexpected metric type" if metrics[self.monitor].live_eval: monitor_val = metrics[self.monitor].eval() monitor_str = f" {self.monitor}: {monitor_val:.2f}" loss_str = "" if loss is not None: loss_str = f" loss: {loss:.6f}" assert self.current_epoch == epoch_idx, "something's messed up" self.logger.info( f"{set_name} epoch#{epoch_idx} (iter#{self.current_iter})" + f" batch: {iter_idx + 1}/{max_iters} ({((iter_idx + 1) / max_iters) * 100.0:.0f}%)" + f"{loss_str}{monitor_str}" ) writers = thelper.utils.get_key("writers", kwargs, msg="missing writers dict in iter logger args") if (set_name == "train" or iter_idx == max_iters - 1) and writers[set_name]: if loss is not None: writers[set_name].add_scalar("iter/loss", loss, self.current_iter) for metric_name, metric in metrics.items(): if isinstance(metric, thelper.optim.metrics.Metric): if metric_name == self.monitor and monitor_val is not None: writers[set_name].add_scalar(f"iter/{self.monitor}", monitor_val, self.current_iter) elif metric.live_eval: # if live eval is not true, metric might be too heavy to compute at each iteration writers[set_name].add_scalar(f"iter/{metric_name}", metric.eval(), self.current_iter) if set_name == "train": self.current_iter += 1 def _write_data(self, data, writer_prefix, file_suffix, writer, output_path, idx=None): """Writes a generic chunk of data passed as a dictionary to the specified output path.""" os.makedirs(output_path, exist_ok=True) assert isinstance(data, dict) and all([isinstance(key, str) for key in data]), \ "unexpected data chunk formatting (should be dict with str-based keys)" reserved_keys = ["/image", "/extension", "/json", "/text", "/pickle"] for key, val in data.items(): if thelper.utils.is_scalar(val) and not any([key.endswith(s) for s in reserved_keys]): if writer is not None: if isinstance(val, str): writer.add_text(f"{writer_prefix}{key}", val, idx) else: writer.add_scalar(f"{writer_prefix}{key}", val, idx) if key.endswith("/image") and val is not None: # some metrics got the callable but return None assert isinstance(val, np.ndarray) and len(val.shape) == 3 and val.shape[2] == 3, \ "unexpected image format (should be numpy array with RGB channels)" image_ext = thelper.utils.get_key_def(key + "/extension", data, "png") image_path = os.path.join(output_path, f"{''.join(key.rsplit('/image', 1))}{file_suffix}.{image_ext}") self.logger.debug(f"writing {key} to {os.path.abspath(image_path)}") cv.imwrite(image_path, val[..., ::-1]) # flip to BGR for opencv compat if writer is not None: writer.add_image(f"{writer_prefix}{key}", val, idx, dataformats="HWC") if key.endswith("/json"): json_ext = thelper.utils.get_key_def(key + "/extension", data, "json") json_path = os.path.join(output_path, f"{''.join(key.rsplit('/json', 1))}{file_suffix}.{json_ext}") self.logger.debug(f"writing {key} to {os.path.abspath(json_path)}") with open(json_path, "w") as fd: json.dump(val, fd) if key.endswith("/text"): txt_ext = thelper.utils.get_key_def(key + "/extension", data, "txt") txt_path = os.path.join(output_path, f"{''.join(key.rsplit('/text', 1))}{file_suffix}.{txt_ext}") self.logger.debug(f"writing {key} to {os.path.abspath(txt_path)}") with open(txt_path, "w") as fd: fd.write(val) if key.endswith("/pickle"): pkl_ext = thelper.utils.get_key_def(key + "/extension", data, "pkl") pkl_path = os.path.join(output_path, f"{''.join(key.rsplit('/pickle', 1))}{file_suffix}.{pkl_ext}") self.logger.debug(f"writing {key} to {os.path.abspath(pkl_path)}") with open(pkl_path, "wb") as fd: pickle.dump(val, fd) def _write_metrics_data(self, epoch, metrics, tbx_writer, output_path, loss=None, optimizer=None, use_suffix=True): """Writes the cumulative evaluation result of all metrics using a specific writer.""" os.makedirs(output_path, exist_ok=True) if tbx_writer is not None: if loss is not None: tbx_writer.add_scalar("epoch/loss", loss, epoch) if optimizer is not None: tbx_writer.add_scalar("epoch/lr", thelper.optim.get_lr(optimizer), epoch) writer_prefix = "epoch/" file_suffix = f"-{epoch:04d}" if use_suffix else "" for metric_name, metric in metrics.items(): output = {} if hasattr(metric, "render") and callable(metric.render): output[f"{metric_name}/image"] = metric.render() output[f"{metric_name}/image/extension"] = "png" if hasattr(metric, "report") and callable(metric.report): output[f"{metric_name}/text"] = metric.report() output[f"{metric_name}/text/extension"] = getattr(metric, "ext", "txt") if hasattr(metric, "eval") and callable(metric.eval): eval_res = metric.eval() if f"{metric_name}/text" not in output and eval_res is not None: if isinstance(eval_res, float): output[f"{metric_name}/text"] = f"{eval_res:.4f}" else: output[f"{metric_name}/text"] = str(eval_res) output[f"{metric_name}/text/extension"] = getattr(metric, "ext", "txt") output[metric_name] = eval_res self._write_data(output, writer_prefix, file_suffix, tbx_writer, output_path, epoch) def _save(self, epoch, iter, optimizer, scheduler, save_best=False): """Saves a session checkpoint containing all the information required to resume training.""" # logically, this should only be called during training (i.e. with a valid optimizer) log_stamp = thelper.utils.get_log_stamp() # the saved state below should be kept compatible with the one in thelper.cli.export_model curr_state = { "name": self.name, "epoch": epoch, "iter": iter, "source": log_stamp, "git_sha1": thelper.utils.get_git_stamp(), "version": thelper.__version__, "task": str(self.task) if self.save_raw else self.task, "outputs": self.outputs, # we save model type/params here in case those are not in the current config "model": self.model.state_dict() if self.save_raw else self.model, "model_type": self.model.get_name(), "model_params": self.model.config if self.model.config else {}, "optimizer": optimizer.state_dict() if optimizer is not None else None, "scheduler": scheduler.state_dict() if (scheduler is not None and hasattr(scheduler, "state_dict")) else None, "monitor_best": self.monitor_best, "monitor_best_epoch": self.monitor_best_epoch, "config": self.config # note: this is the global app config } filename = f"ckpt.{epoch:04d}.{log_stamp}.pth" filename = os.path.join(self.checkpoint_dir, filename) self.logger.debug(f"writing checkpoint to {os.path.abspath(filename)}") torch.save(curr_state, filename) if save_best: filename_best = os.path.join(self.checkpoint_dir, "ckpt.best.pth") self.logger.debug(f"writing checkpoint to {os.path.abspath(filename_best)}") torch.save(curr_state, filename_best)