"""Samplers module.
This module contains classes used for raw dataset rebalancing or augmentation.
All samplers here should aim to be compatible with PyTorch's sampling interface
(torch.utils.data.sampler.Sampler) so that they can be instantiated at runtime
through a configuration file and used as the input of a data loader.
"""
import collections
import copy
import logging
import numpy as np
import torch
import torch.utils.data.sampler
import thelper.data.utils
logger = logging.getLogger(__name__)
[docs]class WeightedSubsetRandomSampler(torch.utils.data.sampler.Sampler):
r"""Provides a rebalanced list of sample indices to use in a data loader.
Given a list of sample indices and the corresponding list of class labels, this sampler
will produce a new list of indices that rebalances the distribution of samples according
to a specified strategy. It can also optionally scale the dataset's total sample count to
avoid undersampling large classes as smaller ones get bigger.
The currently implemented strategies are:
* ``random``: will return a list of randomly picked samples based on the multinomial \
distribution of the initial class weights. This sampling is done with replacement, \
meaning that each index is picked independently of the already-picked ones.
* ``uniform``: will rebalance the dataset by normalizing the sample count of all classes, \
oversampling and undersampling as required to distribute all samples equally. All \
removed or duplicated samples are selected randomly without replacement whenever possible.
* ``root``: will rebalance the dataset by normalizing class weight using an n-th degree \
root. More specifically, for a list of initial class weights :math:`W^0=\{w_1^0, w_2^0, ... w_n^0\}`, \
we compute the adjusted weight :math:`w_i` of each class via:
.. math::
w_i = \frac{\sqrt[\leftroot{-1}\uproot{3}n]{w_i^0}}{\sum_j\sqrt[\leftroot{-1}\uproot{3}n]{w_j^0}}
Then, according to the new distribution of weights, all classes are oversampled and
undersampled as required to reobtain the dataset's total sample count (which may be
scaled). All removed or duplicated samples are selected randomly without replacement
whenever possible.
Note that with the ``root`` strategy, if a very large root degree ``n`` is used, this
strategy is equivalent to ``uniform``. If the degree is one, the original weights will
be used for sampling. The ``root`` strategy essentially provides a flexible solution to
rebalance very uneven label sets where uniform over/undersampling would be too aggressive.
By default, this interface will try to keep the dataset size constant and balance oversampling
with undersampling. If undersampling is undesired, the user can increase the total dataset
size via a scale factor. Finally, note that the rebalanced list of indices is generated by
this interface every time the ``__iter__`` function is called, meaning two consecutive lists
might not contain the exact same indices.
Example configuration file::
# ...
# the sampler is defined inside the 'loaders' field
"loaders": {
# ...
# this field is completely optional, and can be omitted entirely
"sampler": {
# the type of the sampler we want to instantiate
"type": "thelper.data.samplers.WeightedSubsetRandomSampler",
# the parameters passed to the sampler's constructor
"params": {
"stype": "root3",
"scale": 1.2
},
# specifies whether the sampler should receive class labels
"pass_labels": true
},
# ...
},
# ...
Attributes:
nb_samples: total number of samples to rebalance (i.e. scaled size of original dataset).
label_groups: map that splits all samples indices into groups based on labels.
stype: name of the rebalancing strategy to use.
indices: copy of the original list of sample indices provided in the constructor.
sample_weights: list of weights used for random sampling.
label_counts: number of samples in each class for the ``uniform`` and ``root`` strategies.
seeds: dictionary of seeds to use when initializing RNG state.
epoch: epoch number used to reinitialize the RNG to an epoch-specific state.
.. seealso::
| :func:`thelper.data.utils.create_loaders`
| :func:`thelper.data.utils.get_class_weights`
"""
[docs] def __init__(self, indices, labels, stype="uniform", scale=1.0, seeds=None, epoch=0):
"""Receives sample indices, labels, rebalancing strategy, and dataset scaling factor.
This function will validate all input arguments, parse and categorize samples according to
labels, initialize rebalancing parameters, and determine sample counts for each valid class.
Note that an empty list of indices is an acceptable input; the resulting object will also
create and empty list of samples when ``__iter__`` is called.
Args:
indices: list of integers representing the indices of samples of interest in the dataset.
labels: list of labels tied to the list of indices (must be the same length).
stype: rebalancing strategy given as a string. Should be either "random", "uniform", or
"rootX", where the 'X' is the degree to use in the root computation (float).
scale: scaling factor used to increase/decrease the final number of sample indices to
generate while rebalancing.
seeds: dictionary of seeds to use when initializing RNG state.
epoch: epoch number used to reinitialize the RNG to an epoch-specific state.
"""
super().__init__(None)
assert isinstance(indices, (list, np.ndarray)) and isinstance(labels, (list, np.ndarray)), \
"expected indices and labels to be provided as lists"
assert len(indices) == len(labels), "mismatched indices/labels list sizes"
assert isinstance(scale, float) and scale >= 0, "invalid scale parameter; should be greater than zero"
self.seeds = {}
if seeds is not None:
assert isinstance(seeds, dict), "unexpected seed pack type"
self.seeds = seeds
assert isinstance(epoch, int) and epoch >= 0, "invalid epoch index value"
self.epoch = epoch
self.nb_samples = int(round(len(indices) * scale))
if self.nb_samples > 0:
self.stype = stype
self.indices = copy.deepcopy(indices)
self.label_groups = {}
for idx, label in enumerate(labels):
if label in self.label_groups:
self.label_groups[label].append(indices[idx])
else:
self.label_groups[label] = [indices[idx]]
assert isinstance(stype, str) and (stype in ["uniform", "random"] or "root" in stype), \
"unexpected sampling type"
if stype == "random":
self.sample_weights = [1.0 / len(self.label_groups[label]) for label in labels]
else:
weights = thelper.data.utils.get_class_weights(self.label_groups, stype, invmax=False)
self.label_counts = {}
curr_nb_samples, max_sample_label = 0, None
for label_idx, (label, _) in enumerate(self.label_groups.items()):
self.label_counts[label] = int(self.nb_samples * weights[label])
curr_nb_samples += self.label_counts[label]
if max_sample_label is None or len(self.label_groups[label]) > len(self.label_groups[max_sample_label]):
max_sample_label = label
if curr_nb_samples != self.nb_samples:
self.label_counts[max_sample_label] += self.nb_samples - curr_nb_samples
[docs] def set_epoch(self, epoch=0):
"""Sets the current epoch number in order to offset the RNG state for sampling."""
assert isinstance(epoch, int) and epoch >= 0, "invalid epoch index value"
self.epoch = epoch
def __iter__(self):
"""Returns the list of rebalanced sample indices to load.
Note that the indices are repicked every time this function is called, meaning that samples
eliminated due to undersampling (or duplicated due to oversampling) might not receive the same
treatment twice.
This function will reseed the RNGs it uses every time it is called, and revert their state before
returning its output.
"""
if self.nb_samples == 0:
self.epoch += 1
return iter([])
rng_state = None
if "torch" in self.seeds:
rng_state = torch.random.get_rng_state()
torch.random.manual_seed(self.seeds["torch"] + self.epoch)
assert self.stype in ["random", "uniform"] or "root" in self.stype, "invalid stype"
if self.stype == "random":
result = (self.indices[idx] for idx in torch.multinomial(
torch.FloatTensor(self.sample_weights), self.nb_samples, replacement=True))
else: # if self.stype == "uniform" or "root" in self.stype:
indices = []
for label, count in self.label_counts.items():
while count > 0:
subidxs = torch.randperm(len(self.label_groups[label]))
for subidx in range(min(count, len(subidxs))):
indices.append(self.label_groups[label][subidxs[subidx]])
count -= len(subidxs)
assert len(indices) == self.nb_samples, "messed up something internally..."
result = (indices[i] for i in torch.randperm(len(indices)))
if rng_state is not None:
torch.random.set_rng_state(rng_state)
self.epoch += 1
return result
def __len__(self):
"""Returns the number of sample indices that will be generated by this interface.
This number is the scaled size of the originally provided sample indices list.
"""
return self.nb_samples
[docs]class SubsetRandomSampler(torch.utils.data.sampler.Sampler):
r"""Samples elements randomly from a given list of indices, without replacement.
This specialization handles seeding based on the epoch number, and scaling (via duplication/decimation)
of samples.
Arguments:
indices (list): a list of indices
seeds (dict): dictionary of seeds to use when initializing RNG state.
epoch (int): epoch number used to reinitialize the RNG to an epoch-specific state.
scale (float): scaling factor used to increase/decrease the final number of samples.
"""
[docs] def __init__(self, indices, seeds=None, epoch=0, scale=1.0):
super().__init__(indices)
self.seeds = {}
if seeds is not None:
assert isinstance(seeds, dict), "unexpected seed pack type"
self.seeds = seeds
assert isinstance(epoch, int) and epoch >= 0, "invalid epoch index value"
self.epoch = epoch
self.indices = indices
assert isinstance(scale, float) and scale >= 0, "invalid scale parameter; should be greater than zero"
self.num_samples = int(round(len(self.indices) * scale))
[docs] def set_epoch(self, epoch=0):
"""Sets the current epoch number in order to offset the RNG state for sampling."""
assert isinstance(epoch, int) and epoch >= 0, "invalid epoch index value"
self.epoch = epoch
def __iter__(self):
rng_state = None
if "torch" in self.seeds:
rng_state = torch.random.get_rng_state()
torch.random.manual_seed(self.seeds["torch"] + self.epoch)
indices = []
max_samples = len(self.indices)
req_count = self.num_samples
while req_count > 0:
subidxs = torch.randperm(max_samples)
for subidx in range(min(req_count, max_samples)):
indices.append(self.indices[subidxs[subidx]])
req_count -= max_samples
result = (indices[i] for i in torch.randperm(len(indices)))
if rng_state is not None:
torch.random.set_rng_state(rng_state)
self.epoch += 1
return result
def __len__(self):
return self.num_samples
[docs]class SubsetSequentialSampler(torch.utils.data.sampler.Sampler):
r"""Samples element indices sequentially, always in the same order.
Arguments:
indices (list): a list of indices
"""
[docs] def __init__(self, indices):
super().__init__(indices)
self.indices = indices
def __iter__(self):
return iter(self.indices)
def __len__(self):
return len(self.indices)
[docs]class FixedWeightSubsetSampler(torch.utils.data.sampler.Sampler):
r"""Provides a rebalanced list of sample indices to use in a data loader.
Given a list of sample indices and the corresponding list of class labels, this sampler
will produce a new list of indices that rebalances the distribution of samples according
to a provided array of weights.
Example configuration file::
# ...
# the sampler is defined inside the 'loaders' field
"loaders": {
# ...
# this field is completely optional, and can be omitted entirely
"sampler": {
# the type of the sampler we want to instantiate
"type": "thelper.data.samplers.FixedWeightSubsetSampler",
# the parameters passed to the sampler's constructor
"params": {
"weights": {
# the weights must be provided using class name pairs
"class_A": 0.1,
"class_B": 5.0,
"class_C": 1.0,
# ...
}
},
# specifies whether the sampler should receive class labels
"pass_labels": true
},
# ...
},
# ...
Attributes:
nb_samples: total number of samples to rebalance (i.e. scaled size of original dataset).
weights: weight map to use for sampling each class
indices: copy of the original list of sample indices provided in the constructor.
seeds: dictionary of seeds to use when initializing RNG state.
epoch: epoch number used to reinitialize the RNG to an epoch-specific state.
.. seealso::
| :func:`thelper.data.utils.create_loaders`
| :func:`thelper.data.utils.get_class_weights`
"""
[docs] def __init__(self, indices, labels, weights, seeds=None, epoch=0):
"""Receives sample indices, labels, rebalancing strategy, and dataset scaling factor.
This function will validate all input arguments, parse and categorize samples according to
labels, initialize rebalancing parameters, and determine sample counts for each valid class.
Note that an empty list of indices is an acceptable input; the resulting object will also
create and empty list of samples when ``__iter__`` is called.
Args:
indices: list of integers representing the indices of samples of interest in the dataset.
labels: list of labels tied to the list of indices (must be the same length).
weights: weight map to use for sampling each class.
seeds: dictionary of seeds to use when initializing RNG state.
epoch: epoch number used to reinitialize the RNG to an epoch-specific state.
"""
super().__init__(None)
assert isinstance(indices, (list, np.ndarray)) and isinstance(labels, (list, np.ndarray)), \
"expected indices and labels to be provided as lists"
assert len(indices) == len(labels), "mismatched indices/labels list sizes"
self.seeds = {}
if seeds is not None:
assert isinstance(seeds, dict), "unexpected seed pack type"
self.seeds = seeds
assert isinstance(epoch, int) and epoch >= 0, "invalid epoch index value"
self.epoch = epoch
assert isinstance(weights, (dict, collections.OrderedDict)), "invalid weights map type"
assert all([weight >= 0 for weight in weights.values()]), "weights must all be non-negative"
self.weights = weights
self.label_groups = {}
for idx, label in enumerate(labels):
if label in self.label_groups:
self.label_groups[label].append(indices[idx])
else:
self.label_groups[label] = [indices[idx]]
self.class_sample_counts = {label: int(round(self.weights[label] * len(self.label_groups[label])))
for label in self.label_groups}
self.nb_samples = sum(self.class_sample_counts.values())
[docs] def set_epoch(self, epoch=0):
"""Sets the current epoch number in order to offset the RNG state for sampling."""
assert isinstance(epoch, int) and epoch >= 0, "invalid epoch index value"
self.epoch = epoch
def __iter__(self):
"""Returns the list of rebalanced sample indices to load.
Note that the indices are repicked every time this function is called, meaning that samples
eliminated due to undersampling (or duplicated due to oversampling) might not receive the same
treatment twice.
This function will reseed the RNGs it uses every time it is called, and revert their state before
returning its output.
"""
if self.nb_samples == 0:
self.epoch += 1
return iter([])
rng_state = None
if "torch" in self.seeds:
rng_state = torch.random.get_rng_state()
torch.random.manual_seed(self.seeds["torch"] + self.epoch)
result = []
for label, group in self.label_groups.items():
count_to_sample = self.class_sample_counts[label]
while count_to_sample > 0:
perm_size = min(count_to_sample, len(group))
result += [group[idx] for idx in torch.randperm(len(group))[0:perm_size].tolist()]
count_to_sample -= perm_size
assert len(result) == self.nb_samples
if rng_state is not None:
torch.random.set_rng_state(rng_state)
self.epoch += 1
return iter(result)
def __len__(self):
"""Returns the number of sample indices that will be generated by this interface.
This number is the scaled size of the originally provided sample indices list.
"""
return self.nb_samples