Skip to content

Commit

Permalink
[mosaic] Directly build IR in _device_id_to_logical, rather than usin…
Browse files Browse the repository at this point in the history
…g lower_fun.

This is just as simple and faster.

PiperOrigin-RevId: 690196495
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Oct 26, 2024
1 parent 2b01aff commit 6f3c012
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2801,16 +2801,15 @@ def _device_id_to_logical(
# Mesh means we are passed the mesh coordinates for the device
device_ids = tree_util.tree_leaves(device_id)
mesh_strides = ctx.lowering_context.mesh_context.mesh_strides
def _linearize_mesh_indices(*indices):
return sum(a * b for a, b in zip(indices, mesh_strides))
lower_ctx = LoweringRuleContext(
lowering_context=ctx.lowering_context,
avals_in=[pallas_core.index_map_grid_aval] * len(device_ids),
avals_out=[pallas_core.index_map_grid_aval],
block_shapes=(None,) * len(device_ids),

i32 = ir.IntegerType.get_signless(32)
return functools.reduce(
arith.addi,
(
arith.muli(a, arith.constant(i32, b))
for a, b in zip(device_ids, mesh_strides)
),
)
return lower_fun(_linearize_mesh_indices, multiple_results=False)(
lower_ctx, *device_ids)
elif device_id_type is tpu_primitives.DeviceIdType.LOGICAL:
return device_id
raise NotImplementedError(f"Unsupported device id type: {device_id_type}")
Expand Down

0 comments on commit 6f3c012

Please sign in to comment.