torchelastic requires you to implement a state
object and a train_step
function.
For details on what these are refer to how torch elastic works.
While going through the sections below, refer to the imagenet example for more complete implementation details.
The State
object has two categories of methods that need to be implemented:
synchronization and persistence.
Lets take a look at synchronization first. The sync
method is responsible for
ensuring that all workers get a consistent view of state
. It is called at
startup as well as on each event that potentially leaves the workers out of sync,
for instance, on membership changes and rollback events. Torchelastic relies on
the sync()
method for state
recovery from surviving workers (e.g. when
there are membership changes, either due to worker failure or elasticity,
the new workers receive the most up-to-date state
from one of the surviving
workers - usually the one that has the most recent state
- we call this worker
the most tenured worker).
Things you should consider doing in sync
are:
- Broadcasting global parameters/data from a particular worker (e.g. rank 0).
- (re)Initializing data loaders based on markers (e.g. last known start index).
- (re)Initializing the model.
IMPORTANT:
state.sync()
is not meant for synchronizing steps in training. For instance you should not be synchronizing weights (e.g .all-reduce model weights for synchronous SGD). These type of collectives operations belong in thetrain_step
.
All workers initially create the state
object with the same constructor arguments.
We refer to this initial state as S_0
and assume that any worker is able to create
S_0
without needing any assistance from torchelastic. Essentially S_0
is the bootstrap
state. This concept will become important in the next sections when talking about
state persistence (rollbacks and checkpoints).
You do not have to implement these methods if you do not want rollbacks from failed
train_steps
torchelastic has the ability to rollback a state if a train_step
fails to
execute successfully, which may result in the state
object being left partially
updated. It relies on a properly implemented capture_snapshot()
and apply_snapshot()
methods of the state
to ensure that the state
is restored to before the
faulty train_step
.
The capture_snapshot()
method, as the name implies, takes a snapshot of the state
and returns the necessary information to be able to restore
the state
object. You may return any object from capture_snapshot()
so long as you
can use it in the apply_snapshot(snapshot)
method. A possible implementation of a
rollback is:
snapshot = state.capture_snapshot()
try:
train_step(state)
except RuntimeError:
state.apply_snapshot(snapshot)
state.sync()
NOTE: Since certain fields of the
state
may need to get re-initialized, torchelastic calls thesync()
method. For instance, data loaders may need to be restarted as their iterators may end up in a corrupted state when thetrain_step
does not exit successfully.
Notice that the apply method is called on the existing state
object, this implies
that an efficient implementation of snapshot
should only return mutable, stateful
data. Immutable fields or fields that can be derived from other member variables or
restored in the sync
method need not be included in the snapshot.
By default the capture_snapshot()
method returns None
and the apply_snapshot()
method
is a pass
, which essentially means "rollback not supported".
IMPORTANT: The
apply_snapshot
object should make no assumptions about whichstate
object it is called on (e.g. the values of the member variables). That is, applying asnapshot
to any state followed bystate.sync()
should effectively restore the state object to when the correspondingcapture_snapshot
method was called. A good rule of thumb is that theapply_snapshot
should act more like aset
method rather than anupdate
method.
You do not have to implement these methods if you do not plan on using checkpointing.
Much like the capture_snapshot
and apply_snapshot
, the save
and load
methods form a pair.
They are responsible for persisting and restoring the state
object to and from
a stream
which is a file-like object
that is compatible with pytorch.save.
torchelastic relies on these methods to provide checkpoint functionality for your job.
We encourage users to use
torch.save
andtorch.load
methods when implementingsave
andload
methods of theirstate
class.
NOTE: The default implementations of
save
andload
usecapture_snapshot
andapply_snapshot
The train_step
is a function that takes state
as a single argument
and carries out a partition of the overall training job.
This is your unit of work and it is up to you to define what
a unit is. When deciding what your unit of work should be, keep in mind the
following:
- Rollbacks and checkpoints are done at
train_step
granularity. This means that torchelastic can only recover to the last successfultrain_step
Any failures during the train_step are not recoverable. - A
train_step
iteration in thetrain_loop
has overhead due to the work that goes in ensuring that your job is fault-tolerant and elastic. How much overhead depends on your configurations for rollbacks and checkpoints as well as how expensive yoursnapshot
,apply
,save
andload
functions are.
In most cases, your job naturally lends itself to an obvious
train_step
. The most canonical one for many training jobs is to map the processing of a mini-batch of training data to atrain_step
.
There is a trade-off to be made between how much work you are willing to lose versus how much overhead you want to pay for that security.
Now that you have state
and train_step
implementations all that remains
is to bring everything together and implement a main
that will execute your
training. Your script should initialize torchelastic's coordinator
, create
your state
object, and call the train_loop
. Below is a simple example:
import torchelastic
from torchelastic.p2p import CoordinatorP2P
if __name__ == "__main__":
min_workers = 1
max_workers = 1
run_id = 1234
etcd_endpoint = "localhost:2379"
state = MyState()
coordinator = CoordinatorP2P(
c10d_backend="gloo",
init_method=f"etcd://{etcd_endpoint}/{run_id}?min_workers={min_workers}&max_workers={max_workers}",
max_num_trainers=max_workers,
process_group_timeout=60000,
)
torchelastic.train(coordinator, train_step, state)
See metrics documentation.
See checkpoint documentation
See rendezvous documentation