PyTorch implementations of BatchSampler that under/over sample according to a chosen parameter
alpha, in order to create a balanced training distribution.
The factory class constructs a pytorch BatchSampler to yield balanced samples from a
training distribution.
from pytorch_balanced_sampler.sampler import SamplerFactory
# which sample indices belong to each of 4 classes
class_idxs = [
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
[11, 12],
[13, 14, 15, 16, 17, 18, 19, 20]
]
batch_sampler = SamplerFactory().get(
class_idxs=class_idxs,
batch_size=32,
n_batches=250,
alpha=0.5,
kind='fixed'
)
dataset = Dataset( ... )
data_loader = DataLoader(dataset, batch_sampler=batch_sampler)
for data, target in data_loader:
# nice balanced batches!
...
Based on the choice of an alpha parameter in [0, 1] the sampler will adjust the sample
distribution to be between true distribution (alpha = 0), and a uniform distribution
(alpha = 1).
Overrepresented classes will be undersampled, and underrepresented classes oversampled. Here's an example from an imbalanced data distribution I was working with a while ago:
If you select kind='fixed', each batch generated will contain a consistent proportion of
classes. Eg. if we have 5 classes, we might receive batches like:
Batch: 0
Classes: [1, 0, 0, 0, 2, 4, 0, 2, 0, 0, 3, 2, 1, 0, 2, 0, 0, 3, 0, 0, 4, 4, 0, 2, 1, 3, 3, 1, 2, 0, 0, 4]
Counts: {0: 14, 1: 4, 2: 6, 3: 4, 4: 4}
Batch: 1
Classes: [4, 1, 1, 2, 0, 0, 0, 4, 2, 4, 0, 3, 1, 3, 0, 0, 3, 2, 0, 2, 4, 2, 0, 0, 2, 3, 0, 1, 0, 0, 0, 0]
Counts: {0: 14, 1: 4, 2: 6, 3: 4, 4: 4}
Batch: 2
Classes: [0, 4, 0, 0, 0, 3, 3, 2, 0, 4, 2, 3, 0, 3, 2, 0, 0, 1, 2, 2, 0, 1, 0, 0, 4, 0, 2, 1, 1, 4, 0, 0]
Counts: {0: 14, 1: 4, 2: 6, 3: 4, 4: 4}
Note that the class counts are the same for each batch.
If you don't want to fix the number of each class in each batch, you can select kind='random',
which will use sampling with replacement. The samples will be weighted as to produce the target
class distribution on average.
pytorch_balanced_sampler was written by Karl Hornlund.
