[RFC] Online and offline adaptive batching in torchtune. #2199
+146
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
About: #2191
Enable adaptive batching in torchtune
Intuition
It is useful to set a maximum batch size that is not causing OOM for a given compute. Also, it might be interesting to increase batch size gradually during the training process. Both techniques are examples of adaptive batching.
What are we currently doing?
We don't have an approach to adaptive batching in both offline (set it before training) and online (update it during the training) paradigms.
How will it look for the user?
I think the best design at this point is the addition of a new parameter in the recipe. For an offline approach, we might just add a field like this:
use_maximum_batch_size: True
It will be exposed with the default parameter set to
False
in every config.In the case of an online approach, we want a stronger possibility. As the standard non-empirical method for adaptive batching does not exist, we might provide the possibility to define the way in which batch size will be increased on users' hands in the following way. Let's add a non-required parameter in the recipe:
batch_increase_function: ...
Where a user will provide a function with conditions on increasing batch size and value on which batch size will be increased. Also, we do an offline procedure before training to understand the maximum bound for batch size to handle the case that users' increasing function will cause OOM. By definition,
batch_increase_function
will accept 2 arguments:epoch
andcurrent_batch_size
, and it will return a number on what we increase the batch size. An example:In the recipe, on each epoch, the following check will be done:
On some ways of adapting batching
Online
Currently, only emperical methods has shown efficiency in real tasks. Speaking about "clever" and non-emperical ways, there were no real works that showed greater perfomance then emperical increasing.
Basically, non-emperical ways will use some quailities of optimizer (for example stochastic approximation of upper bound on Breggmans' distance) and addition of such will require thoughtfull consideration.
Offline
Main approach is based on idea "Decrease until OOM". The general pipeline looks like this:
This is fine approach but not really optimal, better way is to try binary search on answer:
Attempt to predict it without OOMs. The most interesting approach is OOM-less attempt to predict it. Probably, we can do some rough rate in following way:
There other ways to get this rate either, like:
(vram - model_size) / (forward_back_ward_size)
For all ways we round to the closest lower power of two.