Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[replica-parallel] Add replica-parallel saving #1320

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gspschmid
Copy link
Contributor

@gspschmid gspschmid commented Nov 11, 2024

(Follow-up to #1319)

Adds "replica-parallel" saving in which each replica of a shard saves an equally-sized slice. In effect, this drives down the time spent on saving checkpoints as data-parallelism increases.

Motivation: Depending on their sharding and replication, JAX arrays may consist of multiple shards. In case of replication each shard carries a distinct replica_id, distinguishing the copies of the same logical shard from one another. Orbax's current behavior is to save the same replica_id-copy for all shards of all arrays ("single-replica" saving). In the presence of replication this is suboptimal, since the work could be parallelized across all replicas.

This PR adapts ArrayHandler to operate in "replica-parallel" mode (use_replica_parallel=True). In this mode we determine the first axis for which an array's shards are evenly-divisible by the number of replicas R. If such an axis exists, we will then assign each replica to save one R-th of the overall shard. Otherwise we fall back to "single-replica" saving.


While this patch is mostly relevant to large-scale training across many hosts, the following self-contained example illustrates the difference between existing "single-replica" saving and "replica-parallel" saving:
https://gist.github.com/gspschmid/41a78da35c0b14fafaff7bed3d52c5bc

For simplicity, I ran this example in a single process controlling 8 GPUs to save a fully-replicated array of ~4.3GBs. Attached are the output and profile for the last iteration.

Single-replica (Orbax v0.9.0):

=== iteration 5/5 ===
[gc] 0.91s
[save (offload)] 0.09s
[save (persist)] 5.42s
image

Replica-parallel (this PR):

=== iteration 5/5 ===
[gc] 1.01s
[save (offload)] 0.05s
[save (persist)] 2.27s
image

A few observations:

  • As expected, replica-parallel distributes the D2H transfers across all 8 devices. Since I am running only a single process the persistence step is still serialized; with multiple processes this would be parallelized as well.
  • Since I am running only a single process on a single host the absolute speed-up of the D2H (transfer) step is small, around 50ms/save. Note that the D2H already takes up relatively little time (~100ms), and we "merely" reduce it by 2x due to contention for bus bandwidth. Note that in a multi-host setting we should experience close to perfect scaling, as each host offloads to local host memory.
  • Surprisingly, overall save time (including the commit step of tensorstore) is reduced by 2x. I am not sure why, but suspect that this is related to us effectively splitting each existing ts[index].write(...) into 8 smaller ones. The microbenchmark uses tensorstore's memory:// backend, so it remains to be seen whether this improvement holds up in realistic use cases.
  • Garbage collection has an outsize impact on this microbenchmark (regardless of whether we use replica-parallel or not). I had not noticed this effect in earlier rounds of benchmarking Orbax and tensorstore, so I am wondering whether this might be a recent regression.

@gspschmid
Copy link
Contributor Author

@cpgaffney1

@cpgaffney1
Copy link
Collaborator

Apologies I did not get a chance to fully review this today, will take a closer look tomorrow.

By garbage collection you are referring to gc.collect(), right? (not the Orbax CheckpointManager garbage collection of old steps, that should not be happening here)

It would be very helpful to see a small example like this run on multiple processes, since that's where the true value of the feature would lie.

@gspschmid
Copy link
Contributor Author

By garbage collection you are referring to gc.collect(), right? (not the Orbax CheckpointManager garbage collection of old steps, that should not be happening here)

Correct

It would be very helpful to see a small example like this run on multiple processes, since that's where the true value of the feature would lie.

Agreed, will try to get to this later this week.

except StopIteration:
return None
local_shape = tuple(
axis_size // (replica_count if axis_index == axis else 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect reference to axis_size and axis_index will raise type checker error since they may not be initialized.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How so? They are both in scope as part of the generator expression? I.e., this statement is equivalent to

gen = (
    axis_size // (replica_count if axis_index == axis else 1)
    for axis_index, axis_size in enumerate(shard0.data.shape)
)
local_shape = tuple(gen)

axis = next(
axis_index
for axis_index, axis_size in enumerate(shard0.data.shape)
if replica_count > 1 and axis_size % replica_count == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems fairly easy to construct examples for which this approach does not work. e.g. let's say there are 8 devices, global array of (11, 4, 7), where the middle axis is sharded 4 ways, giving 2 replicas per shard, shape (11, 1, 7). replica_count does not evenly divide any dimension so we would have to fall back on the old approach.

I guess we are just assuming that realistic use cases will not use weird dimensions like this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, that's a limitation of the current approach.

I haven't looked into it too deeply, but in principle it seems like we just allow the last replica of each shard to write a slice that's not exactly of local_shape but shorter along the chosen axis. We'd then choose the slicing axis in a way that prefers an evenly-divisible axis, but may fall back to a non-evenly-divisible one.

We'd need to check how much that degrades tensorstore write/read performance. (Perhaps you've experimented with that before in the context of fragments?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've never experimented with uneven fragments when writing, so I also don't know about what the performance impacts would be, if any. I agree this can be a follow-up. Maybe worth leaving a small TODO with a note about this aspect.

@gspschmid
Copy link
Contributor Author

Rebased on the newer version of #1319 and addressed comments -- will rebase once more once #1319 is merged and ping you.

@gspschmid
Copy link
Contributor Author

Ran this in a single-host, but multi-process setting today: gist

Timings end up very similar to the single-process example above, i.e. ~100ms/save with baseline Orbax (single-replica) and ~50ms for replica-parallel.

Note that the adapted microbenchmark only creates a single CheckpointManager and doesn't operate on a new jax.Array for every iteration. That seems to have reduced the GC overhead significantly (only ~30ms now).

Besides that, I'm noticing that we're spending a fair amount of time in the various sync_global_devicescalls -- with replica-parallel these end up taking ~50% of the exposed save time (~30ms):
image

(I still intend to run a real benchmark on a cluster, but this might have to wait til next week.)

@cpgaffney1
Copy link
Collaborator

Please rebase onto head and I will take a final look at this CL before merging internally.

On the sync_global_devices question, sometimes it looks like the barrier is taking a long time even though it is really just a non-primary process waiting while the leader does real work. I'm not familiar enough with your profiler to know whether that is the case.

We can take a look internally at this, since it's not the first time the possibility has been raised. Intuitively though there is some floor on the amount of time barrier syncs can take, which should scale with the number of devices (one major advantage of using single-controller model). There's a number of things we could move to background threads to minimize the overall number of barriers though, like directory creations - that is something we're starting to work on.

@gspschmid
Copy link
Contributor Author

On the sync_global_devices question, sometimes it looks like the barrier is taking a long time even though it is really just a non-primary process waiting while the leader does real work. I'm not familiar enough with your profiler to know whether that is the case.

The top (i.e. highlighted) process in the screenshot above should be the primary (Xla:#global=0).

Intuitively though there is some floor on the amount of time barrier syncs can take, which should scale with the number of devices

There's a number of things we could move to background threads to minimize the overall number of barriers though, like directory creations - that is something we're starting to work on.

Agreed, and indeed it seems like the lowest hanging fruit might be to elide some of these barriers -- based on the above profile there are five per save. So it sounds great that you're already looking into that!

Please rebase onto head and I will take a final look at this CL before merging internally.

Will do (likely on Monday)! Thanks again for helping push this through :-)

@gspschmid
Copy link
Contributor Author

Rebased on main, PTAL, @cpgaffney1 !

@cpgaffney1
Copy link
Collaborator

LGTM - if I don't finish merging today, will finish tomorrow.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants