-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[replica-parallel] Add replica-parallel saving #1320
base: main
Are you sure you want to change the base?
Conversation
Apologies I did not get a chance to fully review this today, will take a closer look tomorrow. By garbage collection you are referring to It would be very helpful to see a small example like this run on multiple processes, since that's where the true value of the feature would lie. |
Correct
Agreed, will try to get to this later this week. |
checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py
Outdated
Show resolved
Hide resolved
checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py
Outdated
Show resolved
Hide resolved
checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py
Outdated
Show resolved
Hide resolved
checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py
Outdated
Show resolved
Hide resolved
except StopIteration: | ||
return None | ||
local_shape = tuple( | ||
axis_size // (replica_count if axis_index == axis else 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect reference to axis_size and axis_index will raise type checker error since they may not be initialized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How so? They are both in scope as part of the generator expression? I.e., this statement is equivalent to
gen = (
axis_size // (replica_count if axis_index == axis else 1)
for axis_index, axis_size in enumerate(shard0.data.shape)
)
local_shape = tuple(gen)
checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py
Outdated
Show resolved
Hide resolved
axis = next( | ||
axis_index | ||
for axis_index, axis_size in enumerate(shard0.data.shape) | ||
if replica_count > 1 and axis_size % replica_count == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems fairly easy to construct examples for which this approach does not work. e.g. let's say there are 8 devices, global array of (11, 4, 7), where the middle axis is sharded 4 ways, giving 2 replicas per shard, shape (11, 1, 7). replica_count does not evenly divide any dimension so we would have to fall back on the old approach.
I guess we are just assuming that realistic use cases will not use weird dimensions like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, that's a limitation of the current approach.
I haven't looked into it too deeply, but in principle it seems like we just allow the last replica of each shard to write a slice that's not exactly of local_shape
but shorter along the chosen axis. We'd then choose the slicing axis in a way that prefers an evenly-divisible axis, but may fall back to a non-evenly-divisible one.
We'd need to check how much that degrades tensorstore write/read performance. (Perhaps you've experimented with that before in the context of fragments?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've never experimented with uneven fragments when writing, so I also don't know about what the performance impacts would be, if any. I agree this can be a follow-up. Maybe worth leaving a small TODO with a note about this aspect.
a1d3cb2
to
893532d
Compare
Ran this in a single-host, but multi-process setting today: gist Timings end up very similar to the single-process example above, i.e. ~100ms/save with baseline Orbax (single-replica) and ~50ms for replica-parallel. Note that the adapted microbenchmark only creates a single CheckpointManager and doesn't operate on a new jax.Array for every iteration. That seems to have reduced the GC overhead significantly (only ~30ms now). Besides that, I'm noticing that we're spending a fair amount of time in the various (I still intend to run a real benchmark on a cluster, but this might have to wait til next week.) |
Please rebase onto head and I will take a final look at this CL before merging internally. On the We can take a look internally at this, since it's not the first time the possibility has been raised. Intuitively though there is some floor on the amount of time barrier syncs can take, which should scale with the number of devices (one major advantage of using single-controller model). There's a number of things we could move to background threads to minimize the overall number of barriers though, like directory creations - that is something we're starting to work on. |
The top (i.e. highlighted) process in the screenshot above should be the primary (
Agreed, and indeed it seems like the lowest hanging fruit might be to elide some of these barriers -- based on the above profile there are five per save. So it sounds great that you're already looking into that!
Will do (likely on Monday)! Thanks again for helping push this through :-) |
893532d
to
08dffea
Compare
Rebased on main, PTAL, @cpgaffney1 ! |
LGTM - if I don't finish merging today, will finish tomorrow. |
(Follow-up to #1319)
Adds "replica-parallel" saving in which each replica of a shard saves an equally-sized slice. In effect, this drives down the time spent on saving checkpoints as data-parallelism increases.
Motivation: Depending on their sharding and replication, JAX arrays may consist of multiple shards. In case of replication each shard carries a distinct
replica_id
, distinguishing the copies of the same logical shard from one another. Orbax's current behavior is to save the samereplica_id
-copy for all shards of all arrays ("single-replica" saving). In the presence of replication this is suboptimal, since the work could be parallelized across all replicas.This PR adapts
ArrayHandler
to operate in "replica-parallel" mode (use_replica_parallel=True
). In this mode we determine the first axis for which an array's shards are evenly-divisible by the number of replicasR
. If such an axis exists, we will then assign each replica to save one R-th of the overall shard. Otherwise we fall back to "single-replica" saving.While this patch is mostly relevant to large-scale training across many hosts, the following self-contained example illustrates the difference between existing "single-replica" saving and "replica-parallel" saving:
https://gist.github.com/gspschmid/41a78da35c0b14fafaff7bed3d52c5bc
For simplicity, I ran this example in a single process controlling 8 GPUs to save a fully-replicated array of ~4.3GBs. Attached are the output and profile for the last iteration.
Single-replica (Orbax v0.9.0):
Replica-parallel (this PR):
A few observations:
ts[index].write(...)
into 8 smaller ones. The microbenchmark uses tensorstore'smemory://
backend, so it remains to be seen whether this improvement holds up in realistic use cases.