Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Design discussion] Batches and resampling #97

Open
ablaom opened this issue Mar 2, 2021 · 8 comments
Open

[Design discussion] Batches and resampling #97

ablaom opened this issue Mar 2, 2021 · 8 comments

Comments

@ablaom
Copy link
Collaborator

ablaom commented Mar 2, 2021

There is a fundamental problem with the way we handle batches, at least as far applications where the extra GPU speed gained with batching is important. The issue is the incompatibility of batching with observation resampling, as conventionally understood.

So, for example, if we are wrapping a model for cross-validation, then the observations get split up multiple times into different test/train sets. At present, a "batch" is understood to consist of multiple "observations", which means that resampling a MLJFlux model breaks the batches, an expensive operation for large data.

I'm guessing this is a very familiar problem to people in deep learning and so am copying some of them in for comment and will post a link on slack. The solution I am considering for MLJ is to regard a "batch" of images as an unbreakable object that we consequently view as an observation, by definition. It would be natural do introduce a new parametric scientific type Batch{SomeAtomicScitype} to articulate a model's participation in this convention.

Thoughts anyone?

Some consequences of this breaking change would be:

  • batch_size disappears as a hyper-parameter of MLJFlux models, at least for ImageClassifier, but probably for all the models, for simplicity. So changing the batch size becomes the responsibility of a pre-processing transformer external to the model. I need to give some thought to transformers that reduce the number of observations, when inserted into MLJ pipelines (and learning networks, more generally). If that works, "smart" training of MLJ pipelines would mean no "re-batching" when retraining the composite model, unless the batch size changes, which is good.

  • with this change one could implement the reformat and selectrows (now same as "select batches") functions that constitute buy-in for MLJ's new data front-end.

@ablaom
Copy link
Collaborator Author

ablaom commented Mar 2, 2021

cc @lorenzoh @ToucheSir @ayush-1506

@lorenzoh
Copy link
Member

lorenzoh commented Mar 2, 2021

I'm not all too familiar with the way MLJFlux.jl currently handles this, but I'll try to give some comments and what our current approach on the deep learning side is.

In FastAI.jl, batching is mostly an implementation detail that is independent of the semantic transformations applied to observations. Cross-validation is rarely done since training time is usually fairly long already, but data containers are split into train/test sets and the training set is reshuffled every epoch. Both are done as lazy operations on the containers so they don't incur any performance problems. I'm not sure if that's what you mean, but I guess this does "break the batches", but many datasets (e.g. in computer vision) are larger than memory anyway and so each observation is loaded and batched every single epoch. Since reloading batches is unavoidable, doing it as fast as possible is handled by DataLoaders.jl with some of the important performance aspects explained here.

Regarding the interface used in MLJFlux.jl: I am curious if it might make sense to separate an MLJFlux model into a) the Flux model and b) the training hyperparameters. This would allow factoring out common configuration parameters. I'm not sure if it makes sense to make Batch a scientific type as it does not have a clear semantic meaning and might better be treated as an implementation detail, especially if it is only relevant within MLJFlux.jl models. That said, batch size is a hyperparameter that affects regularization, so should be part of the training parameters part of the model imo.

@darsnack
Copy link
Member

darsnack commented Mar 2, 2021

Without speaking directly to MLJFlux, this is how eachbatch from MLDataPattern.jl works. If you consider the "data iterator" returned by itr = eachbatch(data, batch_size), then itr[1] (i.e. the first observation of this data iterator) is a whole batch. Contrast this with data[1] where data is the full un-batched dataset. In this case, data[1] is a single image. This matches your

regard a "batch" of images as an unbreakable object that we consequently view as an observation, by definition

As a consequence, applying a resampling iterator on top of a batch iterator will resample by batches.

In deep learning, the batching is usually the last step before augmentation. You do resampling, shuffling, splitting on the entire dataset, then batch each split, then augment each batch. k-folds style splitting is not generally done, but if you were to do it, it would make more sense to me to do that on the dataset than the batch.

@ToucheSir
Copy link
Member

I would also add that the type of data in a batch can be quite heterogeneous. Think nested dicts, strings and whatnot. If MLJ(Flux) wants to handle that level of complexity, then it's worth making a distinction between batches pre- and post-collation. The former is a SOA collection of observations with the same structure, while the latter is an AOS structure which is only required because models expect to work with contiguous memory regions.

@darsnack
Copy link
Member

darsnack commented Mar 2, 2021

If that works, "smart" training of MLJ pipelines would mean no "re-batching" when retraining the composite model, unless the batch size changes, which is good.

If I'm understanding this correctly, this means caching batches? That's something you can do (MLDataPattern.jl does), but I think as the dataset scales to DL sizes, re-batching is quite cheap (especially when the memory could be used for something else).

@ablaom
Copy link
Collaborator Author

ablaom commented Mar 9, 2021

@darsnack @ToucheSir @lorenzoh

Thank you for your comments. These are very helpful. I will take a closer look at DataLoaders.jl, which would be nice to integrate if possible.

I appear to be mistaken that re-batching is a possible performance bottleneck, partly in view of the fact that re-sampling beyond a holdout set is not so common in DL. This is good to know.

I may misunderstand, but it seems there is not complete consenus here on where the responsibility for batching should lie. On the one hand, as @lorenzoh suggests, batching should be something handled on the training side, not in data preparation:

In FastAI.jl, batching is mostly an implementation detail that is independent of the semantic transformations applied to observations.
... batch size is a hyperparameter that affects regularization, so should be part of the training parameters part of the model imo.

On the other hand, there is the practice pointed out by @darsnack that augmentation is performed after batching (so each batching gets the same amount of augmentation?). This suggests batching is a pre-training data preparation step, no?

In deep learning, the batching is usually the last step before augmentation. You do resampling, shuffling, splitting on the entire dataset, then batch each split, then augment each batch.

In the current approach, MLJFlux considers batch_size as a "training" hyper-parameter and augmenting data after batching would not be possible. Would this be a deal breaker for many DL applications?

@lorenzoh Regarding this comment:

Regarding the interface used in MLJFlux.jl: I am curious if it might make sense to separate an MLJFlux model into a) the Flux model and b) the training hyperparameters.

I think this is essentially the case. An MLJ "model" is a struct with a bunch of "training parameters" as fields (regularisation, parameters, batch size, optimiser choice, and so forth) plus a single field called a builder that furnishes instructions on how to build the Flux "model" (new meaning) once the data has been inspected. The Flux model is something that can be accessed using the generic fitted_params method for MLJ machines, and so will be accessible, for example, to callback functions in iteration control.

@ToucheSir
Copy link
Member

Augmentation can be done on or offline. Usually you would do offline augmentation as either a performance measure or as a way to increase the effective size of a dataset. The latter is possible to do in an online fashion as well, but it requires quite a bit more thought about how attributes like dataset length change as a result. In any case, MLJ(Flux) probably can just assume that the user will have done whatever offline augmentation is necessary already?

WRT per-sample vs per-batch, I think the vast majority of augmentations are applied per-sample and prior to collation. That is, even if the samples in each batch are determined ahead of time, the augmentation function will only see one sample at a time. This is why I noted the distinction between (1) a collection of samples and (2) a collated set of tensors, both of which unfortunately have been given the moniker "batch". Augmentation usually happens before (1) or between (1) and (2).

@darsnack
Copy link
Member

Yes, to be clear, I didn't mean that augmentation happens on the full batch. It is still a per-sample operation. There are multiple ways to apply an augmentation. It can either be distinct from training where an augmentation f is applied on f(getobs(data, i)) for all i (equivalent to "before (1)" above). This can be something like normalizing the data by centering it or resizing images. You could also do f(getobs(eachbatch(data, batch_size), i)) for all i. This is more like "between (1) and (2)." The augmentation is still per-sample within a batch. An example of this might be random noise or an adversarial perturbation. In the case of the latter, postponing the augmentation is desired so it can run on the GPU.

Still, in both cases, batching is part of training (so I agree with @lorenzoh).

You'll notice that if the augmentations are per-sample, then it doesn't really matter semantically whether this is done when you fetch sample i prior to batching, or whether it is done on sample j within a batch. To me, it's a trade-off between flexibility and being opinionated. At the most flexible end of the spectrum, an augmentation is just applying a function to a collection of samples. If a batch behaves like any collection of samples, then you can give the user the choice of when to augment. On the other hand, you can make a distinction between data pre- and post-batching, and only allow augmentation to one but not the other. This has the advantage of being able to encapsulate the augmentation strategy as part of the "method" used to solve the problem. That's opinionated, but users might appreciate that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants