Skip to content

Commit

Permalink
First, chunk of work.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed May 27, 2024
1 parent 22c1441 commit bf6132e
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 330 deletions.
7 changes: 0 additions & 7 deletions src/jace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,6 @@
from .jax import grad, jacfwd, jacrev, jit


# In Jax `float32` is the main datatype, and they go to great lengths to avoid
# some aggressive [type promotion](https://jax.readthedocs.io/en/latest/type_promotion.html).
# However, in this case we will have problems when we call the SDFG, for some reasons
# `CompiledSDFG` does not work in that case correctly, thus we enable it now globally.
_jax.config.update("jax_enable_x64", True)


__all__ = [
"__author__",
"__copyright__",
Expand Down
40 changes: 20 additions & 20 deletions src/jace/jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,19 @@
from jace.jax import stages


__all__ = [
"grad",
"jit",
"jacfwd",
"jacrev",
]


@overload
def jit(
fun: Literal[None] = None,
/,
sub_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> Callable[[Callable], stages.JaceWrapped]: ...

Expand All @@ -32,15 +40,15 @@ def jit(
def jit(
fun: Callable,
/,
sub_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> stages.JaceWrapped: ...


def jit(
fun: Callable | None = None,
/,
sub_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> stages.JaceWrapped | Callable[[Callable], stages.JaceWrapped]:
"""Jace's replacement for `jax.jit` (just-in-time) wrapper.
Expand All @@ -50,36 +58,28 @@ def jit(
In addition it accepts some Jace specific arguments.
Args:
sub_translators: Use these subtranslators for the lowering to DaCe.
primitive_translators: Use these primitive translators for the lowering to SDFG.
Notes:
If no subtranslators are specified then the ones that are currently active,
i.e. the output of `get_regsitered_primitive_translators()`, are used.
After construction changes to the passed `sub_translators` have no effect on the returned object.
If no translators are specified the currently ones currently inside the global registry are used.
After construction changes to the passed `primitive_translators` have no effect on the returned object.
"""
if kwargs:
# TODO(phimuell): Add proper name verification and exception type.
raise NotImplementedError(
f"The following arguments of 'jax.jit' are not yet supported by jace: {', '.join(kwargs.keys())}."
f"The following arguments to 'jace.jit' are not yet supported: {', '.join(kwargs)}."
)

def wrapper(f: Callable) -> stages.JaceWrapped:
jace_wrapper = stages.JaceWrapped(
fun=f,
sub_translators=(
primitive_translators=(
translator.managing._PRIMITIVE_TRANSLATORS_DICT
if sub_translators is None
else sub_translators
if primitive_translators is None
else primitive_translators
),
jit_ops=kwargs,
jit_options=kwargs,
)
return functools.update_wrapper(jace_wrapper, f)

return wrapper if fun is None else wrapper(fun)


__all__ = [
"grad",
"jit",
"jacfwd",
"jacrev",
]
139 changes: 73 additions & 66 deletions src/jace/jax/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,47 +34,49 @@
from jace.jax import translation_cache as tcache
from jace.optimization import CompilerOptions
from jace.translator import post_translation as ptrans
from jace.util import dace_helper as jdace
from jace.util import dace_helper


class JaceWrapped(tcache.CachingStage):
class JaceWrapped(tcache.CachingStage["JaceLowered"]):
"""A function ready to be specialized, lowered, and compiled.
This class represents the output of functions such as `jace.jit()`.
Calling it results in jit (just-in-time) lowering, compilation, and execution.
It can also be explicitly lowered prior to compilation, and the result compiled prior to execution.
You should not create `JaceWrapped` instances directly, instead you should use `jace.jit`.
Calling it results in jit (just-in-time) lowering, compilation and execution.
It is also possible to lower the function explicitly by calling `self.lower()`.
This function can be composed with other Jax transformations.
Todo:
- Handle pytrees.
- Handle all options to `jax.jit`.
Note:
The tracing of function will always happen with enabled `x64` mode, which is implicitly
and temporary activated during tracing. Furthermore, the disable JIT config flag is ignored.
"""

_fun: Callable
_sub_translators: dict[str, translator.PrimitiveTranslator]
_jit_ops: dict[str, Any]
_primitive_translators: dict[str, translator.PrimitiveTranslator]
_jit_options: dict[str, Any]

def __init__(
self,
fun: Callable,
sub_translators: Mapping[str, translator.PrimitiveTranslator],
jit_ops: Mapping[str, Any],
primitive_translators: Mapping[str, translator.PrimitiveTranslator],
jit_options: Mapping[str, Any],
) -> None:
"""Creates a wrapped jitable object of `fun`.
You should not create `JaceWrapped` instances directly, instead you should use `jace.jit`.
Args:
fun: The function that is wrapped.
sub_translators: The list of subtranslators that that should be used.
jit_ops: All options that we forward to `jax.jit`.
fun: The function that is wrapped.
primitive_translators: The list of subtranslators that that should be used.
jit_options: Options to influence the jit process.
"""
super().__init__()
# We have to shallow copy both the translator and the jit options.
# This prevents that any modifications affect `self`.
# Shallow is enough since the translators themselves are immutable.
self._sub_translators = dict(sub_translators)
self._jit_ops = dict(jit_ops)
self._primitive_translators = dict(primitive_translators)
self._jit_options = dict(jit_options)
self._fun = fun

def __call__(
Expand All @@ -84,35 +86,25 @@ def __call__(
) -> Any:
"""Executes the wrapped function, lowering and compiling as needed in one step."""

# TODO(phimuell): Handle the `disable_jit` context manager of Jax.

# This allows us to be composable with Jax transformations.
# If we are inside a traced context, then we forward the call to the wrapped function.
# This ensures that Jace is composable with Jax.
if util.is_tracing_ongoing(*args, **kwargs):
# TODO(phimuell): Handle the case of gradients:
# It seems that this one uses special tracers, since they can handle comparisons.
# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-autodiff
return self._fun(*args, **kwargs)

# TODO(phimuell): Handle static arguments correctly
# https://jax.readthedocs.io/en/latest/aot.html#lowering-with-static-arguments
lowered = self.lower(*args, **kwargs)
compiled = lowered.compile()
return compiled(*args, **kwargs)

@tcache.cached_translation
@tcache.cached_transition
def lower(
self,
*args: Any,
**kwargs: Any,
) -> JaceLowered:
"""Lower this function explicitly for the given arguments.
Performs the first two steps of the AOT steps described above,
i.e. transformation into Jaxpr and then to SDFG.
The result is encapsulated into a `Lowered` object.
Todo:
- Handle pytrees.
Performs the first two steps of the AOT steps described above, i.e. transformation into
Jaxpr and then to SDFG. The result is encapsulated into a `Lowered` object.
"""
if len(kwargs) != 0:
raise NotImplementedError("Currently only positional arguments are supported.")
Expand All @@ -122,12 +114,19 @@ def lower(
if not all((not util.is_array(arg)) or arg.flags["C_CONTIGUOUS"] for arg in args):
raise NotImplementedError("Currently can not handle strides beside 'C_CONTIGUOUS'.")

jaxpr = _jax.make_jaxpr(self._fun)(*args)
driver = translator.JaxprTranslationDriver(sub_translators=self._sub_translators)
trans_sdfg: translator.TranslatedJaxprSDFG = driver.translate_jaxpr(jaxpr)
ptrans.postprocess_jaxpr_sdfg(tsdfg=trans_sdfg, fun=self.wrapped_fun)
# The `JaceLowered` assumes complete ownership of `trans_sdfg`!
return JaceLowered(trans_sdfg)
# In Jax `float32` is the main datatype, and they go to great lengths to avoid
# some aggressive [type promotion](https://jax.readthedocs.io/en/latest/type_promotion.html).
# However, in this case we will have problems when we call the SDFG, for some reasons
# `CompiledSDFG` does not work in that case correctly, thus we enable it for the tracing.
with _jax.experimental.enable_x64():
driver = translator.JaxprTranslationDriver(
primitive_translators=self._primitive_translators
)
jaxpr = _jax.make_jaxpr(self._fun)(*args)
tsdfg: translator.TranslatedJaxprSDFG = driver.translate_jaxpr(jaxpr)
ptrans.postprocess_jaxpr_sdfg(tsdfg=tsdfg, fun=self.wrapped_fun)

return JaceLowered(tsdfg)

@property
def wrapped_fun(self) -> Callable:
Expand All @@ -137,42 +136,47 @@ def wrapped_fun(self) -> Callable:
def _make_call_description(
self,
*args: Any,
) -> tcache.CachedCallDescription:
) -> tcache.StageTransformationDescription:
"""This function computes the key for the `JaceWrapped.lower()` call.
Currently it is only able to handle positional argument and does not support static arguments.
The function will fully abstractify its input arguments.
This function is used by the cache to generate the key.
"""
fargs = tuple(tcache._AbstractCallArgument.from_value(x) for x in args)
return tcache.CachedCallDescription(stage_id=id(self), fargs=fargs)
call_args = tuple(tcache._AbstractCallArgument.from_value(x) for x in args)
return tcache.StageTransformationDescription(stage_id=id(self), call_args=call_args)


class JaceLowered(tcache.CachingStage["JaceCompiled"]):
"""Represents the original computation as an SDFG.
class JaceLowered(tcache.CachingStage):
"""Represents the original computation that was lowered to SDFG.
Although, `JaceWrapped` is composable with Jax transformations `JaceLowered` is not.
A user should never create such an object.
Todo:
- Handle pytrees.
"""

# `self` assumes complete ownership of the
_trans_sdfg: translator.TranslatedJaxprSDFG
_translated_sdfg: translator.TranslatedJaxprSDFG

def __init__(
self,
trans_sdfg: translator.TranslatedJaxprSDFG,
tsdfg: translator.TranslatedJaxprSDFG,
) -> None:
"""Constructs the lowered object."""
if not trans_sdfg.is_finalized:
"""Initialize the lowered object.
Args:
tsdfg: The lowered SDFG with metadata. Must be finalized.
Notes:
The passed `tsdfg` will be managed by `self`.
"""
if not tsdfg.is_finalized:
raise ValueError("The translated SDFG must be finalized.")
if trans_sdfg.inp_names is None:
raise ValueError("Input names must be defined.")
if trans_sdfg.out_names is None:
raise ValueError("Output names must be defined.")
super().__init__()
self._trans_sdfg = trans_sdfg
self._translated_sdfg = tsdfg

@tcache.cached_translation
@tcache.cached_transition
def compile(
self,
compiler_options: CompilerOptions | None = None,
Expand All @@ -192,11 +196,9 @@ def compile(
# We **must** deepcopy before we do any optimization.
# The reason is `self` is cached and assumed to be immutable.
# Since all optimizations works in place, we would violate this assumption.
tsdfg: translator.TranslatedJaxprSDFG = copy.deepcopy(self._trans_sdfg)
tsdfg: translator.TranslatedJaxprSDFG = copy.deepcopy(self._translated_sdfg)

# Must be the same as in `_make_call_description()`!
options = optimization.DEFAULT_OPTIMIZATIONS | (compiler_options or {})
optimization.jace_optimize(tsdfg=tsdfg, **options)
optimization.jace_optimize(tsdfg=tsdfg, **self._make_compiler_options(compiler_options))

return JaceCompiled(
csdfg=util.compile_jax_sdfg(tsdfg),
Expand All @@ -211,7 +213,7 @@ def compiler_ir(self, dialect: str | None = None) -> translator.TranslatedJaxprS
It is important that modifying this object in any ways is considered an error.
"""
if (dialect is None) or (dialect.upper() == "SDFG"):
return self._trans_sdfg
return self._translated_sdfg
raise ValueError(f"Unknown dialect '{dialect}'.")

def as_html(self, filename: str | None = None) -> None:
Expand All @@ -231,17 +233,22 @@ def as_sdfg(self) -> dace.SDFG:
def _make_call_description(
self,
compiler_options: CompilerOptions | None = None,
) -> tcache.CachedCallDescription:
) -> tcache.StageTransformationDescription:
"""This function computes the key for the `self.compile()` call.
The function only get one argument that is either a `dict` or a `None`, where `None` means `use default argument.
The function will construct a concrete description of the call using `(name, value)` pairs.
This function is used by the cache.
"""
# Must be the same as in `compile()`!
options = optimization.DEFAULT_OPTIMIZATIONS | (compiler_options or {})
fargs = tuple(sorted(options.items(), key=lambda X: X[0]))
return tcache.CachedCallDescription(stage_id=id(self), fargs=fargs)
options = self._make_compiler_options(compiler_options)
call_args = tuple(sorted(options.items(), key=lambda X: X[0]))
return tcache.StageTransformationDescription(stage_id=id(self), call_args=call_args)

def _make_compiler_options(
self,
compiler_options: CompilerOptions | None,
) -> CompilerOptions:
return optimization.DEFAULT_OPTIMIZATIONS | (compiler_options or {})


class JaceCompiled:
Expand All @@ -251,13 +258,13 @@ class JaceCompiled:
- Handle pytrees.
"""

_csdfg: jdace.CompiledSDFG # The compiled SDFG object.
_csdfg: dace_helper.CompiledSDFG # The compiled SDFG object.
_inp_names: tuple[str, ...] # Name of all input arguments.
_out_names: tuple[str, ...] # Name of all output arguments.

def __init__(
self,
csdfg: jdace.CompiledSDFG,
csdfg: dace_helper.CompiledSDFG,
inp_names: Sequence[str],
out_names: Sequence[str],
) -> None:
Expand Down
Loading

0 comments on commit bf6132e

Please sign in to comment.