"""General utilities module.
This module only contains non-ML specific functions, i/o helpers,
and matplotlib/pyplot drawing calls.
"""
import copy
import errno
import functools
import glob
import hashlib
import importlib
import importlib.util
import inspect
import io
import json
import logging
import os
import pathlib
import pickle
import platform
import re
import sys
import time
from distutils.version import LooseVersion
from typing import TYPE_CHECKING
import h5py
import hdf5plugin
import numpy as np
import torch
import yaml
import thelper.typedefs # noqa: F401
if TYPE_CHECKING:
from typing import Any, AnyStr, Callable, Dict, List, Mapping, Optional, Tuple, Type, Union # noqa: F401
from types import FunctionType # noqa: F401
from thelper.session.base import SessionRunner
logger = logging.getLogger(__name__)
bypass_queries = False
warned_generic_draw = False
fixed_yaml_parsing = False
no_compression_flags = ["None", "none", "raw", "", None]
chunk_compression_flags = ["chunk_lz4", "gzip", "lzf", "szip"]
[docs]class Struct:
"""Generic runtime-defined C-like data structure (maps constructor elements to fields)."""
[docs] def __init__(self, **kwargs):
for key, val in kwargs.items():
setattr(self, key, val)
def __repr__(self):
return self.__class__.__module__ + "." + self.__class__.__qualname__ + \
"(" + ", ".join([f"{key}={repr(val)}" for key, val in self.__dict__.items()]) + ")"
[docs]def test_cuda_device_availability(device_idx):
# type: (int) -> bool
"""Tests the availability of a single cuda device and returns its status."""
# noinspection PyBroadException
try:
torch.cuda.set_device(device_idx)
test_val = torch.cuda.FloatTensor([1])
return test_val.cpu().item() == 1.0
except Exception:
return False
[docs]def get_available_cuda_devices(attempts_per_device=5):
# type: (Optional[int]) -> List[int]
"""
Tests all visible cuda devices and returns a list of available ones.
Returns:
List of available cuda device IDs (integers). An empty list means no
cuda device is available, and the app should fallback to cpu.
"""
if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
return []
devices_available = [False] * torch.cuda.device_count()
attempt_broadcast = False
for attempt in range(attempts_per_device):
for device_id in range(torch.cuda.device_count()):
if not devices_available[device_id]:
if not attempt_broadcast:
logger.debug("testing availability of cuda device #%d (%s)" % (
device_id, torch.cuda.get_device_name(device_id)
))
devices_available[device_id] = test_cuda_device_availability(device_id)
attempt_broadcast = True
return [device_id for device_id, available in enumerate(devices_available) if available]
[docs]def setup_plt(config):
"""Parses the provided config for matplotlib flags and sets up its global state accordingly."""
import matplotlib.pyplot as plt
config = get_key_def(["plt", "pyplot", "matplotlib"], config, {})
if "backend" in config:
import matplotlib
matplotlib.use(get_key("backend", config))
plt.interactive(get_key_def("interactive", config, False))
# noinspection PyUnusedLocal
[docs]def setup_cv2(config):
"""Parses the provided config for OpenCV flags and sets up its global state accordingly."""
# https://github.com/pytorch/pytorch/issues/1355
import cv2 as cv
cv.setNumThreads(0)
cv.ocl.setUseOpenCL(False)
# todo: add more global opencv flags setups here
[docs]def setup_gdal(config):
"""Parses the provided config for GDAL flags and sets up its global state accordingly."""
config = get_key_def("gdal", config, {})
if "proj_search_path" in config:
import osr
osr.SetPROJSearchPath(config["proj_search_path"])
[docs]def setup_cudnn(config):
"""Parses the provided config for CUDNN flags and sets up PyTorch accordingly."""
if "cudnn" in config and isinstance(config["cudnn"], dict):
config = config["cudnn"]
if "benchmark" in config:
cudnn_benchmark_flag = str2bool(config["benchmark"])
logger.debug("cudnn benchmark mode = %s" % str(cudnn_benchmark_flag))
torch.backends.cudnn.benchmark = cudnn_benchmark_flag
if "deterministic" in config:
cudnn_deterministic_flag = str2bool(config["deterministic"])
logger.debug("cudnn deterministic mode = %s" % str(cudnn_deterministic_flag))
torch.backends.cudnn.deterministic = cudnn_deterministic_flag
else:
if "cudnn_benchmark" in config:
cudnn_benchmark_flag = str2bool(config["cudnn_benchmark"])
logger.debug("cudnn benchmark mode = %s" % str(cudnn_benchmark_flag))
torch.backends.cudnn.benchmark = cudnn_benchmark_flag
if "cudnn_deterministic" in config:
cudnn_deterministic_flag = str2bool(config["cudnn_deterministic"])
logger.debug("cudnn deterministic mode = %s" % str(cudnn_deterministic_flag))
torch.backends.cudnn.deterministic = cudnn_deterministic_flag
[docs]def setup_sys(config):
"""Parses the provided config for PYTHON sys paths and sets up its global state accordingly."""
paths_to_add = []
if "sys" in config:
if isinstance(config["sys"], list):
paths_to_add = config["sys"]
elif isinstance(config["sys"], str):
paths_to_add = [config["sys"]]
for dir_path in paths_to_add:
if os.path.isdir(dir_path):
logger.debug(f"will append path to python's syspaths: {dir_path}")
sys.path.append(dir_path)
else:
logger.warning(f"could not append to syspaths, invalid dir: {dir_path}")
[docs]def setup_globals(config):
"""Parses the provided config for global flags and sets up the global state accordingly."""
if "bypass_queries" in config and config["bypass_queries"]:
global bypass_queries
bypass_queries = True
setup_sys(config)
setup_plt(config)
setup_cv2(config)
setup_gdal(config)
setup_cudnn(config)
[docs]def load_checkpoint(ckpt, # type: thelper.typedefs.CheckpointLoadingType
map_location=None, # type: Optional[thelper.typedefs.MapLocationType]
always_load_latest=False, # type: Optional[bool]
check_version=True, # type: Optional[bool]
): # type: (...) -> thelper.typedefs.CheckpointContentType
"""Loads a session checkpoint via PyTorch, check its compatibility, and returns its data.
If the ``ckpt`` parameter is a path to a valid directory, then that directly will be searched for
a checkpoint. If multiple checkpoints are found, the latest will be returned (based on the epoch
index in its name). iF ``always_load_latest`` is set to False and if a checkpoint named
``ckpt.best.pth`` is found, it will be returned instead.
Args:
ckpt: a file-like object or a path to the checkpoint file or session directory.
map_location: a function, string or a dict specifying how to remap storage
locations. See ``torch.load`` for more information.
always_load_latest: toggles whether to always try to load the latest checkpoint
if a session directory is provided (instead of loading the 'best' checkpoint).
check_version: toggles whether the checkpoint's version should be checked for
compatibility issues, and query the user for how to proceed.
Returns:
Content of the checkpoint (a dictionary).
"""
if map_location is None and not get_available_cuda_devices():
map_location = 'cpu'
if isinstance(ckpt, str) and os.path.isdir(ckpt):
logger.debug("will search directory '%s' for a checkpoint to load..." % ckpt)
search_ckpt_dir = os.path.join(ckpt, "checkpoints")
if os.path.isdir(search_ckpt_dir):
search_dir = search_ckpt_dir
else:
search_dir = ckpt
ckpt_paths = glob.glob(os.path.join(search_dir, "ckpt.*.pth"))
if not ckpt_paths:
raise AssertionError("could not find any valid checkpoint files in directory '%s'" % search_dir)
latest_epoch, latest_day, latest_time = -1, -1, -1
for ckpt_path in ckpt_paths:
# note: the 2nd field in the name should be the epoch index, or 'best' if final checkpoint
split = os.path.basename(ckpt_path).split(".")
tag = split[1]
if tag == "best" and (not always_load_latest or latest_epoch == -1):
# if eval-only, always pick the best checkpoint; otherwise, only pick if nothing else exists
ckpt = ckpt_path
if not always_load_latest:
break
elif tag != "best":
log_stamp = split[2] if len(split) > 2 else ""
log_stamp = "fake-0-0" if log_stamp.count("-") != 2 else log_stamp
epoch_stamp, day_stamp, time_stamp = int(tag), int(log_stamp.split("-")[1]), int(log_stamp.split("-")[2])
if epoch_stamp > latest_epoch or day_stamp > latest_day or time_stamp > latest_time:
ckpt, latest_epoch, latest_day, latest_time = ckpt_path, epoch_stamp, day_stamp, time_stamp
if not os.path.isfile(ckpt):
raise AssertionError("could not find valid checkpoint at '%s'" % ckpt)
basepath = None
if isinstance(ckpt, str):
logger.debug("parsing checkpoint at '%s'" % ckpt)
basepath = os.path.dirname(os.path.abspath(ckpt))
else:
if hasattr(ckpt, "name"):
logger.debug("parsing checkpoint provided via file object")
basepath = os.path.dirname(os.path.abspath(ckpt.name))
ckptdata = torch.load(ckpt, map_location=map_location)
if not isinstance(ckptdata, dict):
raise AssertionError("unexpected checkpoint data type")
if check_version:
good_version = False
from thelper import __version__ as curr_ver
if "version" not in ckptdata:
logger.warning("checkpoint missing internal version tag")
ckpt_ver_str = "0.0.0"
else:
ckpt_ver_str = ckptdata["version"]
if not isinstance(ckpt_ver_str, str) or len(ckpt_ver_str.split(".")) != 3:
raise AssertionError("unexpected checkpoint version formatting")
# by default, checkpoints should be from the same minor version, we warn otherwise
versions = [curr_ver.split("."), ckpt_ver_str.split(".")]
if versions[0][0] != versions[1][0]:
logger.error("incompatible checkpoint, major version mismatch (%s vs %s)" % (curr_ver, ckpt_ver_str))
elif versions[0][1] != versions[1][1]:
logger.warning("outdated checkpoint, minor version mismatch (%s vs %s)" % (curr_ver, ckpt_ver_str))
else:
good_version = True
if not good_version:
answer = query_string("Checkpoint version unsupported (framework=%s, checkpoint=%s); how do you want to proceed?" %
(curr_ver, ckpt_ver_str), choices=["continue", "migrate", "abort"], default="migrate", bypass="migrate")
if answer == "abort":
logger.error("checkpoint out-of-date; user aborted")
sys.exit(1)
elif answer == "continue":
logger.warning("will attempt to load checkpoint anyway (might crash later due to incompatibilities)")
elif answer == "migrate":
ckptdata = migrate_checkpoint(ckptdata)
# load model trace if needed (we do it here since we can locate the neighboring file)
if "model" in ckptdata and isinstance(ckptdata["model"], str):
trace_path = None
if os.path.isfile(ckptdata["model"]):
trace_path = ckptdata["model"]
elif basepath is not None and os.path.isfile(os.path.join(basepath, ckptdata["model"])):
trace_path = os.path.join(basepath, ckptdata["model"])
if trace_path is not None:
if trace_path.endswith(".pth"):
ckptdata["model"] = torch.load(trace_path, map_location=map_location)
elif trace_path.endswith(".zip"):
ckptdata["model"] = torch.jit.load(trace_path, map_location=map_location)
return ckptdata
[docs]def check_version(version_check, version_required):
# type: (AnyStr, AnyStr) -> Tuple[bool, List[Union[int, AnyStr]], List[Union[int, AnyStr]]]
"""Verifies that the checked version is not greater than the required one (ie: not a future version).
Version format is ``MAJOR[.MINOR[.PATCH[[-]<RELEASE>]]]``.
Note that for ``RELEASE`` part, comparison depends on alphabetical order if all other previous parts were equal
(i.e.: ``alpha`` will be lower than ``beta``, which in turn is lower than ``rc`` and so on). The ``-`` is optional
and will be removed for comparison (i.e.: ``0.5.0-rc`` is exactly the same as ``0.5.0rc`` and the additional ``-``
will not result in evaluating ``0.5.0a0`` as a greater version because of ``-`` being lower ascii than ``a``).
Args:
version_check: the version string that needs to be verified and compared for lower than the required version.
version_required: the control version against which the check is done.
Returns:
Tuple of the validated check, and lists of both parsed version parts as ``[MAJOR, MINOR, PATCH, 'RELEASE']``.
The returned lists are *guaranteed* to be formed of 4 elements, adding 0 or '' as applicable for missing parts.
"""
v_check = LooseVersion(version_check)
v_req = LooseVersion(version_required)
l_check = [0, 0, 0, '']
l_req = [0, 0, 0, '']
for ver_list, ver_parse in [(l_check, v_check), (l_req, v_req)]:
for v in [0, 1, 2]:
ver_list[v] = 0 if len(ver_parse.version) < v + 1 else ver_parse.version[v]
if len(ver_parse.version) >= 4:
release_idx = 4 if len(ver_parse.version) >= 5 and ver_parse.version[3] == '-' else 3
ver_list[3] = ''.join(str(v) for v in ver_parse.version[release_idx:])
# check with re-parsed version after fixing release dash
v_ok = LooseVersion('.'.join(str(v) for v in l_check)) <= LooseVersion('.'.join(str(v) for v in l_req))
return v_ok, l_check, l_req
[docs]def migrate_checkpoint(ckptdata, # type: thelper.typedefs.CheckpointContentType
): # type: (...) -> thelper.typedefs.CheckpointContentType
"""Migrates the content of an incompatible or outdated checkpoint to the current version of the framework.
This function might not be able to fix all backward compatibility issues (e.g. it cannot fix class interfaces
that were changed). Perfect reproductibility of tests cannot be guaranteed either if this migration tool is used.
Args:
ckptdata: checkpoint data in dictionary form obtained via ``thelper.utils.load_checkpoint``. Note that
the data contained in this dictionary will be modified in-place.
Returns:
An updated checkpoint dictionary that should be compatible with the current version of the framework.
"""
if not isinstance(ckptdata, dict):
raise AssertionError("unexpected ckptdata type")
from thelper import __version__ as curr_ver_str
ckpt_ver_str = ckptdata["version"] if "version" in ckptdata else "0.0.0"
ok_ver, ckpt_ver, curr_ver = check_version(ckpt_ver_str, curr_ver_str)
if not ok_ver:
raise AssertionError("cannot migrate checkpoints from future versions! You need to update your thelper package")
if "config" not in ckptdata:
raise AssertionError("checkpoint migration requires config")
old_config = ckptdata["config"]
new_config = migrate_config(copy.deepcopy(old_config), ckpt_ver_str)
if ckpt_ver[:3] == [0, 0, 0]:
logger.warning("trying to migrate checkpoint data from v0.0.0; all bets are off")
else:
logger.info("trying to migrate checkpoint data from v%s" % ckpt_ver_str)
if ckpt_ver[0] <= 0 and ckpt_ver[1] <= 1:
# combine 'host' and 'time' fields into 'source'
if "host" in ckptdata and "time" in ckptdata:
ckptdata["source"] = ckptdata["host"] + ckptdata["time"]
del ckptdata["host"]
del ckptdata["time"]
# update classif task interface
if "task" in ckptdata and isinstance(ckptdata["task"], thelper.tasks.classif.Classification):
ckptdata["task"] = str(thelper.tasks.classif.Classification(class_names=ckptdata["task"].class_names,
input_key=ckptdata["task"].input_key,
label_key=ckptdata["task"].label_key,
meta_keys=ckptdata["task"].meta_keys,
multi_label=False))
# move 'state_dict' field to 'model'
if "state_dict" in ckptdata:
ckptdata["model"] = ckptdata["state_dict"]
del ckptdata["state_dict"]
# create 'model_type' and 'model_params' fields
if "model" in new_config:
if "type" in new_config["model"]:
ckptdata["model_type"] = new_config["model"]["type"]
else:
ckptdata["model_type"] = None
if "params" in new_config["model"]:
ckptdata["model_params"] = copy.deepcopy(new_config["model"]["params"])
else:
ckptdata["model_params"] = {}
# TODO: create 'scheduler' field to restore previous state? (not so important for early versions)
# ckpt_ver = [0, 2, 0] # set ver for next update step
# if ckpt_ver[0] <= x and ckpt_ver[1] <= y and ckpt_ver[2] <= z:
# ... add more compatibility fixes here
ckptdata["config"] = new_config
return ckptdata
[docs]def migrate_config(config, # type: thelper.typedefs.ConfigDict
cfg_ver_str, # type: str
): # type: (...) -> thelper.typedefs.ConfigDict
"""Migrates the content of an incompatible or outdated configuration to the current version of the framework.
This function might not be able to fix all backward compatibility issues (e.g. it cannot fix class interfaces
that were changed). Perfect reproducibility of tests cannot be guaranteed either if this migration tool is used.
Args:
config: session configuration dictionary obtained e.g. by parsing a JSON file. Note that the data contained
in this dictionary will be modified in-place.
cfg_ver_str: string representing the version for which the configuration was created (e.g. "0.2.0").
Returns:
An updated configuration dictionary that should be compatible with the current version of the framework.
"""
if not isinstance(config, dict):
raise AssertionError("unexpected config type")
from thelper import __version__ as curr_ver_str
ok_ver, cfg_ver, curr_ver = check_version(cfg_ver_str, curr_ver_str)
if not ok_ver:
raise AssertionError("cannot migrate checkpoints from future versions! You need to update your thelper package")
if cfg_ver[:3] == [0, 0, 0]:
logger.warning("trying to migrate config from v0.0.0; all bets are off")
else:
logger.info("trying to migrate config from v%s" % cfg_ver_str)
if cfg_ver[0] <= 0 and cfg_ver[1] < 1:
# must search for name-value parameter lists and convert them to dictionaries
def name_value_replacer(cfg):
if isinstance(cfg, dict):
for key, val in cfg.items():
if (key == "params" or key == "parameters") and isinstance(val, list) and \
all([isinstance(p, dict) and list(p.keys()) == ["name", "value"] for p in val]):
cfg["params"] = {param["name"]: name_value_replacer(param["value"]) for param in val}
if key == "parameters":
del cfg["parameters"]
elif isinstance(val, (dict, list)):
cfg[key] = name_value_replacer(val)
elif isinstance(cfg, list):
for idx, val in enumerate(cfg):
cfg[idx] = name_value_replacer(val)
return cfg
config = name_value_replacer(config)
# must replace "data_config" section by "loaders"
if "data_config" in config:
config["loaders"] = config["data_config"]
del config["data_config"]
# remove deprecated name attribute for models
if "model" in config and isinstance(config["model"], dict) and "name" in config["model"]:
del config["model"]["name"]
# must update import targets wrt class name refactorings
def import_refactoring(cfg): # noqa: E306
if isinstance(cfg, dict):
for key, val in cfg.items():
cfg[key] = import_refactoring(val)
elif isinstance(cfg, list):
for idx, val in enumerate(cfg):
cfg[idx] = import_refactoring(val)
elif isinstance(cfg, str) and cfg.startswith("thelper."):
cfg = thelper.utils.resolve_import(cfg)
return cfg
config = import_refactoring(config)
if "trainer" in config and isinstance(config["trainer"], dict):
trainer_cfg = config["trainer"] # type: thelper.typedefs.ConfigDict
# move 'loss' section to 'optimization' section
if "loss" in trainer_cfg:
if "optimization" not in trainer_cfg or not isinstance(trainer_cfg["optimization"], dict):
trainer_cfg["optimization"] = {}
trainer_cfg["optimization"]["loss"] = trainer_cfg["loss"]
del trainer_cfg["loss"]
# replace all devices with cuda:all
if "train_device" in trainer_cfg:
del trainer_cfg["train_device"]
if "valid_device" in trainer_cfg:
del trainer_cfg["valid_device"]
if "test_device" in trainer_cfg:
del trainer_cfg["test_device"]
if "device" not in trainer_cfg:
trainer_cfg["device"] = "cuda:all"
# remove params from trainer config
if "params" in trainer_cfg:
if not isinstance(trainer_cfg["params"], (dict, list)) or trainer_cfg["params"]:
logger.warning("removing non-empty parameter section from trainer config")
del trainer_cfg["params"]
cfg_ver = [0, 1, 0] # set ver for next update step
if cfg_ver[0] <= 0 and cfg_ver[1] <= 1:
# remove 'force_convert' flags from all transform pipelines + build augment pipeline wrappers
def remove_force_convert(cfg): # noqa: E306
if isinstance(cfg, list):
for idx, stage in enumerate(cfg):
cfg[idx] = remove_force_convert(stage)
elif isinstance(cfg, dict):
if "parameters" in cfg:
cfg["params"] = cfg["parameters"]
del cfg["parameters"]
if "operation" in cfg and cfg["operation"] == "thelper.transforms.TransformWrapper":
if "params" in cfg and "force_convert" in cfg["params"]:
del cfg["params"]["force_convert"]
for key, stage in cfg.items():
cfg[key] = remove_force_convert(stage)
return cfg
for pipeline in ["base_transforms", "train_augments", "valid_augments", "test_augments"]:
if "loaders" in config and isinstance(config["loaders"], dict) and pipeline in config["loaders"]:
if pipeline.endswith("_augments"):
stages = config["loaders"][pipeline]
for stage in stages:
if "append" in stage:
if stage["append"]:
logger.warning("overriding augmentation stage ordering")
del stage["append"]
if "operation" in stage and stage["operation"] == "Augmentor.Pipeline":
if "params" in stage:
stage["params"] = stage["params"]["operations"]
elif "parameters" in stage:
stage["params"] = stage["parameters"]["operations"]
del stage["parameters"]
config["loaders"][pipeline] = {"append": False, "transforms": remove_force_convert(stages)}
else:
config["loaders"][pipeline] = remove_force_convert(config["loaders"][pipeline])
cfg_ver = [0, 2, 0] # set ver for next update step
if cfg_ver[0] <= 0 and cfg_ver[1] <= 2 and cfg_ver[2] < 5:
# TODO: add scheduler 0-based step fix here? (unlikely to cause serious issues)
cfg_ver = [0, 2, 5] # set ver for next update step
if cfg_ver[0] <= 0 and cfg_ver[1] <= 3 and cfg_ver[2] < 6:
if "trainer" in config:
if "eval_metrics" in config["trainer"]:
assert "valid_metrics" not in config["trainer"]
config["trainer"]["valid_metrics"] = config["trainer"]["eval_metrics"]
del config["trainer"]["eval_metrics"]
for mset in ["train_metrics", "valid_metrics", "test_metrics", "metrics"]:
if mset in config["trainer"]:
metrics_config = config["trainer"][mset]
for mname, mcfg in metrics_config.items():
if "type" in mcfg and mcfg["type"].endswith("ExternalMetric"):
assert "params" in mcfg
assert "goal" in mcfg["params"]
mcfg["params"]["metric_goal"] = mcfg["params"]["goal"]
del mcfg["params"]["goal"]
if "metric_params" in mcfg["params"]:
if isinstance(mcfg["params"]["metric_params"], list):
assert not mcfg["params"]["metric_params"], "cannot fill in kw names"
mcfg["params"]["metric_params"] = {}
elif "type" in mcfg and mcfg["type"].endswith("ROCCurve"):
assert "params" in mcfg
if "log_params" in mcfg["params"]:
logger.warning("disabling logging via ROCCurve metric")
del mcfg["params"]["log_params"]
cfg_ver = [0, 3, 6] # set ver for next update step
if cfg_ver[0] <= 0 and cfg_ver[1] <= 4 and cfg_ver[2] < 2:
if "model" in config and isinstance(config, dict):
model_config = thelper.utils.get_key("model", config)
model_type = thelper.utils.get_key_def("type", model_config, None)
if model_type == "thelper.nn.resnet.ResNet":
model_params = thelper.utils.get_key_def("params", model_config, {})
coordconv_flag = thelper.utils.get_key_def("coordconv", model_params, False)
if coordconv_flag:
logger.warning("coordconv implementation for resnets changed in v0.4.2; "
"beware if reloading old model weights!")
cfg_ver = [0, 4, 2]
# if cfg_ver[0] <= x and cfg_ver[1] <= y and cfg_ver[2] <= z:
# ... add more compatibility fixes here
return config
[docs]def download_file(url, root, filename, md5=None):
"""Downloads a file from a given URL to a local destination.
Args:
url: path to query for the file (query will be based on urllib).
root: destination folder where the file should be saved.
filename: destination name for the file.
md5: optional, for md5 integrity check.
Returns:
The path to the downloaded file.
"""
# inspired from torchvision.datasets.utils.download_url; no dep check
from six.moves import urllib
root = os.path.expanduser(root)
fpath = os.path.join(root, filename)
try:
os.makedirs(root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
if not os.path.isfile(fpath):
logger.info("Downloading %s to %s ..." % (url, fpath))
urllib.request.urlretrieve(url, fpath, reporthook)
sys.stdout.write("\r")
sys.stdout.flush()
if md5 is not None:
md5o = hashlib.md5()
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(1024 * 1024), b''):
md5o.update(chunk)
md5c = md5o.hexdigest()
if md5c != md5:
raise AssertionError("md5 check failed for '%s'" % fpath)
return fpath
[docs]def reporthook(count, block_size, total_size):
"""Report hook used to display a download progression bar when using urllib requests."""
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = str(int(progress_size / (1024 * duration))) if duration > 0 else "?"
percent = min(int(count * block_size * 100 / total_size), 100)
sys.stdout.write("\r\t=> downloaded %d%% (%d MB) @ %s KB/s..." %
(percent, progress_size / (1024 * 1024), speed))
sys.stdout.flush()
[docs]def init_logger(log_level=logging.NOTSET, filename=None, force_stdout=False):
"""Initializes the framework logger with a specific filter level, and optional file output."""
if getattr(thelper, "LOGGER_INITIALIZED", None) is None:
logging.getLogger().setLevel(logging.NOTSET)
thelper.logger.propagate = 0
logger_format = logging.Formatter("[%(asctime)s - %(name)s] %(levelname)s : %(message)s")
if filename is not None:
logger_fh = logging.FileHandler(filename)
logger_fh.setLevel(logging.NOTSET)
logger_fh.setFormatter(logger_format)
thelper.logger.addHandler(logger_fh)
stream = sys.stdout if force_stdout else None
logger_ch = logging.StreamHandler(stream=stream)
logger_ch.setLevel(log_level)
logger_ch.setFormatter(logger_format)
thelper.logger.addHandler(logger_ch)
setattr(thelper, "LOGGER_INITIALIZED", True)
[docs]def resolve_import(fullname):
# type: (str) -> str
"""
Class name resolver.
Takes a string corresponding to a module and class fullname to be imported with :func:`thelper.utils.import_class`
and resolves any back compatibility issues related to renamed or moved classes.
Args:
fullname: the fully qualified class name to be resolved.
Returns:
The resolved class fullname.
"""
removed_cases = [
'thelper.optim.metrics.RawPredictions', # removed in 0.3.5
]
if fullname in removed_cases:
raise AssertionError(f"class {repr(fullname)} was deprecated and removed in a previous version")
refactor_cases = [
('thelper.modules', 'thelper.nn'),
('thelper.samplers', 'thelper.data.samplers'),
('thelper.optim.BinaryAccuracy', 'thelper.optim.metrics.Accuracy'),
('thelper.optim.CategoryAccuracy', 'thelper.optim.metrics.Accuracy'),
('thelper.optim.ClassifLogger', 'thelper.train.utils.ClassifLogger'),
('thelper.optim.ClassifReport', 'thelper.train.utils.ClassifReport'),
('thelper.optim.ConfusionMatrix', 'thelper.train.utils.ConfusionMatrix'),
('thelper.optim.metrics.BinaryAccuracy', 'thelper.optim.metrics.Accuracy'),
('thelper.optim.metrics.CategoryAccuracy', 'thelper.optim.metrics.Accuracy'),
('thelper.optim.metrics.ClassifLogger', 'thelper.train.utils.ClassifLogger'),
('thelper.optim.metrics.ClassifReport', 'thelper.train.utils.ClassifReport'),
('thelper.optim.metrics.ConfusionMatrix', 'thelper.train.utils.ConfusionMatrix'),
('thelper.transforms.ImageTransformWrapper', 'thelper.transforms.TransformWrapper'),
('thelper.transforms.wrappers.ImageTransformWrapper', 'thelper.transforms.wrappers.TransformWrapper'),
]
old_name = fullname
for old, new in refactor_cases:
fullname = fullname.replace(old, new)
if old_name != fullname:
logger.warning(f"class fullname '{str(old_name)}' was resolved to '{str(fullname)}'")
return fullname
[docs]def import_class(fullname):
# type: (str) -> Type
"""General-purpose runtime class importer.
Supported syntax:
1. ``module.package.Class`` will import the fully qualified ``Class`` located
in ``package`` from the *installed* ``module``
2. ``/some/path/mod.pkg.Cls`` will import ``Cls`` as fully qualified ``mod.pkg.Cls`` from
``/some/path`` directory
Args:
fullname: the fully qualified class name to be imported.
Returns:
The imported class.
"""
if inspect.isclass(fullname):
return fullname # useful shortcut for hacky configs
assert isinstance(fullname, str), "should specify class by its (fully qualified) name"
fullname = pathlib.Path(fullname).as_posix()
if "/" in fullname:
mod_path, mod_cls_name = fullname.rsplit("/", 1)
pkg_name = mod_cls_name.rsplit(".", 1)[0]
pkg_file = os.path.join(mod_path, pkg_name.replace(".", "/")) + ".py"
mod_cls_name = resolve_import(mod_cls_name)
spec = importlib.util.spec_from_file_location(mod_cls_name, pkg_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
class_name = mod_cls_name.rsplit('.', 1)[-1]
else:
fullname = resolve_import(fullname)
module_name, class_name = fullname.rsplit('.', 1)
module = importlib.import_module(module_name)
return getattr(module, class_name)
[docs]def import_function(func, # type: Union[Callable, AnyStr, List, Dict]
params=None # type: Optional[thelper.typedefs.ConfigDict]
): # type: (...) -> FunctionType
"""General-purpose runtime function importer, with support for parameter binding.
Args:
func: the fully qualified function name to be imported, or a dictionary with
two members (a ``type`` and optional ``params``), or a list of any of these.
params: optional params dictionary to bind to the function call via functools.
If a dictionary of parameters is also provided in ``func``, both will be merged.
Returns:
The imported function, with optionally bound parameters.
"""
assert isinstance(func, (str, dict, list)) or callable(func), "invalid target function type"
assert params is None or isinstance(params, dict), "invalid target function parameters"
params = {} if params is None else params
if isinstance(func, list):
def multi_caller(funcs, *args, **kwargs):
return [fn(*args, **kwargs) for fn in funcs]
return functools.partial(multi_caller, [import_function(fn, params) for fn in func])
if isinstance(func, dict):
errmsg = "dynamic function import via dictionary must provide 'type' and 'params' members"
fn_type = thelper.utils.get_key(["type", "func", "function", "op", "operation", "name"], func, msg=errmsg)
fn_params = thelper.utils.get_key_def(["params", "param", "parameters", "kwargs"], func, None)
fn_params = {} if fn_params is None else fn_params
fn_params = {**params, **fn_params}
return import_function(fn_type, params=fn_params)
if isinstance(func, str):
func = import_class(func)
assert callable(func), f"unsupported function type ({type(func)})"
if params:
return functools.partial(func, **params)
return func
[docs]def get_func_params(func, # type: Callable
): # type: (...) -> Mapping[str]
"""Returns the parameters expected when calling the given function. Supports class constructors."""
if inspect.isclass(func):
func = func.__init__
return inspect.signature(func).parameters
[docs]def check_func_signature(func, # type: FunctionType
params # type: List[str]
): # type: (...) -> None
"""Checks whether the signature of a function matches the expected parameter list."""
if func is None or not callable(func):
raise AssertionError("invalid function object")
if params is not None:
if not isinstance(params, list) or not all([isinstance(p, str) for p in params]):
raise AssertionError("unexpected param name list format")
func_params = get_func_params(func)
for p in params:
if p not in func_params:
raise AssertionError("function missing parameter '%s'" % p)
[docs]def encode_data(data, approach="lz4", **kwargs):
"""Encodes a numpy array using a given coding approach.
Args:
data: the numpy array to encode.
approach: the encoding; supports `none`, `lz4`, `jpg`, `png`.
.. seealso::
| :func:`thelper.utils.decode_data`
"""
supported_approaches = [*no_compression_flags, "lz4", "jpg", "png"]
assert approach in supported_approaches, f"unexpected approach '{approach}'"
if approach in no_compression_flags:
assert not kwargs
return data
elif approach == "lz4":
import lz4
return lz4.frame.compress(data, **kwargs)
elif approach == "jpg" or approach == "jpeg":
import cv2 as cv
ret, buf = cv.imencode(".jpg", data, **kwargs)
elif approach == "png":
import cv2 as cv
ret, buf = cv.imencode(".png", data, **kwargs)
else:
raise NotImplementedError
assert ret, "failed to encode data"
return buf
[docs]def decode_data(data, approach="lz4", **kwargs):
"""Decodes a binary array using a given coding approach.
Args:
data: the binary array to decode.
approach: the encoding; supports `none`, `lz4`, `jpg`, `png`.
.. seealso::
| :func:`thelper.utils.encode_data`
"""
supported_approaches = [*no_compression_flags, "lz4", "jpg", "png"]
assert approach in supported_approaches, f"unexpected approach '{approach}'"
if approach in no_compression_flags:
assert not kwargs
return data
elif approach == "lz4":
import lz4
return lz4.frame.decompress(data, **kwargs)
elif approach in ["jpg", "jpeg", "png"]:
kwargs = copy.deepcopy(kwargs)
if isinstance(kwargs["flags"], str): # required arg by opencv
kwargs["flags"] = eval(kwargs["flags"])
import cv2 as cv
return cv.imdecode(data, **kwargs)
else:
raise NotImplementedError
[docs]def get_class_logger(skip=0, base=False):
"""Shorthand to get logger for current class frame."""
return logging.getLogger(get_caller_name(skip + 1, base_class=base).rsplit(".", 1)[0])
[docs]def get_func_logger(skip=0):
"""Shorthand to get logger for current function frame."""
return logging.getLogger(get_caller_name(skip + 1))
[docs]def get_caller_name(skip=2, base_class=False):
# source: https://gist.github.com/techtonik/2151727
"""Returns the name of a caller in the format module.class.method.
Args:
skip: specifies how many levels of stack to skip while getting the caller.
base_class: specified if the base class should be returned or the top-most class in case of inheritance
If the caller is not a class, this doesn't do anything.
Returns:
An empty string is returned if skipped levels exceed stack height; otherwise,
returns the requested caller name.
"""
def stack_(frame):
frame_list = []
while frame:
frame_list.append(frame)
frame = frame.f_back
return frame_list
# noinspection PyProtectedMember
stack = stack_(sys._getframe(1))
start = 0 + skip
if len(stack) < start + 1:
return ""
parent_frame = stack[start]
name = []
module = inspect.getmodule(parent_frame)
# `modname` can be None when frame is executed directly in console
if module:
# frame module in case of inherited classes will point to base class
# but frame local will still refer to top-most class when checking for 'self'
# (stack: top(mid).__init__ -> mid(base).__init__ -> base.__init__)
name.append(module.__name__)
# detect class name
if "self" in parent_frame.f_locals:
# I don't know any way to detect call from the object method
# XXX: there seems to be no way to detect static method call - it will
# be just a function call
cls = parent_frame.f_locals["self"].__class__
if not base_class and module and inspect.isclass(cls):
name[0] = cls.__module__
name.append(cls.__name__)
codename = parent_frame.f_code.co_name
if codename != "<module>": # top level usually
name.append(codename) # function or a method
del parent_frame
return ".".join(name)
[docs]def get_key(key, config, msg=None, delete=False):
"""Returns a value given a dictionary key, throwing if not available."""
if isinstance(key, list):
if len(key) <= 1:
if msg is not None:
raise AssertionError(msg)
else:
raise AssertionError("must provide at least two valid keys to test")
for k in key:
if k in config:
val = config[k]
if delete:
del config[k]
return val
if msg is not None:
raise AssertionError(msg)
else:
raise AssertionError("config dictionary missing a field named as one of '%s'" % str(key))
else:
if key not in config:
if msg is not None:
raise AssertionError(msg)
else:
raise AssertionError("config dictionary missing '%s' field" % key)
else:
val = config[key]
if delete:
del config[key]
return val
[docs]def get_key_def(key, config, default=None, msg=None, delete=False):
"""Returns a value given a dictionary key, or the default value if it cannot be found."""
if isinstance(key, list):
if len(key) <= 1:
if msg is not None:
raise AssertionError(msg)
else:
raise AssertionError("must provide at least two valid keys to test")
for k in key:
if k in config:
val = config[k]
if delete:
del config[k]
return val
return default
else:
if key not in config:
return default
else:
val = config[key]
if delete:
del config[key]
return val
[docs]def get_log_stamp():
"""Returns a print-friendly and filename-friendly identification string containing platform and time."""
return str(platform.node()) + "-" + time.strftime("%Y%m%d-%H%M%S")
[docs]def get_git_stamp():
"""Returns a print-friendly SHA signature for the framework's underlying git repository (if found)."""
try:
import git
try:
repo = git.Repo(path=os.path.abspath(__file__), search_parent_directories=True)
sha = repo.head.object.hexsha
return str(sha)
except (AttributeError, git.InvalidGitRepositoryError):
return "unknown"
except (ImportError, AttributeError):
return "unknown"
[docs]def get_env_list():
"""Returns a list of all packages installed in the current environment.
If the required packages cannot be imported, the returned list will be empty. Note that some
packages may not be properly detected by this approach, and it is pretty hacky, so use it with
a grain of salt (i.e. logging is fine).
"""
try:
import pip
# noinspection PyUnresolvedReferences
pkgs = pip.get_installed_distributions()
return sorted(["%s %s" % (pkg.key, pkg.version) for pkg in pkgs])
except (ImportError, AttributeError):
try:
import pkg_resources as pkgr
return sorted([str(pkg) for pkg in pkgr.working_set])
except (ImportError, AttributeError):
return []
[docs]def str2size(input_str):
"""Returns a (WIDTH, HEIGHT) integer size tuple from a string formatted as 'WxH'."""
if not isinstance(input_str, str):
raise AssertionError("unexpected input type")
display_size_str = input_str.split('x')
if len(display_size_str) != 2:
raise AssertionError("bad size string formatting")
return tuple([max(int(substr), 1) for substr in display_size_str])
[docs]def str2bool(s):
"""Converts a string to a boolean.
If the lower case version of the provided string matches any of 'true', '1', or
'yes', then the function returns ``True``.
"""
if isinstance(s, bool):
return s
if isinstance(s, (int, float)):
return s != 0
if isinstance(s, str):
positive_flags = ["true", "1", "yes"]
return s.lower() in positive_flags
raise AssertionError("unrecognized input type")
[docs]def clipstr(s, size, fill=" "):
"""Clips a string to a specific length, with an optional fill character."""
if len(s) > size:
s = s[:size]
if len(s) < size:
s = fill * (size - len(s)) + s
return s
[docs]def lreplace(string, old_prefix, new_prefix):
"""Replaces a single occurrence of `old_prefix` in the given string by `new_prefix`."""
return re.sub(r'^(?:%s)+' % re.escape(old_prefix), lambda m: new_prefix * (m.end() // len(old_prefix)), string)
[docs]def query_yes_no(question, default=None, bypass=None):
"""Asks the user a yes/no question and returns the answer.
Args:
question: the string that is presented to the user.
default: the presumed answer if the user just hits ``<Enter>``. It must be 'yes',
'no', or ``None`` (meaning an answer is required).
bypass: the option to select if the ``bypass_queries`` global variable is set to
``True``. Can be ``None``, in which case the function will throw an exception.
Returns:
``True`` for 'yes', or ``False`` for 'no' (or their respective variations).
"""
valid = {"yes": True, "ye": True, "y": True, "no": False, "n": False}
if bypass is not None and (not isinstance(bypass, str) or bypass not in valid):
raise AssertionError("unexpected bypass value")
clean_q = clipstr(question.replace("\n", " ").replace("\t", " ").replace("'", "`"), 45)
if bypass_queries:
if bypass is None:
raise AssertionError("cannot bypass interactive query, no default value provided")
logger.debug(f"bypassed query '{clean_q}...' with {valid[bypass]}")
return valid[bypass]
if (isinstance(default, bool) and default) or \
(isinstance(default, str) and default.lower() in ["yes", "ye", "y"]):
prompt = " [Y/n] "
elif (isinstance(default, bool) and not default) or \
(isinstance(default, str) and default.lower() in ["no", "n"]):
prompt = " [y/N] "
else:
prompt = " [y/n] "
sys.stdout.flush()
sys.stderr.flush()
time.sleep(0.25) # to make sure all debug/info prints are done, and we see the question
while True:
sys.stdout.write(question + prompt + "\n>> ")
choice = input().lower()
if default is not None and choice == "":
if isinstance(default, str):
sys.stdout.write("\n")
logger.debug(f"defaulted query '{clean_q}...' with {valid[default]}")
return valid[default]
else:
sys.stdout.write("\n")
logger.debug(f"defaulted query '{clean_q}...' with {default}")
return default
elif choice in valid:
sys.stdout.write("\n")
logger.debug(f"answered query '{clean_q}...' with {valid[choice]}")
return valid[choice]
else:
sys.stdout.write("Please respond with 'yes/y' or 'no/n'.\n")
[docs]def query_string(question, choices=None, default=None, allow_empty=False, bypass=None):
"""Asks the user a question and returns the answer (a generic string).
Args:
question: the string that is presented to the user.
choices: a list of predefined choices that the user can pick from. If
``None``, then whatever the user types will be accepted.
default: the presumed answer if the user just hits ``<Enter>``. If ``None``,
then an answer is required to continue.
allow_empty: defines whether an empty answer should be accepted.
bypass: the returned value if the ``bypass_queries`` global variable is set to
``True``. Can be ``None``, in which case the function will throw an exception.
Returns:
The string entered by the user.
"""
clean_q = clipstr(question.replace("\n", " ").replace("\t", " ").replace("'", "`"), 45)
if bypass_queries:
if bypass is None:
raise AssertionError("cannot bypass interactive query, no default value provided")
logger.debug(f"bypassed query '{clean_q}...' with {bypass}")
return bypass
sys.stdout.flush()
sys.stderr.flush()
time.sleep(0.25) # to make sure all debug/info prints are done, and we see the question
while True:
msg = question
if choices is not None:
msg += "\n\t(choices=%s)" % str(choices)
if default is not None:
msg += "\n\t(default=%s)" % default
sys.stdout.write(msg + "\n>> ")
answer = input()
if answer == "":
if default is not None:
sys.stdout.write("\n")
logger.debug(f"defaulted query '{clean_q}...' with {default}")
return default
elif allow_empty:
sys.stdout.write("\n")
logger.debug(f"answered query '{clean_q}...' with empty string")
return answer
elif choices is not None:
if answer in choices:
sys.stdout.write("\n")
logger.debug(f"answered query '{clean_q}...' with choice '{answer}'")
return answer
else:
sys.stdout.write("\n")
logger.debug(f"answered query '{clean_q}...' with {answer}")
return answer
sys.stdout.write("Please respond with a valid string.\n")
[docs]def get_config_session_name(config):
# type: (thelper.typedefs.ConfigDict) -> Optional[str]
"""Returns the 'name' of a session as defined inside a configuration dictionary.
The current implementation will scan for multiple keywords and return the first value found. If no
keyword is matched, the function will return None.
Args:
config: the configuration dictionary to parse for a name.
Returns:
The name that should be given to the session (or 'None' if unknown/unavailable).
"""
return thelper.utils.get_key_def(["output_dir_name", "output_directory_name",
"session_name", "name"], config, None)
[docs]def get_config_output_root(config):
# type: (thelper.typedefs.ConfigDict) -> Optional[str]
"""Returns the output root directory as defined inside a configuration dictionary.
The current implementation will scan for multiple keywords and return the first value found. If no
keyword is matched, the function will return None.
Args:
config: the configuration dictionary to parse for a root output directory.
Returns:
The path to the output root directory. Can point to a non-existing directory, or be None.
"""
return thelper.utils.get_key_def(["output_root_dir", "output_root_directory"], config, None)
[docs]def get_checkpoint_session_root(ckpt_path):
# type: (str) -> Optional[str]
"""Returns the session root directory associated with a checkpoint path.
The given path can point to a checkpoint file or to a directory that contains checkpoints. The
returned output directory will be the top-level of the session that created the checkpoint, or
None if it cannot be deduced.
Args:
ckpt_path: the path to a checkpoint or to an exisiting directory that contains checkpoints.
Returns:
The path to the session root directory. Will always point to an existing directory, or be None.
"""
assert os.path.exists(ckpt_path), "input path should point to valid filesystem node"
ckpt_dir_path = os.path.dirname(os.path.abspath(ckpt_path)) \
if not os.path.isdir(ckpt_path) else os.path.abspath(ckpt_path)
# find session dir by looking for 'logs' directory
if os.path.isdir(os.path.join(ckpt_dir_path, "logs")):
return os.path.abspath(os.path.join(ckpt_dir_path, ".."))
elif os.path.isdir(os.path.join(ckpt_dir_path, "../logs")):
return os.path.abspath(os.path.join(ckpt_dir_path, "../.."))
return None # cannot be found... giving up
[docs]def get_save_dir(out_root, dir_name, config=None, resume=False, backup_ext=".json"):
"""Returns a directory path in which the app can save its data.
If a folder with name ``dir_name`` already exists in the directory ``out_root``, then the user will be
asked to pick a new name. If the user refuses, ``sys.exit(1)`` is called. If config is not ``None``, it
will be saved to the output directory as a json file. Finally, a ``logs`` directory will also be created
in the output directory for writing logger files.
Args:
out_root: path to the directory root where the save directory should be created.
dir_name: name of the save directory to create. If it already exists, a new one will be requested.
config: dictionary of app configuration parameters. Used to overwrite i/o queries, and will be
written to the save directory in json format to test writing. Default is ``None``.
resume: specifies whether this session is new, or resumed from an older one (in the latter
case, overwriting is allowed, and the user will never have to choose a new folder)
backup_ext: extension to use when creating configuration file backups.
Returns:
The path to the created save directory for this session.
"""
func_logger = get_func_logger()
if config is not None:
config_out_root = thelper.utils.get_config_output_root(config)
if out_root is not None and config_out_root is not None and out_root != config_out_root:
answer = query_string("Received conflicting output root directory paths; which one should be used?\n"
f"\t[config] = {config_out_root}\n\t[cli] = {out_root}", choices=["config", "cli"],
default="config", bypass="config")
if answer == "config":
out_root = config_out_root
elif out_root is None and config_out_root is not None:
out_root = config_out_root
config_dir_name = thelper.utils.get_config_session_name(config)
if config_dir_name is not None:
assert isinstance(config_dir_name, str), "config session/directory name should be given as string"
assert not os.path.isabs(config_dir_name), "config session/directory name should never be full (abs) path"
if dir_name is not None and dir_name != config_dir_name:
func_logger.warning(f"overriding output session directory name '{dir_name}' to '{config_dir_name}'")
dir_name = config_dir_name
if out_root is None:
time.sleep(0.25) # to make sure all debug/info prints are done, and we see the question
out_root = query_string("Please provide the path to where session directories should be created/saved:")
func_logger.info(f"output root directory = {os.path.abspath(out_root)}")
os.makedirs(out_root, exist_ok=True)
save_dir = os.path.join(out_root, dir_name) if dir_name is not None else out_root
if not resume:
overwrite = str2bool(config["overwrite"]) if config is not None and "overwrite" in config else False
time.sleep(0.25) # to make sure all debug/info prints are done, and we see the question
while os.path.exists(save_dir) and not overwrite:
abs_save_dir = os.path.abspath(save_dir).replace("\\", "/")
overwrite = query_yes_no("Training session at '%s' already exists; overwrite?" % abs_save_dir, bypass="y")
if not overwrite:
save_dir = query_string("Please provide a new save directory path:")
func_logger.info(f"output session directory = {os.path.abspath(save_dir)}")
os.makedirs(save_dir, exist_ok=True)
logs_dir = os.path.join(save_dir, "logs")
func_logger.info(f"output logs directory = {os.path.abspath(logs_dir)}")
os.makedirs(logs_dir, exist_ok=True)
if config is not None:
common_backup_path = os.path.join(save_dir, "config.latest" + backup_ext)
if resume and os.path.exists(common_backup_path):
config_backup = thelper.utils.load_config(common_backup_path, add_name_if_missing=False)
if config_backup != config: # TODO make config dict comparison smarter...?
query_msg = f"Config backup in '{common_backup_path}' differs from config loaded through checkpoint; overwrite?"
answer = query_yes_no(query_msg, bypass="y")
if answer:
func_logger.warning("config mismatch with previous run; "
"will overwrite latest backup in save directory")
else:
func_logger.error("config mismatch with previous run; user aborted")
sys.exit(1)
save_config(config, common_backup_path)
tagged_backup_path = os.path.join(logs_dir, "config." + thelper.utils.get_log_stamp() + backup_ext)
save_config(config, tagged_backup_path)
return save_dir
[docs]def load_config(path, as_json=False, add_name_if_missing=True, **kwargs):
# type: (str, bool, bool, **Any) -> thelper.typedefs.ConfigDict
"""Loads the configuration dictionary from the provided path.
The type of file that is loaded is based on the extension in the path.
If the loaded configuration dictionary does not contain a 'name' field, the name of
the file itself will be inserted as a value.
Args:
path: the path specifying which configuration to be loaded.
only supported types are loaded unless `as_json` is `True`.
as_json: specifies if an alternate extension should be considered as JSON format.
add_name_if_missing: specifies whether the file name should be added to the config
dictionary if it is missing a 'name' field.
kwargs: forwarded to the extension-specific importer.
"""
global fixed_yaml_parsing
if not fixed_yaml_parsing:
# https://stackoverflow.com/questions/30458977/yaml-loads-5e-6-as-string-and-not-a-number
loader = yaml.SafeLoader
loader.add_implicit_resolver(
u'tag:yaml.org,2002:float',
re.compile(u'''^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$''', re.X),
list(u'-+0123456789.'))
fixed_yaml_parsing = True
ext = os.path.splitext(path)[-1]
if ext in [".json", ".yml", ".yaml", ".conf"] or as_json:
with open(path) as fd:
assert not kwargs, "yaml safe load takes no extra args"
config = yaml.safe_load(fd) # also supports json
elif ext == ".pkl":
with open(path, "rb") as fd:
config = pickle.load(fd, **kwargs)
else:
raise AssertionError(f"unknown input file type: {ext}")
if add_name_if_missing and thelper.utils.get_config_session_name(config) is None:
config["name"] = os.path.splitext(os.path.basename(path))[0]
return config
[docs]def save_config(config, path, force_convert=True, as_json=False, **kwargs):
# type: (thelper.typedefs.ConfigDict, str, bool, bool, **Any) -> None
"""Saves the given session/object configuration dictionary to the provided path.
The type of file that is created is based on the extension specified in the path. If the file
cannot hold some of the objects within the configuration, they will be converted to strings before
serialization, unless `force_convert` is set to `False` (in which case the function will raise
an exception).
Args:
config: the session/object configuration dictionary to save.
path: the path specifying where to create the output file. The extension used will determine
what type of backup to create (e.g. Pickle = .pkl, JSON = .json, YAML = .yml/.yaml).
if `as_json` is `True`, then any specified extension will be preserved bump dumped as JSON.
force_convert: specifies whether non-serializable types should be converted if necessary.
as_json: specifies if an alternate extension should be considered as JSON format.
kwargs: forwarded to the extension-specific exporter.
"""
ext = os.path.splitext(path)[-1]
if ext in [".json", ".yml", ".yaml"] or as_json:
with open(path, "w") as fd:
kwargs.setdefault("indent", 4)
if ext == ".json" or as_json:
serializer = (lambda x: str(x)) if force_convert else None
kwargs.setdefault("default", serializer)
kwargs.setdefault("sort_keys", False)
json.dump(config, fd, **kwargs)
else:
yaml.dump(config, fd, **kwargs)
elif ext == ".pkl":
with open(path, "wb") as fd:
pickle.dump(config, fd, **kwargs)
else:
raise AssertionError(f"unknown output file type: {ext}")
[docs]def save_env_list(path):
"""Saves a list of all packages installed in the current environment to a log file.
Args:
path: the path where the log file should be created.
"""
with open(path, "w") as fd:
pkgs_list = thelper.utils.get_env_list()
if pkgs_list:
for pkg in pkgs_list:
fd.write("%s\n" % pkg)
else:
fd.write("<n/a>\n")
[docs]def is_scalar(val):
"""Returns whether the input value is a scalar according to numpy and PyTorch."""
if np.isscalar(val):
return True
if isinstance(val, torch.Tensor) and (val.dim() == 0 or val.numel() == 1):
return True
return False
[docs]def to_numpy(array):
"""Converts a list or PyTorch tensor to numpy. Does nothing if already a numpy array."""
if isinstance(array, list):
return np.asarray(array)
elif isinstance(array, torch.Tensor):
return array.cpu().numpy()
elif isinstance(array, np.ndarray):
return array
else:
raise AssertionError(f"unexpected input type ({type(array)})")
[docs]def stringify_confmat(confmat, class_list, hide_zeroes=False, hide_diagonal=False, hide_threshold=None):
"""Transforms a confusion matrix array obtained in list or numpy format into a printable string."""
if not isinstance(confmat, np.ndarray) or not isinstance(class_list, list):
raise AssertionError("invalid inputs")
column_width = 9
empty_cell = " " * column_width
fst_empty_cell = (column_width - 3) // 2 * " " + "t/p" + (column_width - 3) // 2 * " "
if len(fst_empty_cell) < len(empty_cell):
fst_empty_cell = " " * (len(empty_cell) - len(fst_empty_cell)) + fst_empty_cell
res = "\t" + fst_empty_cell + " "
for label in class_list:
res += ("%{0}s".format(column_width) % clipstr(label, column_width)) + " "
res += ("%{0}s".format(column_width) % "total") + "\n"
for idx_true, label in enumerate(class_list):
res += ("\t%{0}s".format(column_width) % clipstr(label, column_width)) + " "
for idx_pred, _ in enumerate(class_list):
cell = "%{0}d".format(column_width) % int(confmat[idx_true, idx_pred])
if hide_zeroes:
cell = cell if int(confmat[idx_true, idx_pred]) != 0 else empty_cell
if hide_diagonal:
cell = cell if idx_true != idx_pred else empty_cell
if hide_threshold:
cell = cell if confmat[idx_true, idx_pred] > hide_threshold else empty_cell
res += cell + " "
res += ("%{0}d".format(column_width) % int(confmat[idx_true, :].sum())) + "\n"
res += ("\t%{0}s".format(column_width) % "total") + " "
for idx_pred, _ in enumerate(class_list):
res += ("%{0}d".format(column_width) % int(confmat[:, idx_pred].sum())) + " "
res += ("%{0}d".format(column_width) % int(confmat.sum())) + "\n"
return res
[docs]def get_glob_paths(input_glob_pattern, can_be_dir=False):
"""Parse a wildcard-compatible file name pattern for valid file paths."""
glob_file_paths = glob.glob(input_glob_pattern)
if not glob_file_paths:
raise AssertionError("invalid input glob pattern '%s' (no matches found)" % input_glob_pattern)
for file_path in glob_file_paths:
if not os.path.isfile(file_path) and not (can_be_dir and os.path.isdir(file_path)):
raise AssertionError("invalid input file at globed path '%s'" % file_path)
return glob_file_paths
[docs]def get_file_paths(input_path, data_root, allow_glob=False, can_be_dir=False):
"""Parse a wildcard-compatible file name pattern at a given root level for valid file paths."""
if os.path.isabs(input_path):
if '*' in input_path and allow_glob:
return get_glob_paths(input_path)
elif not os.path.isfile(input_path) and not (can_be_dir and os.path.isdir(input_path)):
raise AssertionError("invalid input file at absolute path '%s'" % input_path)
else:
if not os.path.isdir(data_root):
raise AssertionError("invalid dataset root directory at '%s'" % data_root)
input_path = os.path.join(data_root, input_path)
if '*' in input_path and allow_glob:
return get_glob_paths(input_path)
elif not os.path.isfile(input_path) and not (can_be_dir and os.path.isdir(input_path)):
raise AssertionError("invalid input file at path '%s'" % input_path)
return [input_path]
[docs]def get_params_hash(*args, **kwargs):
"""Returns a sha1 hash for the given list of parameters (useful for caching)."""
# by default, will use the repr of all params but remove the 'at 0x00000000' addresses
clean_str = re.sub(r" at 0x[a-fA-F\d]+", "", str(args) + str(kwargs))
return hashlib.sha1(clean_str.encode()).hexdigest()
[docs]def check_installed(package_name):
"""Attempts to import a specified package by name, returning a boolean indicating success."""
try:
importlib.import_module(package_name)
return True
except ImportError:
return False
[docs]def set_matplotlib_agg():
"""Sets the matplotlib backend to Agg."""
import matplotlib
matplotlib.use('Agg')
[docs]def create_hdf5_dataset(fd, name, max_len, batch_like, compression="chunk_lz4", chunk_size=None, flatten=True):
"""Creates an HDF5 dataset inside the provided HDF5.File object descriptor."""
assert batch_like.ndim >= 1, "minibatch must always contain at least batch dim"
compression_args = {}
if isinstance(compression, (tuple, list)) and len(compression) == 2:
compression_args = compression[1]
compression = compression[0]
flat_dtype = h5py.special_dtype(vlen=np.uint8)
if batch_like.ndim > 1 and flatten:
dset = fd.create_dataset(name, shape=(max_len,), maxshape=(max_len,), dtype=flat_dtype)
dset.attrs["orig_shape"] = batch_like.shape[1:] # removes batch dim
elif batch_like.ndim > 1:
assert compression in no_compression_flags or compression in chunk_compression_flags, \
f"unsupported chunk-compress filter '{compression}'"
assert np.issubdtype(batch_like.dtype, np.number), "invalid non-flattened array subtype"
auto_chunker = False
if chunk_size is None:
auto_chunker = True
chunk_size = (1, *batch_like.shape[1:])
chunk_byte_size = np.multiply.reduce(chunk_size) * batch_like.dtype.itemsize
assert auto_chunker or 10 * (2 ** 10) <= chunk_byte_size < 2 ** 20, \
f"unrecommended chunk byte size ({chunk_byte_size}) should be in [10KiB,1MiB];" \
" see http://docs.h5py.org/en/stable/high/dataset.html#chunked-storage"
if compression == "chunk_lz4":
dset = fd.create_dataset(
name=name,
shape=(max_len, *batch_like.shape[1:]),
chunks=chunk_size,
dtype=batch_like.dtype,
**hdf5plugin.LZ4(nbytes=0)
)
else:
assert compression not in no_compression_flags or len(compression_args) == 0
dset = fd.create_dataset(
name=name,
shape=(max_len, *batch_like.shape[1:]),
chunks=chunk_size,
dtype=batch_like.dtype,
compression=compression if compression not in no_compression_flags else None,
**compression_args
)
dset.attrs["orig_shape"] = batch_like.shape[1:] # removes batch dim
else:
assert thelper.utils.is_scalar(batch_like[0])
if np.issubdtype(batch_like.dtype, np.number):
assert compression in no_compression_flags, "cannot compress scalar elements"
dset = fd.create_dataset(name, shape=(max_len,), maxshape=(max_len,), dtype=batch_like.dtype)
else:
dset = fd.create_dataset(name, shape=(max_len,), maxshape=(max_len,), dtype=flat_dtype)
dset.attrs["orig_shape"] = ()
dset.attrs["orig_dtype"] = batch_like.dtype.str
dset.attrs["compression"] = "none" if compression in no_compression_flags else compression
return dset
[docs]def fill_hdf5_sample(dset, dset_idx, array_idx, array, compression="chunk_lz4", **compr_kwargs):
"""Fills a sample inside the specified HDF5 dataset object."""
sample = array[array_idx]
if compression not in chunk_compression_flags:
sample = thelper.utils.encode_data(sample, compression, **compr_kwargs)
if compression not in no_compression_flags:
sample = np.frombuffer(sample, dtype=np.uint8)
if not np.issubdtype(array.dtype, np.number):
if np.issubdtype(array.dtype, np.dtype(str).type):
assert len(array.shape) == 1, "missing impl for string array reconstr"
sample = sample.encode()
sample = np.frombuffer(sample, dtype=np.uint8)
dset[dset_idx] = sample
[docs]def fetch_hdf5_sample(dset, idx, dtype="auto", shape="auto", compression="auto", **decompr_kwargs):
"""Returns a sample from the specified HDF5 dataset object."""
if compression == "auto":
compression = dset.attrs.get("compression")
if shape == "auto":
shape = dset.attrs.get("orig_shape")
sample = dset[idx]
if compression not in chunk_compression_flags:
sample = thelper.utils.decode_data(sample, compression, **decompr_kwargs)
if dtype == "auto":
dtype = np.dtype(dset.attrs.get("orig_dtype"))
if dtype is not None:
if np.issubdtype(dtype, np.dtype(str).type):
assert shape is None or len(shape) == 0, "missing impl for string array reconstr"
sample = sample.tobytes().decode()
elif sample.dtype != dtype:
sample = np.frombuffer(sample, dtype=dtype)
else:
assert dtype == "auto" or dtype == sample.dtype
if shape is not None and len(shape) > 0 and sample.shape != tuple(shape):
sample = sample.reshape(shape)
return sample
[docs]def get_slurm_tmpdir() -> str:
"""Returns the local SLURM_TMPDIR path if available, or ``None``."""
slurm_tmpdir = os.getenv("SLURM_TMPDIR")
if slurm_tmpdir is not None:
assert os.path.isdir(slurm_tmpdir), "invalid SLURM_TMPDIR path (not directory)"
assert os.access(slurm_tmpdir, os.W_OK), "invalid SLURM_TMPDIR path (not writable)"
return slurm_tmpdir
[docs]def report_orion_results(session_runner: "SessionRunner") -> None:
"""Reports the results of a session runner, but only if the config allows it (true by default)."""
orion_config = get_key_def("orion", session_runner.config, default={})
orion_report_flag = get_key_def("report", orion_config, default=True)
if not orion_report_flag:
return
import orion.client
if session_runner.monitor is not None:
if session_runner.monitor_goal == thelper.optim.Metric.minimize:
report_val = session_runner.monitor_best
else:
report_val = -session_runner.monitor_best
orion.client.report_results([dict(
name=session_runner.monitor,
type="objective",
value=report_val,
)])