Skip to content

Commit

Permalink
Made the calling of teh compiled SDFG more uniform.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Oct 7, 2024
1 parent 3be9f36 commit 82c82ff
Showing 1 changed file with 68 additions and 13 deletions.
81 changes: 68 additions & 13 deletions src/jace/translated_jaxpr_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,18 @@ class CompiledJaxprSDFG:
def sdfg(self) -> dace.SDFG: # noqa: D102 [undocumented-public-method]
return self.compiled_sdfg.sdfg

def __call__(
def _construct_csdfg_args(
self,
flat_call_args: Sequence[Any],
) -> list[jax.Array]:
) -> dict[str, Any]:
"""
Run the compiled SDFG using the flattened input.
Create the calling arguments from `flat_call_args`.
The function will not perform flattening of its input nor unflattening of
the output.
The function will collect the already flattened arguments into a `dict`.
Furthermore, it will allocate the buffers that are used for the return values
and add them to the `dict` as well.
The `dict` can then be passed to `self._call_csdfg()` to invoke the compiled
SDFG.
Args:
flat_call_args: Flattened input arguments.
Expand All @@ -154,42 +157,94 @@ def __call__(
f"Expected {len(self.input_names)} flattened arguments, but got {len(flat_call_args)}."
)

sdfg_call_args: dict[str, Any] = {}
csdfg_call_args: dict[str, Any] = {}
for in_name, in_val in zip(self.input_names, flat_call_args):
# TODO(phimuell): Implement a stride matching process.
if util.is_jax_array(in_val):
if not util.is_fully_addressable(in_val):
raise ValueError(f"Passed a not fully addressable JAX array as '{in_name}'")
in_val = in_val.__array__() # noqa: PLW2901 [redefined-loop-name] # JAX arrays do not expose the __array_interface__.
sdfg_call_args[in_name] = in_val
csdfg_call_args[in_name] = in_val

# Allocate the output arrays.
# In DaCe the output arrays are created by the `CompiledSDFG` calls and all
# calls share the same arrays. In JaCe the output arrays are distinct.
arrays = self.sdfg.arrays
for output_name in self.output_names:
sdfg_call_args[output_name] = dace_data.make_array_from_descriptor(arrays[output_name])
csdfg_call_args[output_name] = dace_data.make_array_from_descriptor(arrays[output_name])

assert len(sdfg_call_args) == len(self.compiled_sdfg.argnames), (
assert len(csdfg_call_args) == len(self.compiled_sdfg.argnames), (
"Failed to construct the call arguments,"
f" expected {len(self.compiled_sdfg.argnames)} but got {len(flat_call_args)}."
f"\nExpected: {self.compiled_sdfg.argnames}\nGot: {list(sdfg_call_args.keys())}"
f"\nExpected: {self.compiled_sdfg.argnames}\nGot: {list(csdfg_call_args.keys())}"
)
return csdfg_call_args

def _call_csdfg(
self,
csdfg_call_args: dict[str, Any],
) -> None:
"""
Calls the underlying SDFG with the data in `csdfg_call_args`.
This will forward the arguments directly to the compiled SDFG object.
See `self._construct_csdfg_args()` for how to construct the `dict`
and `self._extract_return_values()` for how to get the output values back.
Args:
csdfg_call_args: The required arguments to call the compiled sdfg.
"""
assert len(csdfg_call_args) == len(self.compiled_sdfg.argnames)

# Calling the SDFG
with dace.config.temporary_config():
dace.Config.set("compiler", "allow_view_arguments", value=True)
self.compiled_sdfg(**sdfg_call_args)
self.compiled_sdfg(**csdfg_call_args)

def _extract_return_values(
self,
csdfg_call_args: dict[str, Any],
) -> list[jax.Array]:
"""
Extract the return values and return the flattened version.
JaCe allocates the buffer for the return value outside the SDFG and passes
them as arguments, see `self._construct_csdfg_args()` and `self._call_csdfg()`.
This function will extract these values and return them in the flattened order.
Furthermore, the buffer will be transferred to a `jax.Array` object.
Args:
csdfg_call_args: Collection of the arguments passed to the compiled SDFG.
Note:
After this function returns accessing any element in `csdfg_call_args`
is undefined behaviour.
"""
# DaCe writes the results either into CuPy or NumPy arrays. For compatibility
# with JAX we will now turn them into `jax.Array`s. Note that this is safe
# because we created these arrays in this function explicitly. Thus when
# this function ends, there is no writable reference to these arrays left.
return [
util.move_into_jax_array(sdfg_call_args[output_name])
util.move_into_jax_array(csdfg_call_args[output_name])
for output_name in self.output_names
]

def __call__(
self,
flat_call_args: Sequence[Any],
) -> list[jax.Array]:
"""
Run the compiled SDFG using the flattened input.
The function will not perform flattening of its input nor unflattening of
the output.
Args:
flat_call_args: Flattened input arguments.
"""
csdfg_call_args = self._construct_csdfg_args(flat_call_args)
self._call_csdfg(csdfg_call_args)
return self._extract_return_values(csdfg_call_args)


def compile_jaxpr_sdfg(tsdfg: TranslatedJaxprSDFG) -> dace_csdfg.CompiledJaxprSDFG:
"""Compile `tsdfg` and return a `CompiledJaxprSDFG` object with the result."""
Expand Down

0 comments on commit 82c82ff

Please sign in to comment.