Skip to content

PyTorch implementations of `BatchSampler` that under/over sample according to a chosen parameter alpha, in order to create a balanced training distribution.

Notifications You must be signed in to change notification settings

khornlund/pytorch-balanced-sampler

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch Balanced Sampler

PyTorch implementations of BatchSampler that under/over sample according to a chosen parameter alpha, in order to create a balanced training distribution.

./resources/sample-distributions-2015-data.png

Usage

SamplerFactory

The factory class constructs a pytorch BatchSampler to yield balanced samples from a training distribution.

from pytorch_balanced_sampler.sampler import SamplerFactory

# which sample indices belong to each of 4 classes
class_idxs = [
    [1, 2, 3, 4, 5],
    [6, 7, 8, 9, 10],
    [11, 12],
    [13, 14, 15, 16, 17, 18, 19, 20]
]

batch_sampler = SamplerFactory().get(
    class_idxs=class_idxs,
    batch_size=32,
    n_batches=250,
    alpha=0.5,
    kind='fixed'
)

dataset = Dataset( ... )
data_loader = DataLoader(dataset, batch_sampler=batch_sampler)

for data, target in data_loader:
    # nice balanced batches!
    ...

Class Balancing

Based on the choice of an alpha parameter in [0, 1] the sampler will adjust the sample distribution to be between true distribution (alpha = 0), and a uniform distribution (alpha = 1).

Overrepresented classes will be undersampled, and underrepresented classes oversampled. Here's an example from an imbalanced data distribution I was working with a while ago:

./resources/sample-distributions-2015-data.png

Fixed Batch Distributions

If you select kind='fixed', each batch generated will contain a consistent proportion of classes. Eg. if we have 5 classes, we might receive batches like:

Batch: 0
Classes: [1, 0, 0, 0, 2, 4, 0, 2, 0, 0, 3, 2, 1, 0, 2, 0, 0, 3, 0, 0, 4, 4, 0, 2, 1, 3, 3, 1, 2, 0, 0, 4]
Counts: {0: 14, 1: 4, 2: 6, 3: 4, 4: 4}

Batch: 1
Classes: [4, 1, 1, 2, 0, 0, 0, 4, 2, 4, 0, 3, 1, 3, 0, 0, 3, 2, 0, 2, 4, 2, 0, 0, 2, 3, 0, 1, 0, 0, 0, 0]
Counts: {0: 14, 1: 4, 2: 6, 3: 4, 4: 4}

Batch: 2
Classes: [0, 4, 0, 0, 0, 3, 3, 2, 0, 4, 2, 3, 0, 3, 2, 0, 0, 1, 2, 2, 0, 1, 0, 0, 4, 0, 2, 1, 1, 4, 0, 0]
Counts: {0: 14, 1: 4, 2: 6, 3: 4, 4: 4}

Note that the class counts are the same for each batch.

Random Batch Distributions

If you don't want to fix the number of each class in each batch, you can select kind='random', which will use sampling with replacement. The samples will be weighted as to produce the target class distribution on average.

Authors

pytorch_balanced_sampler was written by Karl Hornlund.

About

PyTorch implementations of `BatchSampler` that under/over sample according to a chosen parameter alpha, in order to create a balanced training distribution.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages