Skip to content

Commit

Permalink
Integrated part of teh development branch into the new PR branch.
Browse files Browse the repository at this point in the history
Essentially it is a partially copy of `src/jace` from the development branch (027ae35) to this branch.
However, the translators were not copied, thus they are still WIP mode.
Furthermore, all changes to the test were made such that they pass, i.e. they are in WIP mode.
  • Loading branch information
philip-paul-mueller committed Jun 18, 2024
1 parent 1ebbcf3 commit 7cdb5f5
Show file tree
Hide file tree
Showing 20 changed files with 1,087 additions and 1,132 deletions.
3 changes: 3 additions & 0 deletions src/jace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@

from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__
from .api import grad, jacfwd, jacrev, jit
from .translated_jaxpr_sdfg import CompiledJaxprSDFG, TranslatedJaxprSDFG


__all__ = [
"CompiledJaxprSDFG",
"TranslatedJaxprSDFG",
"__author__",
"__copyright__",
"__license__",
Expand Down
59 changes: 42 additions & 17 deletions src/jace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from __future__ import annotations

import functools
from typing import TYPE_CHECKING, Any, Literal, overload
import inspect
from typing import TYPE_CHECKING, Literal, ParamSpec, TypedDict, TypeVar, overload

from jax import grad, jacfwd, jacrev
from typing_extensions import Unpack

from jace import stages, translator

Expand All @@ -21,57 +23,80 @@
from collections.abc import Callable, Mapping


__all__ = ["grad", "jacfwd", "jacrev", "jit"]
__all__ = ["JitOptions", "grad", "jacfwd", "jacrev", "jit"]

# Used for type annotation, see the notes in `jace.stages` for more.
_P = ParamSpec("_P")
_RetrunType = TypeVar("_RetrunType")


class JitOptions(TypedDict, total=False):
"""
All known options to `jace.jit` that influence tracing.
Notes:
Currently there are no known options, but essentially it is a subset of some
of the options that are supported by `jax.jit` together with some additional
JaCe specific ones.
"""


@overload
def jit(
fun: Literal[None] = None,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> Callable[[Callable], stages.JaCeWrapped]: ...
**kwargs: Unpack[JitOptions],
) -> Callable[[Callable[_P, _RetrunType]], stages.JaCeWrapped[_P, _RetrunType]]: ...


@overload
def jit(
fun: Callable,
fun: Callable[_P, _RetrunType],
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> stages.JaCeWrapped: ...
**kwargs: Unpack[JitOptions],
) -> stages.JaCeWrapped[_P, _RetrunType]: ...


def jit(
fun: Callable | None = None,
fun: Callable[_P, _RetrunType] | None = None,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Any,
) -> stages.JaCeWrapped | Callable[[Callable], stages.JaCeWrapped]:
**kwargs: Unpack[JitOptions],
) -> (
Callable[[Callable[_P, _RetrunType]], stages.JaCeWrapped[_P, _RetrunType]]
| stages.JaCeWrapped[_P, _RetrunType]
):
"""
JaCe's replacement for `jax.jit` (just-in-time) wrapper.
It works the same way as `jax.jit` does, but instead of using XLA the
computation is lowered to DaCe. In addition it accepts some JaCe specific
arguments.
It works the same way as `jax.jit` does, but instead of lowering the
computation to XLA, it is lowered to DaCe.
The function supports a subset of the arguments that are accepted by `jax.jit()`,
currently none, and some JaCe specific ones.
Args:
fun: Function to wrap.
primitive_translators: Use these primitive translators for the lowering to SDFG.
If not specified the translators in the global registry are used.
kwargs: Jit arguments.
Notes:
After constructions any change to `primitive_translators` has no effect.
Note:
This function is the only valid way to obtain a JaCe computation.
"""
if kwargs:
# TODO(phimuell): Add proper name verification and exception type.
raise NotImplementedError(
f"The following arguments to 'jace.jit' are not yet supported: {', '.join(kwargs)}."
)

def wrapper(f: Callable) -> stages.JaCeWrapped:
# TODO(egparedes): Improve typing.
def wrapper(f: Callable[_P, _RetrunType]) -> stages.JaCeWrapped[_P, _RetrunType]:
if any(
param.default is not param.empty for param in inspect.signature(f).parameters.values()
):
raise NotImplementedError("Default values are not yet supported.")

jace_wrapper = stages.JaCeWrapped(
fun=f,
primitive_translators=(
Expand Down
26 changes: 17 additions & 9 deletions src/jace/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
#
# SPDX-License-Identifier: BSD-3-Clause

"""
JaCe specific optimizations.
Currently just a dummy exists for the sake of providing a callable function.
"""
"""JaCe specific optimizations."""

from __future__ import annotations

Expand All @@ -19,7 +15,7 @@


if TYPE_CHECKING:
from jace import translator
import jace


class CompilerOptions(TypedDict, total=False):
Expand All @@ -35,15 +31,24 @@ class CompilerOptions(TypedDict, total=False):

auto_optimize: bool
simplify: bool
persistent: bool


# TODO(phimuell): Add a context manager to modify the default.
DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": True, "simplify": True}
DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = {
"auto_optimize": True,
"simplify": True,
"persistent": True,
}

NO_OPTIMIZATIONS: Final[CompilerOptions] = {"auto_optimize": False, "simplify": False}
NO_OPTIMIZATIONS: Final[CompilerOptions] = {
"auto_optimize": False,
"simplify": False,
"persistent": False,
}


def jace_optimize(tsdfg: translator.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOptions]) -> None: # noqa: D417 # Missing description for kwargs
def jace_optimize(tsdfg: jace.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOptions]) -> None: # noqa: D417 # Missing description for kwargs
"""
Performs optimization of the translated SDFG _in place_.
Expand All @@ -55,6 +60,9 @@ def jace_optimize(tsdfg: translator.TranslatedJaxprSDFG, **kwargs: Unpack[Compil
tsdfg: The translated SDFG that should be optimized.
simplify: Run the simplification pipeline.
auto_optimize: Run the auto optimization pipeline (currently does nothing)
persistent: Make the memory allocation persistent, i.e. allocate the
transients only once at the beginning and then reuse the memory across
the lifetime of the SDFG.
"""
# Currently this function exists primarily for the same of existing.

Expand Down
Loading

0 comments on commit 7cdb5f5

Please sign in to comment.