Skip to content

Commit

Permalink
feat: Added Auto Opt, GPU and jax.Array (#26)
Browse files Browse the repository at this point in the history
This commit addresses the following issues:
- It adds an implementation for auto optimizer
(#24).
The current implementation is not likely to stay, since it essentially
uses DaCe's version, which is known to have problems with JaCe's SDFGs.
- It allows to run stuff on GPU
(#25).
While it is possible it is still needed that the user explicitly specify
it, JAX does an auto detection.
- Instead of returning NumPy arrays JaCe now returns `jax.Array` objects
(#22).
This goes in tandem with a reworking of the type annotation, which was
wrong before (and can not be correctly made).
  • Loading branch information
philip-paul-mueller authored Oct 4, 2024
1 parent 0a9f361 commit 3be9f36
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 72 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [
]
dependencies = [
"dace>=0.16",
"jax[cpu]>=0.4.24",
"jax[cpu]>=0.4.33",
"numpy>=1.26.0",
]
description = "JAX jit using DaCe (Data Centric Parallel Programming)"
Expand Down Expand Up @@ -103,6 +103,7 @@ module = [
"dace.*",
"jax.*",
"jaxlib.*",
"cupy.",
]

# -- pytest --
Expand Down
6 changes: 6 additions & 0 deletions src/jace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@

from __future__ import annotations

import jax

import jace.translator.primitive_translators as _ # noqa: F401 [unused-import] # Needed to populate the internal translator registry.

from .__about__ import __author__, __copyright__, __license__, __version__, __version_info__
from .api import grad, jacfwd, jacrev, jit


if jax.version._version_as_tuple(jax.__version__) < (0, 4, 33):
raise ImportError(f"Require at least JAX version '0.4.33', but found '{jax.__version__}'.")


__all__ = [
"__author__",
"__copyright__",
Expand Down
46 changes: 27 additions & 19 deletions src/jace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,55 +11,60 @@

import functools
from collections.abc import Callable, Mapping
from typing import Literal, ParamSpec, TypedDict, TypeVar, overload
from typing import Any, Final, Literal, ParamSpec, TypedDict, overload

from jax import grad, jacfwd, jacrev
from typing_extensions import Unpack

from jace import stages, translator
from jace import stages, translator, util


__all__ = ["JITOptions", "grad", "jacfwd", "jacrev", "jit"]
__all__ = ["DEFAULT_BACKEND", "JITOptions", "grad", "jacfwd", "jacrev", "jit"]

_P = ParamSpec("_P")
_R = TypeVar("_R")

DEFAULT_BACKEND: Final[str] = "cpu"


class JITOptions(TypedDict, total=False):
"""
All known options to `jace.jit` that influence tracing.
Note:
Currently there are no known options, but essentially it is a subset of some
of the options that are supported by `jax.jit` together with some additional
JaCe specific ones.
Not all arguments that are supported by `jax-jit()` are also supported by
`jace.jit`. Furthermore, some additional ones might be supported.
Args:
backend: Target platform for which DaCe should generate code. Supported values
are `'cpu'` or `'gpu'`.
"""

backend: Literal["cpu", "gpu"]


@overload
def jit(
fun: Literal[None] = None,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Unpack[JITOptions],
) -> Callable[[Callable[_P, _R]], stages.JaCeWrapped[_P, _R]]: ...
) -> Callable[[Callable[_P, Any]], stages.JaCeWrapped[_P]]: ...


@overload
def jit(
fun: Callable[_P, _R],
fun: Callable[_P, Any],
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Unpack[JITOptions],
) -> stages.JaCeWrapped[_P, _R]: ...
) -> stages.JaCeWrapped[_P]: ...


def jit(
fun: Callable[_P, _R] | None = None,
fun: Callable[_P, Any] | None = None,
/,
primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None,
**kwargs: Unpack[JITOptions],
) -> Callable[[Callable[_P, _R]], stages.JaCeWrapped[_P, _R]] | stages.JaCeWrapped[_P, _R]:
) -> Callable[[Callable[_P, Any]], stages.JaCeWrapped[_P]] | stages.JaCeWrapped[_P]:
"""
JaCe's replacement for `jax.jit` (just-in-time) wrapper.
Expand All @@ -72,18 +77,20 @@ def jit(
fun: Function to wrap.
primitive_translators: Use these primitive translators for the lowering to SDFG.
If not specified the translators in the global registry are used.
kwargs: Jit arguments.
kwargs: JIT arguments, see `JITOptions` for more.
Note:
This function is the only valid way to obtain a JaCe computation.
"""
if kwargs:
# TODO(phimuell): Add proper name verification and exception type.
raise NotImplementedError(
f"The following arguments to 'jace.jit' are not yet supported: {', '.join(kwargs)}."
not_supported_jit_keys = kwargs.keys() - JITOptions.__annotations__.keys()
if not_supported_jit_keys:
raise ValueError(
f"The following arguments to 'jace.jit' are not supported: {', '.join(not_supported_jit_keys)}."
)
if kwargs.get("backend", DEFAULT_BACKEND).lower() not in {"cpu", "gpu"}:
raise ValueError(f"The backend '{kwargs['backend']}' is not supported.")

def wrapper(f: Callable[_P, _R]) -> stages.JaCeWrapped[_P, _R]:
def wrapper(f: Callable[_P, Any]) -> stages.JaCeWrapped[_P]:
jace_wrapper = stages.JaCeWrapped(
fun=f,
primitive_translators=(
Expand All @@ -92,6 +99,7 @@ def wrapper(f: Callable[_P, _R]) -> stages.JaCeWrapped[_P, _R]:
else primitive_translators
),
jit_options=kwargs,
device=util.to_device_type(kwargs.get("backend", DEFAULT_BACKEND)),
)
functools.update_wrapper(jace_wrapper, f)
return jace_wrapper
Expand Down
68 changes: 50 additions & 18 deletions src/jace/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from typing import TYPE_CHECKING, Final, TypedDict

import dace
from dace.transformation.auto import auto_optimize as dace_autoopt
from typing_extensions import Unpack


Expand All @@ -24,15 +26,17 @@


DEFAULT_OPTIMIZATIONS: Final[CompilerOptions] = {
"auto_optimize": True,
"auto_optimize": False,
"simplify": True,
"persistent_transients": True,
"validate": True,
"validate_all": False,
}

NO_OPTIMIZATIONS: Final[CompilerOptions] = {
"auto_optimize": False,
"simplify": False,
"persistent_transients": False,
"validate": True,
"validate_all": False,
}


Expand All @@ -49,10 +53,15 @@ class CompilerOptions(TypedDict, total=False):

auto_optimize: bool
simplify: bool
persistent_transients: bool
validate: bool
validate_all: bool


def jace_optimize(tsdfg: tjsdfg.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOptions]) -> None: # noqa: D417 [undocumented-param]
def jace_optimize( # noqa: D417 [undocumented-param] # `kwargs` is not documented.
tsdfg: tjsdfg.TranslatedJaxprSDFG,
device: dace.DeviceType,
**kwargs: Unpack[CompilerOptions],
) -> None: # [undocumented-param]
"""
Performs optimization of the translated SDFG _in place_.
Expand All @@ -62,22 +71,45 @@ def jace_optimize(tsdfg: tjsdfg.TranslatedJaxprSDFG, **kwargs: Unpack[CompilerOp
Args:
tsdfg: The translated SDFG that should be optimized.
device: The device on which the SDFG will run on.
simplify: Run the simplification pipeline.
auto_optimize: Run the auto optimization pipeline (currently does nothing)
persistent_transients: Set the allocation lifetime of (non register) transients
in the SDFG to `AllocationLifetime.Persistent`, i.e. keep them allocated
between different invocations.
auto_optimize: Run the auto optimization pipeline.
validate: Perform validation of the SDFG at the end.
validate_all: Perform validation after each substep.
Note:
Currently DaCe's auto optimization pipeline is used when auto optimize is
enabled. However, it might change in the future. Because DaCe's auto
optimizer is considered unstable it must be explicitly enabled.
"""
# TODO(phimuell): Implement the functionality.
# Currently this function exists primarily for the sake of existing.

simplify = kwargs.get("simplify", False)
auto_optimize = kwargs.get("auto_optimize", False)
assert device in {dace.DeviceType.CPU, dace.DeviceType.GPU}
# If an argument is not specified then we consider it disabled.
kwargs = {**NO_OPTIMIZATIONS, **kwargs}
simplify = kwargs["simplify"]
auto_optimize = kwargs["auto_optimize"]
validate = kwargs["validate"]
validate_all = kwargs["validate_all"]

if simplify:
tsdfg.sdfg.simplify()
tsdfg.sdfg.simplify(
validate=validate,
validate_all=validate_all,
)

if device == dace.DeviceType.GPU:
tsdfg.sdfg.apply_gpu_transformations(
validate=validate,
validate_all=validate_all,
simplify=True,
)

if auto_optimize:
pass

tsdfg.validate()
dace_autoopt.auto_optimize(
sdfg=tsdfg.sdfg,
device=device,
validate=validate,
validate_all=validate_all,
)

if validate or validate_all:
tsdfg.validate()
Loading

0 comments on commit 3be9f36

Please sign in to comment.