Source code for thelper.data.loaders

"""Dataset loaders module.

This module contains a dataset loader specialization used to properly seed samplers and workers.
"""

import copy
import logging
import math
import random
import sys
import time
from collections import Counter

import numpy as np
import torch
import torch.utils.data
import torch.utils.data.sampler
import tqdm

import thelper.tasks
import thelper.transforms
import thelper.utils

logger = logging.getLogger(__name__)


[docs]def default_collate(batch, force_tensor=True): """Puts each data field into a tensor with outer dimension batch size. This function is copied from PyTorch's `torch.utils.data._utils.collate.default_collate`, but additionally supports custom objects from the framework (such as bounding boxes). These will not be converted to tensors, and it will be up to the trainer to handle them accordingly. See ``torch.utils.data.DataLoader`` for more information. """ from torch._six import container_abcs, string_classes, int_classes error_msg_fmt = "batch must contain tensors, numbers, dicts or lists; found {}" torch_ver = [int(v) for v in torch.__version__.split('+')[0].split(".")] # format: X.Y.Z[+cu101] elem_type = type(batch[0]) if any([b is None for b in batch]): assert all([b is None for b in batch]), "cannot mix ``None`` and non-``None`` types" return None # compress and return entire field as unavailable elif isinstance(batch[0], torch.Tensor): out = None if torch_ver[0] > 1 or torch_ver[1] > 1: # ver > 1.1 if torch.utils.data.get_worker_info() is not None: numel = sum([x.numel() for x in batch]) storage = batch[0].storage()._new_shared(numel) out = batch[0].new(storage) return torch.stack(batch, 0, out=out) elif torch_ver[0] == 1 and torch_ver[1] == 1: # ver == 1.1 # pragma: no cover if torch.utils.data._utils.collate._use_shared_memory: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = batch[0].storage()._new_shared(numel) out = batch[0].new(storage) return torch.stack(batch, 0, out=out) else: # ver < 1.1 # pragma: no cover if torch.utils.data.dataloader._use_shared_memory: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = batch[0].storage()._new_shared(numel) out = batch[0].new(storage) return torch.stack(batch, 0, out=out) elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' and \ elem_type.__name__ != 'string_': elem = batch[0] if elem_type.__name__ == 'ndarray': # array of string classes and object if torch_ver[0] > 1 or torch_ver[1] > 0: # ver > 1.0 assert torch.utils.data._utils.collate.np_str_obj_array_pattern.search(elem.dtype.str) is None, \ error_msg_fmt.format(elem.dtype) else: # ver <= 1.0 # pragma: no cover import re if re.search('[SaUO]', elem.dtype.str) is not None: raise TypeError(error_msg_fmt.format(elem.dtype)) return default_collate([torch.from_numpy(b) for b in batch], force_tensor=force_tensor) if elem.shape == (): # scalars # pragma: no cover # simplified as of PyTorch v1.2.0, and similar to <1.1.0 return torch.as_tensor(batch) elif isinstance(batch[0], float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(batch[0], int_classes): return torch.tensor(batch) elif isinstance(batch[0], string_classes): return batch elif isinstance(batch[0], container_abcs.Mapping): return {key: default_collate([d[key] for d in batch], force_tensor=force_tensor) for key in batch[0]} elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple return type(batch[0])(*(default_collate(samples, force_tensor=force_tensor) for samples in zip(*batch))) elif isinstance(batch[0], container_abcs.Sequence): if isinstance(batch, list) and all([isinstance(lbl, list) for lbl in batch]) and \ all([isinstance(b, thelper.data.BoundingBox) for lbl in batch for b in lbl]): return batch transposed = zip(*batch) return [default_collate(samples, force_tensor=force_tensor) for samples in transposed] assert not force_tensor, error_msg_fmt.format(type(batch[0])) return batch
[docs]class DataLoader(torch.utils.data.DataLoader): """Specialized data loader used to load minibatches from a dataset parser. This specialization handles the seeding of samplers and workers. See ``torch.utils.data.DataLoader`` for more information on attributes/methods. """
[docs] def __init__(self, *args, seeds=None, epoch=0, collate_fn=default_collate, **kwargs): super().__init__(*args, collate_fn=collate_fn, worker_init_fn=self._worker_init_fn, **kwargs) self.seeds = {} if seeds is not None: if not isinstance(seeds, dict): raise AssertionError("unexpected seed pack type") self.seeds = seeds if not isinstance(epoch, int) or epoch < 0: raise AssertionError("invalid epoch value") self.epoch = epoch self.num_workers = kwargs["num_workers"] if "num_workers" in kwargs else 0
def __iter__(self): """Advances the epoch number for the workers initialization function.""" self.set_epoch(self.epoch) # preset for all attributes if self.num_workers == 0: if "torch" in self.seeds: torch.manual_seed(self.seeds["torch"] + self.epoch) torch.cuda.manual_seed_all(self.seeds["torch"] + self.epoch) if "numpy" in self.seeds: np.random.seed(self.seeds["numpy"] + self.epoch) if "random" in self.seeds: random.seed(self.seeds["random"] + self.epoch) result = super().__iter__() self.epoch += 1 return result
[docs] def set_epoch(self, epoch=0): """Sets the current epoch number in order to offset RNG states for the workers and the sampler.""" if not isinstance(epoch, int) or epoch < 0: raise AssertionError("invalid epoch value") self.epoch = epoch if self.sampler is not None: if hasattr(self.sampler, "set_epoch") and callable(self.sampler.set_epoch): self.sampler.set_epoch(self.epoch) if hasattr(self.dataset, "set_epoch") and callable(self.dataset.set_epoch): self.dataset.set_epoch(epoch) if hasattr(self.dataset, "transforms"): if hasattr(self.dataset.transforms, "set_epoch") and callable(self.dataset.transforms.set_epoch): self.dataset.transforms.set_epoch(epoch)
def _worker_init_fn(self, worker_id): """Sets up the RNGs state of each worker based on their unique id and the epoch number.""" seed_offset = self.num_workers * self.epoch if "torch" in self.seeds: torch.manual_seed(self.seeds["torch"] + seed_offset + worker_id) torch.cuda.manual_seed_all(self.seeds["torch"] + seed_offset + worker_id) if "numpy" in self.seeds: np.random.seed(self.seeds["numpy"] + seed_offset + worker_id) if "random" in self.seeds: random.seed(self.seeds["random"] + seed_offset + worker_id) @property def sample_count(self): return len(self.sampler) if self.sampler is not None else len(self.dataset)
[docs]class DataLoaderWrapper(DataLoader): """Data loader wrapper used to transform all loaded samples with an external function. This can be useful to convert the samples before the user gets to access them, or to upload them on a specific device. The wrapped data loader should be compatible with :class:`thelper.data.loaders.DataLoader`. """
[docs] def __init__(self, loader, callback): self.__class__ = type(loader.__class__.__name__, (self.__class__, loader.__class__), {}) self.__dict__ = {**loader.__dict__, "_wrapped_loader": loader, "_callback": callback}
def __iter__(self): for sample in self._wrapped_loader: yield self._callback(sample)
[docs]class LoaderFactory: """Factory used for preparing and splitting dataset parsers into usable data loader objects. This class is responsible for parsing the parameters contained in the 'loaders' field of a configuration dictionary, instantiating the data loaders, and shuffling/splitting the samples. An example configuration is presented in :func:`thelper.data.utils.create_loaders`. .. seealso:: | :func:`thelper.data.utils.create_loaders` | :func:`thelper.transforms.utils.load_augments` | :func:`thelper.transforms.utils.load_transforms` """
[docs] def __init__(self, config): """Receives and parses the data configuration dictionary.""" logger.debug("loading data configuration...") default_batch_size = 1 if "batch_size" in config: if any([v in config for v in ["train_batch_size", "valid_batch_size", "test_batch_size"]]): raise AssertionError("specifying 'batch_size' overrides all other (loader-specific) values") default_batch_size = int(thelper.utils.get_key("batch_size", config)) self.train_batch_size = thelper.utils.get_key_def("train_batch_size", config, default_batch_size) self.valid_batch_size = thelper.utils.get_key_def("valid_batch_size", config, default_batch_size) self.test_batch_size = thelper.utils.get_key_def("test_batch_size", config, default_batch_size) assert self.train_batch_size > 0, "batch size should be strictly positive integer" assert self.valid_batch_size > 0, "batch size should be strictly positive integer" assert self.test_batch_size > 0, "batch size should be strictly positive integer" self.train_scale = thelper.utils.get_key_def("train_scale", config, 1.0) self.valid_scale = thelper.utils.get_key_def("valid_scale", config, 1.0) self.test_scale = thelper.utils.get_key_def("test_scale", config, 1.0) default_collate_fn = default_collate if any([s in config for s in ["collate", "collate_fn"]]): if any([s in config for s in ["train_collate", "valid_collate", "test_collate", "train_collate_fn", "valid_collate_fn", "test_collate_fn"]]): raise AssertionError("specifying 'collate_fn' overrides all other (loader-specific) values") default_collate_fn = thelper.utils.import_function(config["collate_fn"]) self.train_collate_fn = thelper.utils.import_function(thelper.utils.get_key_def("train_collate_fn", config, default_collate_fn)) self.valid_collate_fn = thelper.utils.import_function(thelper.utils.get_key_def("valid_collate_fn", config, default_collate_fn)) self.test_collate_fn = thelper.utils.import_function(thelper.utils.get_key_def("test_collate_fn", config, default_collate_fn)) self.train_shuffle = thelper.utils.str2bool(thelper.utils.get_key_def(["shuffle", "train_shuffle"], config, True)) self.valid_shuffle = thelper.utils.str2bool(thelper.utils.get_key_def(["shuffle", "valid_shuffle"], config, False)) self.test_shuffle = thelper.utils.str2bool(thelper.utils.get_key_def(["shuffle", "test_shuffle"], config, False)) np.random.seed() # for seed generation below (if needed); will be reseeded afterwards test_seed = self._get_seed(["test_seed", "test_split_seed"], config, (int, str)) valid_seed = self._get_seed(["valid_seed", "valid_split_seed"], config, (int, str)) torch_seed = self._get_seed(["torch_seed"], config, int) numpy_seed = self._get_seed(["numpy_seed"], config, int) random_seed = self._get_seed(["random_seed"], config, int) torch.manual_seed(torch_seed) torch.cuda.manual_seed_all(torch_seed) np.random.seed(numpy_seed) random.seed(random_seed) self.seeds = { "test": test_seed, "valid": valid_seed, "torch": torch_seed, "numpy": numpy_seed, "random": random_seed } self.workers = config["workers"] if "workers" in config and config["workers"] >= 0 else 1 self.pin_memory = thelper.utils.str2bool(config["pin_memory"]) if "pin_memory" in config else False self.drop_last = thelper.utils.str2bool(config["drop_last"]) if "drop_last" in config else False default_sampler_config = None if "sampler" in config: if any([s in config for s in ["train_sampler", "valid_sampler", "test_sampler"]]): raise AssertionError("specifying 'sampler' overrides all other (loader-specific) values") default_sampler_config = config["sampler"] assert isinstance(default_sampler_config, dict), "invalid sampler config (should be dict)" self.train_sampler = thelper.utils.get_key_def("train_sampler", config, default_sampler_config) assert self.train_sampler is None or isinstance(self.train_sampler, (dict, torch.utils.data.sampler.Sampler)), \ "invalid training sampler (should be config dictionary or already-instantiated object)" self.valid_sampler = thelper.utils.get_key_def("valid_sampler", config, default_sampler_config) assert self.valid_sampler is None or isinstance(self.valid_sampler, (dict, torch.utils.data.sampler.Sampler)), \ "invalid valid sampler (should be config dictionary or already-instantiated object)" self.test_sampler = thelper.utils.get_key_def("test_sampler", config, default_sampler_config) assert self.test_sampler is None or isinstance(self.test_sampler, (dict, torch.utils.data.sampler.Sampler)), \ "invalid test sampler (should be config dictionary or already-instantiated object)" train_augs_targets = ["augments", "trainvalid_augments", "train_augments"] valid_augs_targets = ["augments", "trainvalid_augments", "eval_augments", "validtest_augments", "valid_augments"] test_augs_targets = ["augments", "eval_augments", "validtest_augments", "test_augments"] self.train_augments, self.train_augments_append = self._get_augments(train_augs_targets, "train", config) self.valid_augments, self.valid_augments_append = self._get_augments(valid_augs_targets, "valid", config) self.test_augments, self.test_augments_append = self._get_augments(test_augs_targets, "test", config) self.base_transforms = None if "base_transforms" in config and config["base_transforms"]: self.base_transforms = thelper.transforms.load_transforms(config["base_transforms"]) self.train_split = self._get_ratios_split("train", config) self.valid_split = self._get_ratios_split("valid", config) self.test_split = self._get_ratios_split("test", config) if not self.train_split and not self.valid_split and not self.test_split: raise AssertionError("data config must define a split for at least one loader type (train/valid/test)") self.total_usage = Counter(self.train_split) + Counter(self.valid_split) + Counter(self.test_split) self.skip_split_norm = thelper.utils.str2bool(thelper.utils.get_key_def( ["skip_norm", "skip_split_norm"], config, False)) self.skip_class_balancing = thelper.utils.str2bool(thelper.utils.get_key_def( ["skip_balancing", "skip_class_balancing", "skip_rebalancing", "skip_class_rebalancing"], config, False)) for name, usage in self.total_usage.items(): if usage != 1: normalize_ratios = None assert usage >= 0 if 0 < usage < 1 and not math.isclose(usage, 1) and not self.skip_split_norm: query_msg = f"dataset split for {name} has a ratio sum less than 1; do you want to normalize the split?\n\t(" query_msg += f"train={self.train_split[name] if name in self.train_split else 0:.03f}, " query_msg += f"valid={self.valid_split[name] if name in self.valid_split else 0:.03f}, " query_msg += f"test={self.test_split[name] if name in self.test_split else 0:.03f})\n" query_msg += "\t\tto\n\t(" query_msg += f"train={self.train_split[name] / usage if name in self.train_split else 0:.03f}, " query_msg += f"valid={self.valid_split[name] / usage if name in self.valid_split else 0:.03f}, " query_msg += f"test={self.test_split[name] / usage if name in self.test_split else 0:.03f})" normalize_ratios = thelper.utils.query_yes_no(query_msg, bypass="n") if (normalize_ratios or usage > 1) and usage > 0: if usage > 1: logger.warning("dataset split for '%s' sums to more than 1; will normalize..." % name) for subset in [self.train_split, self.valid_split, self.test_split]: if name in subset: subset[name] /= usage self.skip_verif = thelper.utils.str2bool(config["skip_verif"]) if "skip_verif" in config else True logger.debug("batch sizes:" + (f"\n\ttrain = {self.train_batch_size}" if self.train_split else "") + (f"\n\tvalid = {self.valid_batch_size}" if self.valid_split else "") + (f"\n\ttest = {self.test_batch_size}" if self.test_split else "")) logger.debug("samplers configs:" + (f"\n\ttrain = {self.train_sampler}" if self.train_sampler else "") + (f"\n\tvalid = {self.valid_sampler}" if self.valid_sampler else "") + (f"\n\ttest = {self.test_sampler}" if self.test_sampler else "")) logger.debug("scaling factors:" + (f"\n\ttrain = {self.train_scale}" if self.train_split else "") + (f"\n\tvalid = {self.valid_scale}" if self.valid_split else "") + (f"\n\ttest = {self.test_scale}" if self.test_split else "")) if self.drop_last: logger.debug("loaders will drop last batch if sample count not multiple of batch size") if self.base_transforms: logger.debug("base transforms: %s" % str(self.base_transforms))
@staticmethod def _get_seed(prefixes, config, stype): key = None for prefix in prefixes: if prefix in config: key = prefix break if key is not None: if not isinstance(config[key], stype): raise AssertionError("unexpected value type for field '%s'" % key) return config[key] seed = np.random.randint(2 ** 16) logger.info("setting '%s' to %d" % (prefixes[0], seed)) return seed @staticmethod def _get_ratios_split(prefix, config): key = prefix + "_split" if key not in config or not config[key]: return {} split = config[key] assert not any(ratio < 0 for ratio in split.values()) return split @staticmethod def _get_augments(targets, name, config): logger.debug("loading %s augments..." % name) for target in targets: if target in config and config[target]: augments, augments_append = thelper.transforms.load_augments(config[target]) if augments: logger.debug("will %s %s augments: %s" % ("append" if augments_append else "prefix", name, str(augments))) return augments, augments_append return None, False def _get_raw_split(self, indices): for name in self.total_usage: assert name in indices, f"dataset '{name}' does not exist" _indices, train_idxs, valid_idxs, test_idxs = {}, {}, {}, {} for name, indices in indices.items(): _indices[name] = copy.deepcopy(indices) train_idxs[name] = [] valid_idxs[name] = [] test_idxs[name] = [] indices = _indices shuffle = any([self.train_shuffle, self.valid_shuffle, self.test_shuffle]) if shuffle: np.random.seed(self.seeds["test"]) # test idxs will be picked first, then valid+train for idxs in indices.values(): np.random.shuffle(idxs) offsets = dict.fromkeys(self.total_usage, 0) for loader_idx, (idxs_map, ratio_map) in enumerate(zip([test_idxs, valid_idxs, train_idxs], [self.test_split, self.valid_split, self.train_split])): for name in self.total_usage.keys(): if name in ratio_map: count = int(round(ratio_map[name] * len(indices[name]))) assert count >= 0, "ratios should be non-negative values" begidx = offsets[name] endidx = min(begidx + count, len(indices[name])) idxs_map[name] = indices[name][begidx:endidx] offsets[name] = endidx if loader_idx == 0 and shuffle: np.random.seed(self.seeds["valid"]) # all test idxs are now picked, reshuffle for train/valid for name in self.total_usage.keys(): trainvalid_idxs = indices[name][offsets[name]:] np.random.shuffle(trainvalid_idxs) indices[name][offsets[name]:] = trainvalid_idxs if shuffle: np.random.seed(self.seeds["numpy"]) # back to default random state for future use return train_idxs, valid_idxs, test_idxs
[docs] def get_split(self, datasets, task): r"""Returns the train/valid/test sample indices split for a given dataset (name-parser) map. Note that the returned indices are unique, possibly shuffled, and never duplicated between sets. If the samples have a class attribute (i.e. the task is related to classification), the split will respect the initial distribution and apply the ratios within the classes themselves. For example, consider a dataset of three classes (:math:`A`, :math:`B`, and :math:`C`) that contains 100 samples such as: .. math:: |A| = 50,\;|B| = 30,\;|C| = 20 If we require a 80%-10%-10% ratio distribution for the training, validation, and test loaders respectively, the resulting split will contain the following sample counts: .. math:: \text{training loader} = {40A + 24B + 16C} .. math:: \text{validation loader} = {5A + 3B + 2C} .. math:: \text{test loader} = {5A + 3B + 2C} In the case of multi-label classification datasets, there is no guarantee that the classes will be balanced across the training/validation/test sets. Instead, for a given class list, the classes with fewer samples will be split first. Args: datasets: the map of datasets to split, where each has a name (key) and a parser (value). task: a task object that should be compatible with all provided datasets (can be ``None``). Returns: A three-element tuple containing the maps of the training, validation, and test sets respectively. These maps associate dataset names to a list of sample indices. """ dataset_sizes = {} must_split = {} global_size = 0 for dataset_name, dataset in datasets.items(): assert isinstance(dataset, thelper.data.Dataset) or isinstance(dataset, thelper.data.ExternalDataset), \ f"unexpected dataset type for '{dataset_name}'" dataset_sizes[dataset_name] = len(dataset) global_size += dataset_sizes[dataset_name] # if a single dataset is used in more than a single loader, we cannot skip the rebalancing below must_split[dataset_name] = sum([dataset_name in split for split in [self.train_split, self.valid_split, self.test_split]]) > 1 global_size = sum(len(dataset) for dataset in datasets.values()) logger.info("splitting datasets with parsed sizes = %s" % str(dataset_sizes)) must_split = any(must_split.values()) if task is not None and isinstance(task, thelper.tasks.Classification) and not self.skip_class_balancing and must_split: # note: with current impl, all class sets will be shuffled the same way... (shouldnt matter, right?) logger.debug("will split evenly over %d classes..." % len(task.class_names)) unset_class_key = "<unset>" global_class_names = task.class_names + [unset_class_key] # extra name added for unlabeled samples (if needed!) sample_maps, sample_counts = {}, {cname: 0 for cname in global_class_names} for dataset_name, dataset in datasets.items(): # fetching a reference to the list of samples here allows us to bypass the 'loading' process and possibly # directly access sample labels/groundtruth (assuming it is already loaded) samples = dataset.samples if hasattr(dataset, "samples") and dataset.samples is not None \ and len(dataset.samples) == len(dataset) else dataset if isinstance(dataset, thelper.data.ExternalDataset): if hasattr(samples, "samples") and samples.samples is not None and len(samples.samples) == len(samples): sample_maps[dataset_name] = task.get_class_sample_map(samples.samples, unset_class_key) else: logger.warning(f"must fully parse the external dataset '{dataset_name}' for balanced intra-class shuffling;" + " this might take a while!\n\t...consider making a dataset interface that can return labels" + " only, it would greatly speed up the analysis of class distributions\n\t...you could also" + " add the 'skip_class_balancing' flag to your data configuration to skip this rebalancing") # to allow glitch-less tqdm printing after latest logger output sys.stdout.flush(), sys.stderr.flush(), time.sleep(0.01) samples = [] for sample in tqdm.tqdm(dataset): assert task.gt_key in sample, f"could not find label key ('{task.gt_key}') in sample dict" samples.append({task.gt_key: sample[task.gt_key]}) sample_maps[dataset_name] = task.get_class_sample_map(samples, unset_class_key) else: sample_maps[dataset_name] = task.get_class_sample_map(samples, unset_class_key) for class_name, class_samples in sample_maps[dataset_name].items(): assert class_name in sample_counts sample_counts[class_name] += len(class_samples) sample_counts = {k: v for k, v in sorted(sample_counts.items(), key=lambda i: i[1])} backlist_idxs = {d: np.asarray([], np.int32) for d in datasets} train_idxs, valid_idxs, test_idxs = {d: [] for d in datasets}, {d: [] for d in datasets}, {d: [] for d in datasets} for class_name in sample_counts.keys(): curr_class_samples = {} for dataset_name in datasets: class_samples = sample_maps[dataset_name][class_name] if class_name in sample_maps[dataset_name] else [] if task.multi_label: class_samples = np.setdiff1d(class_samples, backlist_idxs[dataset_name]) else: assert len(np.intersect1d(class_samples, backlist_idxs[dataset_name])) == 0, \ "duplicated sample idx across classes" curr_class_samples[dataset_name] = class_samples logger.debug("dataset '{}' class #{} '{}' sample count: {} ({:0.1f}% of dataset, {:0.1f}% of total)".format( dataset_name, global_class_names.index(class_name), class_name, len(class_samples), int(100 * len(class_samples) / dataset_sizes[dataset_name]), int(100 * len(class_samples) / global_size))) class_train_idxs, class_valid_idxs, class_test_idxs = self._get_raw_split(curr_class_samples) for dname in datasets: for subset_idxs, class_subset_idxs in zip([train_idxs[dname], valid_idxs[dname], test_idxs[dname]], [class_train_idxs[dname], class_valid_idxs[dname], class_test_idxs[dname]]): # idx-label pairs below are passed through to the sampler for label-specific indexing (if needed) subset_idxs.extend(list(zip(class_subset_idxs, [class_name] * len(class_subset_idxs)))) backlist_idxs[dname] = np.append(backlist_idxs[dname].astype(np.int32), np.asarray(class_subset_idxs)) else: # no balancing to be done dataset_indices = {} for dataset_name in datasets: # note: all indices paired with 'None' below as class is ignored; used for compatibility with code above dataset_indices[dataset_name] = list( zip(list(range(dataset_sizes[dataset_name])), [None] * dataset_sizes[dataset_name])) train_idxs, valid_idxs, test_idxs = self._get_raw_split(dataset_indices) return train_idxs, valid_idxs, test_idxs
[docs] def create_loaders(self, datasets, train_idxs, valid_idxs, test_idxs): """Returns the data loaders for the train/valid/test sets based on a prior split. This function essentially takes the dataset parser interfaces and indices maps, and instantiates data loaders that are ready to produce samples for training or evaluation. Note that the dataset parsers will be deep-copied in each data loader, meaning that they should ideally not contain a persistent loading state or a large buffer. Args: datasets: the map of dataset parsers, where each has a name (key) and a parser (value). train_idxs: training data samples indices map. valid_idxs: validation data samples indices map. test_idxs: test data samples indices map. Returns: A three-element tuple containing the training, validation, and test data loaders, respectively. """ loaders = [] for idxs_map, (augs, augs_append), shuffle, scale, sampler, batch_size, collate_fn \ in zip([train_idxs, valid_idxs, test_idxs], [(self.train_augments, self.train_augments_append), (self.valid_augments, self.valid_augments_append), (self.test_augments, self.test_augments_append)], [self.train_shuffle, self.valid_shuffle, self.test_shuffle], [self.train_scale, self.valid_scale, self.test_scale], [self.train_sampler, self.valid_sampler, self.test_sampler], [self.train_batch_size, self.valid_batch_size, self.test_batch_size], [self.train_collate_fn, self.valid_collate_fn, self.test_collate_fn]): loader_sample_idx_offset = 0 loader_sample_classes = [] loader_sample_idxs = [] loader_datasets = [] for dataset_name, sample_idxs in idxs_map.items(): if not sample_idxs: continue # todo: investigate need to copy at all? if not datasets[dataset_name].deepcopy: dataset = copy.copy(datasets[dataset_name]) else: dataset = copy.deepcopy(datasets[dataset_name]) if augs: augs_copy = copy.deepcopy(augs) if dataset.transforms is not None: if augs_append: dataset.transforms = thelper.transforms.Compose([dataset.transforms, augs_copy]) else: dataset.transforms = thelper.transforms.Compose([augs_copy, dataset.transforms]) else: dataset.transforms = augs_copy for sample_idx_idx in range(len(sample_idxs)): # values were paired in tuples earlier, 0=idx, 1=label loader_sample_idxs.append(sample_idxs[sample_idx_idx][0] + loader_sample_idx_offset) loader_sample_classes.append(sample_idxs[sample_idx_idx][1]) loader_sample_idx_offset += len(dataset) loader_datasets.append(dataset) if len(loader_datasets) > 0: dataset = torch.utils.data.ConcatDataset(loader_datasets) if len(loader_datasets) > 1 else loader_datasets[0] if sampler is not None: if isinstance(sampler, dict): sampler_type = thelper.utils.get_key("type", sampler, msg="sampler config dict missing 'type' attribute") sampler_type = thelper.utils.import_class(sampler_type) sampler_params = thelper.utils.get_key_def(["params", "param", "parameters", "kwargs"], sampler, {}) sampler_pass_labels = thelper.utils.str2bool(thelper.utils.get_key_def("pass_labels", sampler, False)) sampler_pass_labels_param_name = thelper.utils.get_key_def("pass_labels_param_name", sampler, "labels") if sampler_pass_labels: sampler_params = {**sampler_params, sampler_pass_labels_param_name: loader_sample_classes} sampler_expected_params = thelper.utils.get_func_params(sampler_type) if "seeds" in sampler_expected_params: sampler_params = {**sampler_params, "seeds": self.seeds} if "scale" in sampler_expected_params: assert "scale" not in sampler_params, "specified scale in both sampler config and loader config" sampler_params = {**sampler_params, "scale": scale} else: assert scale == 1.0, f"could not apply scale factor to sampler with type '{str(sampler_type)}'" sampler = sampler_type(loader_sample_idxs, **sampler_params) else: assert scale == 1.0, f"could not apply scale factor to (pre-instantiated) sampler with type '{str(sampler)}'" assert isinstance(sampler, torch.utils.data.sampler.Sampler), "invalid sampler type (should be torch-compatible)" else: if shuffle: sampler = thelper.data.SubsetRandomSampler(loader_sample_idxs, seeds=self.seeds, scale=scale) else: assert scale == 1.0, "sequential sampler currently does not handle scale changes (turn on shuffling)" sampler = thelper.data.SubsetSequentialSampler(loader_sample_idxs) assert hasattr(sampler, "__len__") assert batch_size > 0 loaders.append(DataLoader(dataset=dataset, batch_size=batch_size, sampler=sampler, num_workers=self.workers, collate_fn=collate_fn, pin_memory=self.pin_memory, drop_last=self.drop_last, seeds=self.seeds)) else: loaders.append(None) train_loader, valid_loader, test_loader = loaders logger.info("initialized loaders with batch counts:" + (f"\n\ttrain = {len(train_loader)}" if train_loader else "") + (f"\n\tvalid = {len(valid_loader)}" if valid_loader else "") + (f"\n\ttest = {len(test_loader)}" if test_loader else "")) return train_loader, valid_loader, test_loader
[docs] def get_base_transforms(self): """Returns the (global) sample transformation operations parsed in the data configuration.""" return self.base_transforms