Skip to content

Commit

Permalink
[Pallas] Fix shard_axis in dma_start interpret mode rule.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 703192497
  • Loading branch information
justinjfu authored and hawkinsp committed Dec 5, 2024
1 parent 7e6620a commit 259194a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def dma_start_discharge_rule(in_avals, out_avals,
if device_id_len > 1 or len(nonempty_axes) > 1:
raise NotImplementedError("Meshes with more than 1 named dimension not "
"implemented in dma_start_p")
shard_axis = nonempty_axes[0].name
shard_axis = nonempty_axes[0]
my_axis = jax.lax.axis_index(shard_axis)
else:
raise ValueError(f"Unknown device_id_type: {device_id_type}")
Expand Down

0 comments on commit 259194a

Please sign in to comment.