From 82c82fffdbed296a2ec808857897f23f30a9b31b Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 7 Oct 2024 16:07:36 +0200 Subject: [PATCH] Made the calling of teh compiled SDFG more uniform. --- src/jace/translated_jaxpr_sdfg.py | 81 ++++++++++++++++++++++++++----- 1 file changed, 68 insertions(+), 13 deletions(-) diff --git a/src/jace/translated_jaxpr_sdfg.py b/src/jace/translated_jaxpr_sdfg.py index 2347fdb..40bbd78 100644 --- a/src/jace/translated_jaxpr_sdfg.py +++ b/src/jace/translated_jaxpr_sdfg.py @@ -136,15 +136,18 @@ class CompiledJaxprSDFG: def sdfg(self) -> dace.SDFG: # noqa: D102 [undocumented-public-method] return self.compiled_sdfg.sdfg - def __call__( + def _construct_csdfg_args( self, flat_call_args: Sequence[Any], - ) -> list[jax.Array]: + ) -> dict[str, Any]: """ - Run the compiled SDFG using the flattened input. + Create the calling arguments from `flat_call_args`. - The function will not perform flattening of its input nor unflattening of - the output. + The function will collect the already flattened arguments into a `dict`. + Furthermore, it will allocate the buffers that are used for the return values + and add them to the `dict` as well. + The `dict` can then be passed to `self._call_csdfg()` to invoke the compiled + SDFG. Args: flat_call_args: Flattened input arguments. @@ -154,42 +157,94 @@ def __call__( f"Expected {len(self.input_names)} flattened arguments, but got {len(flat_call_args)}." ) - sdfg_call_args: dict[str, Any] = {} + csdfg_call_args: dict[str, Any] = {} for in_name, in_val in zip(self.input_names, flat_call_args): # TODO(phimuell): Implement a stride matching process. if util.is_jax_array(in_val): if not util.is_fully_addressable(in_val): raise ValueError(f"Passed a not fully addressable JAX array as '{in_name}'") in_val = in_val.__array__() # noqa: PLW2901 [redefined-loop-name] # JAX arrays do not expose the __array_interface__. - sdfg_call_args[in_name] = in_val + csdfg_call_args[in_name] = in_val # Allocate the output arrays. # In DaCe the output arrays are created by the `CompiledSDFG` calls and all # calls share the same arrays. In JaCe the output arrays are distinct. arrays = self.sdfg.arrays for output_name in self.output_names: - sdfg_call_args[output_name] = dace_data.make_array_from_descriptor(arrays[output_name]) + csdfg_call_args[output_name] = dace_data.make_array_from_descriptor(arrays[output_name]) - assert len(sdfg_call_args) == len(self.compiled_sdfg.argnames), ( + assert len(csdfg_call_args) == len(self.compiled_sdfg.argnames), ( "Failed to construct the call arguments," f" expected {len(self.compiled_sdfg.argnames)} but got {len(flat_call_args)}." - f"\nExpected: {self.compiled_sdfg.argnames}\nGot: {list(sdfg_call_args.keys())}" + f"\nExpected: {self.compiled_sdfg.argnames}\nGot: {list(csdfg_call_args.keys())}" ) + return csdfg_call_args + + def _call_csdfg( + self, + csdfg_call_args: dict[str, Any], + ) -> None: + """ + Calls the underlying SDFG with the data in `csdfg_call_args`. + + This will forward the arguments directly to the compiled SDFG object. + See `self._construct_csdfg_args()` for how to construct the `dict` + and `self._extract_return_values()` for how to get the output values back. + + Args: + csdfg_call_args: The required arguments to call the compiled sdfg. + """ + assert len(csdfg_call_args) == len(self.compiled_sdfg.argnames) - # Calling the SDFG with dace.config.temporary_config(): dace.Config.set("compiler", "allow_view_arguments", value=True) - self.compiled_sdfg(**sdfg_call_args) + self.compiled_sdfg(**csdfg_call_args) + + def _extract_return_values( + self, + csdfg_call_args: dict[str, Any], + ) -> list[jax.Array]: + """ + Extract the return values and return the flattened version. + + JaCe allocates the buffer for the return value outside the SDFG and passes + them as arguments, see `self._construct_csdfg_args()` and `self._call_csdfg()`. + This function will extract these values and return them in the flattened order. + Furthermore, the buffer will be transferred to a `jax.Array` object. + Args: + csdfg_call_args: Collection of the arguments passed to the compiled SDFG. + + Note: + After this function returns accessing any element in `csdfg_call_args` + is undefined behaviour. + """ # DaCe writes the results either into CuPy or NumPy arrays. For compatibility # with JAX we will now turn them into `jax.Array`s. Note that this is safe # because we created these arrays in this function explicitly. Thus when # this function ends, there is no writable reference to these arrays left. return [ - util.move_into_jax_array(sdfg_call_args[output_name]) + util.move_into_jax_array(csdfg_call_args[output_name]) for output_name in self.output_names ] + def __call__( + self, + flat_call_args: Sequence[Any], + ) -> list[jax.Array]: + """ + Run the compiled SDFG using the flattened input. + + The function will not perform flattening of its input nor unflattening of + the output. + + Args: + flat_call_args: Flattened input arguments. + """ + csdfg_call_args = self._construct_csdfg_args(flat_call_args) + self._call_csdfg(csdfg_call_args) + return self._extract_return_values(csdfg_call_args) + def compile_jaxpr_sdfg(tsdfg: TranslatedJaxprSDFG) -> dace_csdfg.CompiledJaxprSDFG: """Compile `tsdfg` and return a `CompiledJaxprSDFG` object with the result."""