From bf6132ef20907936787b5f93d8c9d22b20cc7cfc Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 27 May 2024 15:57:33 +0200 Subject: [PATCH] First, chunk of work. --- src/jace/__init__.py | 7 - src/jace/jax/api.py | 40 ++-- src/jace/jax/stages.py | 139 ++++++------ src/jace/jax/translation_cache.py | 149 ++++++------- .../translator/jaxpr_translator_driver.py | 200 ++++++++---------- src/jace/translator/primitive_translator.py | 74 +++---- tests/test_jaxpr_translator_driver.py | 8 +- 7 files changed, 287 insertions(+), 330 deletions(-) diff --git a/src/jace/__init__.py b/src/jace/__init__.py index aad3265..05d9632 100644 --- a/src/jace/__init__.py +++ b/src/jace/__init__.py @@ -17,13 +17,6 @@ from .jax import grad, jacfwd, jacrev, jit -# In Jax `float32` is the main datatype, and they go to great lengths to avoid -# some aggressive [type promotion](https://jax.readthedocs.io/en/latest/type_promotion.html). -# However, in this case we will have problems when we call the SDFG, for some reasons -# `CompiledSDFG` does not work in that case correctly, thus we enable it now globally. -_jax.config.update("jax_enable_x64", True) - - __all__ = [ "__author__", "__copyright__", diff --git a/src/jace/jax/api.py b/src/jace/jax/api.py index a46702b..a214f5a 100644 --- a/src/jace/jax/api.py +++ b/src/jace/jax/api.py @@ -19,11 +19,19 @@ from jace.jax import stages +__all__ = [ + "grad", + "jit", + "jacfwd", + "jacrev", +] + + @overload def jit( fun: Literal[None] = None, /, - sub_translators: Mapping[str, translator.PrimitiveTranslator] | None = None, + primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None, **kwargs: Any, ) -> Callable[[Callable], stages.JaceWrapped]: ... @@ -32,7 +40,7 @@ def jit( def jit( fun: Callable, /, - sub_translators: Mapping[str, translator.PrimitiveTranslator] | None = None, + primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None, **kwargs: Any, ) -> stages.JaceWrapped: ... @@ -40,7 +48,7 @@ def jit( def jit( fun: Callable | None = None, /, - sub_translators: Mapping[str, translator.PrimitiveTranslator] | None = None, + primitive_translators: Mapping[str, translator.PrimitiveTranslator] | None = None, **kwargs: Any, ) -> stages.JaceWrapped | Callable[[Callable], stages.JaceWrapped]: """Jace's replacement for `jax.jit` (just-in-time) wrapper. @@ -50,36 +58,28 @@ def jit( In addition it accepts some Jace specific arguments. Args: - sub_translators: Use these subtranslators for the lowering to DaCe. + primitive_translators: Use these primitive translators for the lowering to SDFG. Notes: - If no subtranslators are specified then the ones that are currently active, - i.e. the output of `get_regsitered_primitive_translators()`, are used. - After construction changes to the passed `sub_translators` have no effect on the returned object. + If no translators are specified the currently ones currently inside the global registry are used. + After construction changes to the passed `primitive_translators` have no effect on the returned object. """ if kwargs: + # TODO(phimuell): Add proper name verification and exception type. raise NotImplementedError( - f"The following arguments of 'jax.jit' are not yet supported by jace: {', '.join(kwargs.keys())}." + f"The following arguments to 'jace.jit' are not yet supported: {', '.join(kwargs)}." ) def wrapper(f: Callable) -> stages.JaceWrapped: jace_wrapper = stages.JaceWrapped( fun=f, - sub_translators=( + primitive_translators=( translator.managing._PRIMITIVE_TRANSLATORS_DICT - if sub_translators is None - else sub_translators + if primitive_translators is None + else primitive_translators ), - jit_ops=kwargs, + jit_options=kwargs, ) return functools.update_wrapper(jace_wrapper, f) return wrapper if fun is None else wrapper(fun) - - -__all__ = [ - "grad", - "jit", - "jacfwd", - "jacrev", -] diff --git a/src/jace/jax/stages.py b/src/jace/jax/stages.py index 5ad907e..9a0218f 100644 --- a/src/jace/jax/stages.py +++ b/src/jace/jax/stages.py @@ -34,47 +34,49 @@ from jace.jax import translation_cache as tcache from jace.optimization import CompilerOptions from jace.translator import post_translation as ptrans -from jace.util import dace_helper as jdace +from jace.util import dace_helper -class JaceWrapped(tcache.CachingStage): +class JaceWrapped(tcache.CachingStage["JaceLowered"]): """A function ready to be specialized, lowered, and compiled. This class represents the output of functions such as `jace.jit()`. - Calling it results in jit (just-in-time) lowering, compilation, and execution. - It can also be explicitly lowered prior to compilation, and the result compiled prior to execution. - - You should not create `JaceWrapped` instances directly, instead you should use `jace.jit`. + Calling it results in jit (just-in-time) lowering, compilation and execution. + It is also possible to lower the function explicitly by calling `self.lower()`. + This function can be composed with other Jax transformations. Todo: - Handle pytrees. + - Handle all options to `jax.jit`. + + Note: + The tracing of function will always happen with enabled `x64` mode, which is implicitly + and temporary activated during tracing. Furthermore, the disable JIT config flag is ignored. """ _fun: Callable - _sub_translators: dict[str, translator.PrimitiveTranslator] - _jit_ops: dict[str, Any] + _primitive_translators: dict[str, translator.PrimitiveTranslator] + _jit_options: dict[str, Any] def __init__( self, fun: Callable, - sub_translators: Mapping[str, translator.PrimitiveTranslator], - jit_ops: Mapping[str, Any], + primitive_translators: Mapping[str, translator.PrimitiveTranslator], + jit_options: Mapping[str, Any], ) -> None: """Creates a wrapped jitable object of `fun`. - You should not create `JaceWrapped` instances directly, instead you should use `jace.jit`. - Args: - fun: The function that is wrapped. - sub_translators: The list of subtranslators that that should be used. - jit_ops: All options that we forward to `jax.jit`. + fun: The function that is wrapped. + primitive_translators: The list of subtranslators that that should be used. + jit_options: Options to influence the jit process. """ super().__init__() # We have to shallow copy both the translator and the jit options. # This prevents that any modifications affect `self`. # Shallow is enough since the translators themselves are immutable. - self._sub_translators = dict(sub_translators) - self._jit_ops = dict(jit_ops) + self._primitive_translators = dict(primitive_translators) + self._jit_options = dict(jit_options) self._fun = fun def __call__( @@ -84,22 +86,16 @@ def __call__( ) -> Any: """Executes the wrapped function, lowering and compiling as needed in one step.""" - # TODO(phimuell): Handle the `disable_jit` context manager of Jax. - - # This allows us to be composable with Jax transformations. + # If we are inside a traced context, then we forward the call to the wrapped function. + # This ensures that Jace is composable with Jax. if util.is_tracing_ongoing(*args, **kwargs): - # TODO(phimuell): Handle the case of gradients: - # It seems that this one uses special tracers, since they can handle comparisons. - # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-autodiff return self._fun(*args, **kwargs) - # TODO(phimuell): Handle static arguments correctly - # https://jax.readthedocs.io/en/latest/aot.html#lowering-with-static-arguments lowered = self.lower(*args, **kwargs) compiled = lowered.compile() return compiled(*args, **kwargs) - @tcache.cached_translation + @tcache.cached_transition def lower( self, *args: Any, @@ -107,12 +103,8 @@ def lower( ) -> JaceLowered: """Lower this function explicitly for the given arguments. - Performs the first two steps of the AOT steps described above, - i.e. transformation into Jaxpr and then to SDFG. - The result is encapsulated into a `Lowered` object. - - Todo: - - Handle pytrees. + Performs the first two steps of the AOT steps described above, i.e. transformation into + Jaxpr and then to SDFG. The result is encapsulated into a `Lowered` object. """ if len(kwargs) != 0: raise NotImplementedError("Currently only positional arguments are supported.") @@ -122,12 +114,19 @@ def lower( if not all((not util.is_array(arg)) or arg.flags["C_CONTIGUOUS"] for arg in args): raise NotImplementedError("Currently can not handle strides beside 'C_CONTIGUOUS'.") - jaxpr = _jax.make_jaxpr(self._fun)(*args) - driver = translator.JaxprTranslationDriver(sub_translators=self._sub_translators) - trans_sdfg: translator.TranslatedJaxprSDFG = driver.translate_jaxpr(jaxpr) - ptrans.postprocess_jaxpr_sdfg(tsdfg=trans_sdfg, fun=self.wrapped_fun) - # The `JaceLowered` assumes complete ownership of `trans_sdfg`! - return JaceLowered(trans_sdfg) + # In Jax `float32` is the main datatype, and they go to great lengths to avoid + # some aggressive [type promotion](https://jax.readthedocs.io/en/latest/type_promotion.html). + # However, in this case we will have problems when we call the SDFG, for some reasons + # `CompiledSDFG` does not work in that case correctly, thus we enable it for the tracing. + with _jax.experimental.enable_x64(): + driver = translator.JaxprTranslationDriver( + primitive_translators=self._primitive_translators + ) + jaxpr = _jax.make_jaxpr(self._fun)(*args) + tsdfg: translator.TranslatedJaxprSDFG = driver.translate_jaxpr(jaxpr) + ptrans.postprocess_jaxpr_sdfg(tsdfg=tsdfg, fun=self.wrapped_fun) + + return JaceLowered(tsdfg) @property def wrapped_fun(self) -> Callable: @@ -137,42 +136,47 @@ def wrapped_fun(self) -> Callable: def _make_call_description( self, *args: Any, - ) -> tcache.CachedCallDescription: + ) -> tcache.StageTransformationDescription: """This function computes the key for the `JaceWrapped.lower()` call. Currently it is only able to handle positional argument and does not support static arguments. The function will fully abstractify its input arguments. This function is used by the cache to generate the key. """ - fargs = tuple(tcache._AbstractCallArgument.from_value(x) for x in args) - return tcache.CachedCallDescription(stage_id=id(self), fargs=fargs) + call_args = tuple(tcache._AbstractCallArgument.from_value(x) for x in args) + return tcache.StageTransformationDescription(stage_id=id(self), call_args=call_args) + +class JaceLowered(tcache.CachingStage["JaceCompiled"]): + """Represents the original computation as an SDFG. -class JaceLowered(tcache.CachingStage): - """Represents the original computation that was lowered to SDFG. + Although, `JaceWrapped` is composable with Jax transformations `JaceLowered` is not. + A user should never create such an object. Todo: - Handle pytrees. """ - # `self` assumes complete ownership of the - _trans_sdfg: translator.TranslatedJaxprSDFG + _translated_sdfg: translator.TranslatedJaxprSDFG def __init__( self, - trans_sdfg: translator.TranslatedJaxprSDFG, + tsdfg: translator.TranslatedJaxprSDFG, ) -> None: - """Constructs the lowered object.""" - if not trans_sdfg.is_finalized: + """Initialize the lowered object. + + Args: + tsdfg: The lowered SDFG with metadata. Must be finalized. + + Notes: + The passed `tsdfg` will be managed by `self`. + """ + if not tsdfg.is_finalized: raise ValueError("The translated SDFG must be finalized.") - if trans_sdfg.inp_names is None: - raise ValueError("Input names must be defined.") - if trans_sdfg.out_names is None: - raise ValueError("Output names must be defined.") super().__init__() - self._trans_sdfg = trans_sdfg + self._translated_sdfg = tsdfg - @tcache.cached_translation + @tcache.cached_transition def compile( self, compiler_options: CompilerOptions | None = None, @@ -192,11 +196,9 @@ def compile( # We **must** deepcopy before we do any optimization. # The reason is `self` is cached and assumed to be immutable. # Since all optimizations works in place, we would violate this assumption. - tsdfg: translator.TranslatedJaxprSDFG = copy.deepcopy(self._trans_sdfg) + tsdfg: translator.TranslatedJaxprSDFG = copy.deepcopy(self._translated_sdfg) - # Must be the same as in `_make_call_description()`! - options = optimization.DEFAULT_OPTIMIZATIONS | (compiler_options or {}) - optimization.jace_optimize(tsdfg=tsdfg, **options) + optimization.jace_optimize(tsdfg=tsdfg, **self._make_compiler_options(compiler_options)) return JaceCompiled( csdfg=util.compile_jax_sdfg(tsdfg), @@ -211,7 +213,7 @@ def compiler_ir(self, dialect: str | None = None) -> translator.TranslatedJaxprS It is important that modifying this object in any ways is considered an error. """ if (dialect is None) or (dialect.upper() == "SDFG"): - return self._trans_sdfg + return self._translated_sdfg raise ValueError(f"Unknown dialect '{dialect}'.") def as_html(self, filename: str | None = None) -> None: @@ -231,17 +233,22 @@ def as_sdfg(self) -> dace.SDFG: def _make_call_description( self, compiler_options: CompilerOptions | None = None, - ) -> tcache.CachedCallDescription: + ) -> tcache.StageTransformationDescription: """This function computes the key for the `self.compile()` call. The function only get one argument that is either a `dict` or a `None`, where `None` means `use default argument. The function will construct a concrete description of the call using `(name, value)` pairs. This function is used by the cache. """ - # Must be the same as in `compile()`! - options = optimization.DEFAULT_OPTIMIZATIONS | (compiler_options or {}) - fargs = tuple(sorted(options.items(), key=lambda X: X[0])) - return tcache.CachedCallDescription(stage_id=id(self), fargs=fargs) + options = self._make_compiler_options(compiler_options) + call_args = tuple(sorted(options.items(), key=lambda X: X[0])) + return tcache.StageTransformationDescription(stage_id=id(self), call_args=call_args) + + def _make_compiler_options( + self, + compiler_options: CompilerOptions | None, + ) -> CompilerOptions: + return optimization.DEFAULT_OPTIMIZATIONS | (compiler_options or {}) class JaceCompiled: @@ -251,13 +258,13 @@ class JaceCompiled: - Handle pytrees. """ - _csdfg: jdace.CompiledSDFG # The compiled SDFG object. + _csdfg: dace_helper.CompiledSDFG # The compiled SDFG object. _inp_names: tuple[str, ...] # Name of all input arguments. _out_names: tuple[str, ...] # Name of all output arguments. def __init__( self, - csdfg: jdace.CompiledSDFG, + csdfg: dace_helper.CompiledSDFG, inp_names: Sequence[str], out_names: Sequence[str], ) -> None: diff --git a/src/jace/jax/translation_cache.py b/src/jace/jax/translation_cache.py index 5bcb948..b7c92b3 100644 --- a/src/jace/jax/translation_cache.py +++ b/src/jace/jax/translation_cache.py @@ -21,7 +21,14 @@ import dataclasses import functools from collections.abc import Callable, Hashable -from typing import TYPE_CHECKING, Any, Final, TypeAlias +from typing import ( + TYPE_CHECKING, + Any, + Generic, + TypeAlias, + TypeVar, + cast, +) import dace from jax import core as jax_core @@ -32,24 +39,24 @@ if TYPE_CHECKING: from jace.jax import stages -# This is the default cache size we are using -_DEF_CACHE_SIZE: Final[int] = 256 +#: Caches used to store the state transition. +#: The states are on a per stage and not per instant basis. +_TRANSLATION_CACHES: dict[type[CachingStage], StageCache] = {} -# This are the caches that we are using. -_TRANSLATION_CACHES: dict[type[CachingStage], TranslationCache] = {} +Stage_ = TypeVar("Stage_", bound="stages.Stage") -class CachingStage: +class CachingStage(Generic[Stage_]): """Annotates a stage whose transition to the next one is cacheable. - This transitions are mainly `JaceWrapped.lower()` and `JaceLowered.compile()` calls. - To make a stage cacheable annotate the transition function with the `@cached_translation` decorator. + To make a transition function cacheable it must be annotated by the + `@cached_transition` decorator. Todo: - Make a generic to indicate what the result stage is. """ - _cache: TranslationCache + _cache: StageCache[Stage_] def __init__(self) -> None: self._cache = get_cache(self) @@ -59,14 +66,17 @@ def _make_call_description( self: CachingStage, *args: Any, **kwargs: Any, - ) -> CachedCallDescription: + ) -> StageTransformationDescription: """Generates the key that is used to store/locate the call in the cache.""" ... -def cached_translation( - action: Callable[..., stages.Stage], -) -> Callable: +Action_T = TypeVar("Action_T", bound=Callable[..., Any]) + + +def cached_transition( + action: Action_T, +) -> Action_T: """Decorator for making the transition function of the stage cacheable. The decorator will call the annotated function only if the call is not stored inside the cache. @@ -79,18 +89,15 @@ def _action_wrapper( self: CachingStage, *args: Any, **kwargs: Any, - ) -> stages.Stage: - # Get the abstract description of the call, that is used as key. - key: CachedCallDescription = self._make_call_description(*args, **kwargs) + ): + key: StageTransformationDescription = self._make_call_description(*args, **kwargs) if key in self._cache: return self._cache[key] - - # We must actually perform the call next_stage: stages.Stage = action(self, *args, **kwargs) self._cache[key] = next_stage return next_stage - return _action_wrapper + return cast(Action_T, _action_wrapper) def clear_translation_cache() -> None: @@ -100,13 +107,12 @@ def clear_translation_cache() -> None: def get_cache( stage: CachingStage, -) -> TranslationCache: +) -> StageCache: """Returns the cache that is used for `stage`.""" - # The caches are per stage and not per instance basis - tstage = type(stage) - if tstage not in _TRANSLATION_CACHES: - _TRANSLATION_CACHES[tstage] = TranslationCache(size=_DEF_CACHE_SIZE) - return _TRANSLATION_CACHES[tstage] + stage_type = type(stage) + if stage_type not in _TRANSLATION_CACHES: + _TRANSLATION_CACHES[stage_type] = StageCache() + return _TRANSLATION_CACHES[stage_type] @dataclasses.dataclass(frozen=True) @@ -116,6 +122,12 @@ class _AbstractCallArgument: It is used as part of the key in the cache. It represents the structure of the argument, i.e. its shape, type and so on, but nots its value. To construct it you should use the `from_value()` class function which interfere the characteristics from a value. + + Attributes: + shape: In case of an array its shape, in case of a scalar the empty tuple. + dtype: The DaCe type of the argument. + strides: The strides of the argument, or `None` if they are unknown or a scalar. + storage: The storage type where the argument is stored. """ shape: tuple[int, ...] @@ -172,80 +184,65 @@ def from_value( @dataclasses.dataclass(frozen=True) -class CachedCallDescription: - """Represents the full structure of a call in the cache as a key. - - This class is the return type of the `CachingStage._make_call_description()` function, - which is used by the `@cached_translation` decorator to compute a key of transition. - This allows to either retrieve or then store the result of the actual call in the cache. - - The actual key is composed of two parts, first the "origin of the call". - For this we just use the address of the stage object we are caching and hope that the - address is not reused for another stag anytime soon. - - The second part is of the key are a description of the actual arguments, see `CallArgsDescription` type alias. - For this the `_make_call_description()` method of the stage is used. - The arguments can be described in two different ways: - - Abstract description: In this way, the actual value of the argument is irrelevant, - only the structure of them are important, this is similar to the tracers used in Jax. - - Concrete description: Here one caches on the actual value of the argument, - which is similar to static arguments in Jax. - The only restriction is that they are hash able. +class StageTransformationDescription: + """Represents the call to a state transformation function. + + State transition functions are annotated with `@cached_transition` and stored inside a cache. + This class serves as a key inside this cache and is generated by `CachingStage._make_call_description()`. + The actual key is consists of two parts. + + Attributes: + stage_id: Origin of the call, for which the id of the stage object should be used. + call_args: Description of the arguments of the call. There are two ways to describe + the arguments: + - Abstract description: In this way, the actual value of the argument is irrelevant, + only the structure of them are important, similar to the tracers used in Jax. + - Concrete description: Here one caches on the actual value of the argument. + The only requirement is that they can be hashed. Notes: The base assumption is that the stages are immutable. Todo: - pytrees. - - Turn the references into week references, Jax does this and I am sure there is a reason for it. """ stage_id: int - fargs: CallArgsDescription + call_args: CallArgsDescription -class TranslationCache: - """The cache object used to cache the stage transitions. +class StageCache(Generic[Stage_]): + """LRU cache that is used to cache the stage transitions, i.e. lowering and compiling, in Jace. Notes: - The most recently used entry is at the end of the `OrderedDict`, because it puts new entries there. + The most recently used entry is at the end of the `OrderedDict`. """ - __slots__ = ("_memory", "_size") - - _memory: collections.OrderedDict[CachedCallDescription, stages.Stage] + _memory: collections.OrderedDict[StageTransformationDescription, Stage_] _size: int def __init__( self, - size: int, + size: int = 256, ) -> None: - """Creates a cache instance of size. + """Creates a LRU cache with `size` many entries. - The cache will have size `size` and use `key` as key function. + Args: + size: Number of entries the cache holds, defaults to 256. """ - if size <= 0: - raise ValueError(f"Invalid cache size of '{size}'") self._memory = collections.OrderedDict() self._size = size def __contains__( self, - key: CachedCallDescription, + key: StageTransformationDescription, ) -> bool: - """Check if `self` have a record of `key`.""" return key in self._memory def __getitem__( self, - key: CachedCallDescription, - ) -> stages.Stage: - """Get the next stage associated with `key`. - - Notes: - It is an error if `key` does not exist. - This function will mark `key` as most recently used. - """ + key: StageTransformationDescription, + ) -> Stage_: if key not in self: raise KeyError(f"Key '{key}' is unknown.") self._memory.move_to_end(key, last=True) @@ -253,25 +250,20 @@ def __getitem__( def __setitem__( self, - key: CachedCallDescription, - res: stages.Stage, - ) -> TranslationCache: - """Adds or update `key` to map to `res`.""" + key: StageTransformationDescription, + res: Stage_, + ) -> None: if key in self: - # `key` is known, so move it to the end and update the mapped value. self._memory.move_to_end(key, last=True) self._memory[key] = res - else: - # `key` is not known so we have to add it - while len(self._memory) >= self._size: + if len(self._memory) == self._size: self.popitem(None) self._memory[key] = res - return self def popitem( self, - key: CachedCallDescription | None, + key: StageTransformationDescription | None, ) -> None: """Evict `key` from `self`. @@ -286,5 +278,4 @@ def popitem( self._memory.popitem(last=False) def __repr__(self) -> str: - """Textual representation for debugging.""" - return f"TranslationCache({len(self._memory)} / {self._size} || {', '.join( '[' + repr(k) + ']' for k in self._memory)})" + return f"StageCache({len(self._memory)} / {self._size} || {', '.join( '[' + repr(k) + ']' for k in self._memory)})" diff --git a/src/jace/translator/jaxpr_translator_driver.py b/src/jace/translator/jaxpr_translator_driver.py index 508322c..aa794b6 100644 --- a/src/jace/translator/jaxpr_translator_driver.py +++ b/src/jace/translator/jaxpr_translator_driver.py @@ -33,62 +33,62 @@ class JaxprTranslationDriver: - the `arg_names` parameter is not set. For these reasons the SDFG is not directly usable, and further manipulations have to be performed. - Especially, DaCe's validation function will fail and it is unable to be processed by the optimization pipeline. - For more information also see `jace.translator.post_translation` module. + Especially, DaCe's validation function will fail and it is unable to be processed by the + optimization pipeline. For more information also see `jace.translator.post_translation` module. - The idea of the translator is extremely simple. - Since Jaxpr is a list consisting of more or less simple instructions/equations, they get processed one after the other. - Each equation is translated into its own state that is appended to the SDFG, thus the SDFG is a long list of states. - In certain cases it might be that an equation needs more states, but this is an exception. + The idea of the translator is extremely simple. Since Jaxpr is a list consisting of more or less + simple instructions/equations, they get processed one after the other. + Each equation is translated into its own state that is appended to the SDFG, thus the SDFG is a + long list of states. In certain cases it might be that an equation needs more states, + but this is an exception. The actual translation of the equation is not handled by the driver. - Instead the request is forwarded to a `PrimitiveTranslator` object, also known as subtranslator. - This is a highly specialized object that is able to handle one kind of primitive. - For more information on the subtranslators see the documentation of `PrimitiveTranslator`. + Instead the request is forwarded to a `PrimitiveTranslator` object, known as primitive translator + or subtranslator. This is a highly specialized object that is able to handle one kind of primitive. + For more information on them see the documentation of `PrimitiveTranslator`. - To start a translation the `translate_jaxpr()` function should be called, if this happens it is said that the driver has an ongoing translation. - If `translate_jaxpr()` is called on a driver that has an ongoing translation, a new translation context will be set up. + To start a translation the `translate_jaxpr()` function should be called, if this happens it is + said that the driver has an ongoing translation. If `translate_jaxpr()` is called on a driver + that has an ongoing translation, a new translation context will be set up. Thus the driver will then translate the supplied (nested) Jaxpr and return the result. However, this will have no influence on the translation process that is already going. Notes: After the main translation has been performed the translator object can be used again. - Currently the driver will generate only Array as SDFG variables, however, this is a temporary solution, see `add_array()`. + Currently the driver will generate only Array as SDFG variables, however, this is a + temporary solution, see `add_array()`. """ - __slots__ = ("_ctx_stack", "_sub_translators", "_jax_name_map") + __slots__ = ("_ctx_stack", "_primitive_translators", "_jax_name_map") - _sub_translators: Mapping[str, translator.PrimitiveTranslatorCallable] + _primitive_translators: Mapping[str, translator.PrimitiveTranslatorCallable] _jax_name_map: dict[jax_core.Var | util.JaCeVar, str] _ctx_stack: list[translator.TranslatedJaxprSDFG] def __init__( self, - sub_translators: Mapping[str, translator.PrimitiveTranslatorCallable], + primitive_translators: Mapping[str, translator.PrimitiveTranslatorCallable], ) -> None: - """Creates the driver. + """Creates the driver ready for translation. Args: - sub_translators: Use these subtranslators to perform the translation. + primitive_translators: Primitive to use during the translation. - Notes: - `sub_translators` is not copied, however, the user has to guarantee, that it does not change during the lifetime of `self`. + Note: + The primitive translators are not copied, thus the user has to ensure that the passed mapping + does not change during the translation. """ - - # Maps the name of a Jax primitive to the primitive translator that should be used. - # Note that the subtranslator is only required to be a callable, and immutable. - # User has to ensure that it does not change. - self._sub_translators = sub_translators + # Maps name of primitives to the associated translator. + self._primitive_translators = primitive_translators # Maps Jax variables to the name of its SDFG equivalent. - # Note that it is shared among all translation contexts. - # This is done to create consistency between SDFG variables - # and the names used pretty printed Jaxprs. + # Shared between all translation contexts, to ensure consecutive + # variable naming as seen as in a pretty printed Jaxpr. + # Will be cleared by `_clear_translation_ctx()` at the end of the translation. self._jax_name_map = {} - # Context stack and current context. - # If it is empty, then no translation process is in process. - # If there is one entry, `self` is the root translator. + # Stack of all context, to handle nested Jaxpr instances. + # The first one, i.e. index 0, is known as head translator. self._ctx_stack = [] def translate_jaxpr( @@ -99,30 +99,21 @@ def translate_jaxpr( ) -> translator.TranslatedJaxprSDFG: """Perform the translation of a Jaxpr into a SDFG. - In case this function is called and `self` has an ongoing translation process, a new translation context will be created. - This means the Jaxpr will be translated independently from the previous one. + In case this function is called and `self` has an ongoing translation process, a new + translation context will be created. This means the Jaxpr will be translated independently + from the previous one. Returns: The function will translate the passed Jaxpr object into an SDFG in canonical form. - This SDFG together with additional meta data, that is needed for further processing is encapsulated inside a `TranslatedJaxprSDFG` object. + This SDFG together with additional meta data, that is needed for further processing + is encapsulated inside a `TranslatedJaxprSDFG` object. Args: name: Use this name for the SDFG instead some generated one. """ - import jax as _jax if len(jaxpr.effects) != 0: raise NotImplementedError("'Jaxpr' with side effects are not supported.") - if not _jax.config.read("jax_enable_x64"): - # NOTE: What is interesting here is, that the SDFG can be called, but the result is garbage. - # Beside that I think it should not work, I think it should not even call, - # because of a mismatch in data types. - # However, If we work with Jax arrays themselves, it should technically work. - # But currently the best we can do, is forbid it! - raise NotImplementedError( - "You have disabled 'x64' support in Jax, which interferes with the calling of the SDFG. " - "SDFG generated in this way will fail to call." - ) # NOTE: If `self` is already allocated, i.e. has an ongoing translation process, # the `_allocate_translation_ctx()` function will start a new context. @@ -136,11 +127,8 @@ def translate_jaxpr( jaxpr=jaxpr, ) self._create_initial_input(jaxpr=jaxpr) - # Note that `self` and `jsdfg` still share the same underlying memory, i.e. context. - jsdfg: translator.TranslatedJaxprSDFG = self._translate_jaxpr_internal(jaxpr) - self._clear_translation_ctx() - return jsdfg + return self._translate_jaxpr_internal(jaxpr) def append_new_state( self, @@ -257,7 +245,7 @@ def map_jax_var_to_sdfg( def sdfg(self) -> dace.SDFG: """Returns the SDFG that is currently constructed. - If you want access to the arrays of the SDFG use `self.arrays()`/`self.get_array()`. + If you want access to the arrays of the SDFG use `self.arrays`/`self.get_array()`. """ return self._ctx.sdfg @@ -288,8 +276,8 @@ def add_jax_name_mapping( ) -> JaxprTranslationDriver: """Creates a new mapping between `jax_var` to `sdfg_name`. - If the mapping already exists an error will be generated. - This function is not able to delete a variable mapping that was established before, for this use TBA. + If the mapping already exists an error will be generated. This function is not + able to delete a variable mapping that was established before, for this use TBA. Args: jax_var: The Jax variable. @@ -299,7 +287,8 @@ def add_jax_name_mapping( if jax_var in self._jax_name_map: raise ValueError( - f"Tried to create the mapping '{jax_var} -> {sdfg_name}', but the variable is already mapped." + f"Cannot change the mapping of '{jax_var}' from" + f" '{self.map_jax_var_to_sdfg(jax_var)}' to '{sdfg_name}'." ) if sdfg_name not in self._ctx.sdfg.arrays: raise KeyError(f"Mapping '{jax_var} -> {sdfg_name}': SDFG target unknown.") @@ -320,9 +309,10 @@ def add_array( The SDFG object is always created as a transient. - By default the function will use `jace.util.propose_jax_name()` to derive the name that should be used. - However, by passing a `JaCeVar` with a name it is possible to suggest a specific name. - In addition it is possible to specify `name_prefix` to prefix name that would be used. + By default the function will use `jace.util.propose_jax_name()` to derive + the name that should be used. However, by passing a `JaCeVar` with a name it + is possible to suggest a specific name. In addition it is possible to specify + `name_prefix` to prefix name that would be used. The function will not update the internal variable mapping. If this is desired one can set `update_var_mapping`, for forcing this. @@ -333,12 +323,11 @@ def add_array( update_var_mapping: Update the internal variable mapping; by default `False`. Notes: - Currently the function will always create an Array, even if the Jax variable refers to a scalar. - This is done to work around some difficulties with scalar return values and so on. - This issue should actually handled in the post processing stage, but currently it is not. - However, from a point of building an SDFG manually, there is no difference between a Scalar and an Array. - According to the dace developer, the majority of the backend, i.e. optimization pipeline, should be handle to handle it. - But there are some special parts that might explicitly want a scalar, it also might block certain compiler optimization. + As a temporary fix for handling scalar return values, the function will always + generate arrays, even if `arg` is a scalar. + According to the dace developer, the majority of the backend, i.e. optimization + pipeline, should be handle to handle it. But there are some special parts that + might explicitly want a scalar, it also might block certain compiler optimization. """ shape: tuple[int | dace.symbol | str, ...] = util.get_jax_var_shape(arg) dtype: dace.typeclass = util.get_jax_var_dtype(arg) @@ -415,12 +404,12 @@ def create_jax_var_list( # type: ignore[misc] If no corresponding SDFG variable is known the function will create one using `add_array()`. By setting `prevent_creation` the function will not create any new SDFG variables, - if no corresponding SDFG variable exists an error is generated. - By setting `only_creation` the function will only create new SDFG variables, - if a variable already have a corresponding SDFG variable an error will be created. + if no corresponding SDFG variable exists an error is generated. By setting `only_creation` + the function will only create new SDFG variables, if a variable already have a + corresponding SDFG variable an error will be created. - By default literals cause an error. - However, by setting `handle_literals` to `True` literals will will be included in the output with the value `None`. + By default literals cause an error. However, by setting `handle_literals` to `True` + literals will will be included in the output with the value `None`. Args: jax_var_list: The list of Jax variables that should be transformed to SDFG names. @@ -518,8 +507,8 @@ def _allocate_translation_ctx( ) -> JaxprTranslationDriver: """This function allocates and initialize the members of the translation context of `self`. - If this function is called and `self` is already allocated, the function will create a new context, - allowing the driver to handle nested Jaxpr. + If this function is called and `self` is already allocated, the function will create a + new context, allowing the driver to handle nested Jaxpr. The first context that is created is also known as root translator. Args: @@ -534,10 +523,6 @@ def _allocate_translation_ctx( ) ) - if self.is_root_translator(): - # In the future we will populate the generate state here, i.e. if we are on GPU or not and so on. - assert len(self._jax_name_map) == 0 - return self @property @@ -546,15 +531,13 @@ def _ctx(self) -> translator.TranslatedJaxprSDFG: assert len(self._ctx_stack) != 0, "No context is active." return self._ctx_stack[-1] - def _clear_translation_ctx(self) -> JaxprTranslationDriver: - """This function deallocate the currently active translation context of `self`. + def _clear_translation_ctx(self) -> translator.TranslatedJaxprSDFG | None: + """Remove the current active context from `self` and returns its state. - Notes: - While it is allowed for outside code to call this function explicit it is is most likely an error. - If `self` is not allocated this function acts as a noops. + If `self` is not allocated it will return `None`. """ if not self.is_allocated(): - return self + return None if self.is_root_translator(): # The translation as a whole has finished, so restore the driver, @@ -562,8 +545,7 @@ def _clear_translation_ctx(self) -> JaxprTranslationDriver: self._jax_name_map = {} # Remove the current head stack. - _ = self._ctx_stack.pop() - return self + return self._ctx_stack.pop() def _translate_single_eqn( self, @@ -573,17 +555,13 @@ def _translate_single_eqn( To do this the function will perform the following steps: - Assemble the in and output variables. - - Select the appropriate subtranslator to use. + - Select the appropriate primitive translator to use. - Create a new empty state terminal state. - - Call the subtranslator to perform the translation inside the new state. + - Call the primitive translator to perform the translation inside the new state. Returns: The SDFG names that were used as input and output are returned. The inputs might contain `None` which indicates that that input was a Jax Literal. - - Notes: - The equation, `eqn` must come from the unclosed jaxpr instance. - The function will perform some consistency checking after the subtranslator was called. """ if len(eqn.effects) != 0: raise NotImplementedError(f"Equation '{eqn}' has side effects.") @@ -603,24 +581,22 @@ def _translate_single_eqn( update_var_mapping=True, ) - # Find the subtranslator - prim_name: str = eqn.primitive.name - if prim_name not in self._sub_translators: - raise NotImplementedError(f"No subtranslators known to handle '{prim_name}'.") - subtranslator = self._sub_translators[prim_name] + pname: str = eqn.primitive.name + if pname not in self._primitive_translators: + raise NotImplementedError(f"No translator known to handle '{pname}'.") + ptranslator = self._primitive_translators[pname] # Create the state into which the equation should be translated - last_term_state: dace.SDFGState = self._terminal_sdfg_state # noqa: F841 # Will be used later eqn_state = self.append_new_state( - label=f"{eqn.primitive.name}_{'_'.join(out_var_names)}", + label=f"{pname}_{'_'.join(out_var_names)}", prev_state=None, # forces terminal state to use ) # Now perform the actual translation of the equation. - new_sdfg_term_state = subtranslator( + new_sdfg_term_state = ptranslator( driver=self, in_var_names=in_var_names, - out_var_names=out_var_names, # Might be modified by the subtranslator! + out_var_names=out_var_names, # Might be modified by the translator! eqn=eqn, eqn_state=eqn_state, ) @@ -631,8 +607,8 @@ def _translate_single_eqn( raise RuntimeError("Inconsistent terminal state was detected.") new_sdfg_term_state = eqn_state - # In case a subtranslator decided to not use the variables we created for it, which is allowed - # but it must update the `out_var_names` list correctly, we will now verify this. + # In case a translator decided to not use the variables we created for it, which is + # allowed but it must update the `out_var_names` list correctly, we will now verify this. for expectedSDFGName, jax_var in zip(out_var_names, eqn.outvars, strict=True): mapped_sdfg_name = self.map_jax_var_to_sdfg(jax_var) if mapped_sdfg_name != expectedSDFGName: @@ -654,8 +630,9 @@ def _translate_jaxpr_internal( """Performs the actual translation of the Jaxpr into an SDFG. The function assumes that the context is allocated as well as the initial variables. - The function will return the internal state of `self` encapsulated inside a `TranslatedJaxprSDFG` object. - However, it will not deallocate the translation context, thus `self` and the return value share the same memory. + The function will return the internal state of `self` encapsulated inside a + `TranslatedJaxprSDFG` object. + The function will also deallocate the current context upon return. Args: jaxpr: The Jaxpr to translate. @@ -678,10 +655,10 @@ def _translate_jaxpr_internal( if nb_translated_eqn == 0: out_var_names = self._handle_null_jaxpr(jaxpr) - # Set the output names inside the context. + # Proper output names in the context. self._ctx.out_names = tuple(out_var_names) - return self._ctx + return cast("translator.TranslatedJaxprSDFG", self._clear_translation_ctx()) def _handle_null_jaxpr( self, @@ -689,9 +666,10 @@ def _handle_null_jaxpr( ) -> Sequence[str]: """This function is called in case a `Jaxpr` with zero equations is encountered. - A function with zero equation might still have output, in which case an input is copied to an output. - This function will handle the copying from the input into the corresponding output variable. - It is important that the function will remove the input and output variables from the internal mapping. + A function with zero equation might still have output, in which case an input is copied + to an output. This function will handle the copying from the input into the corresponding + output variable. It is important that the function will remove the variables that are used + as input and output from the mapping. Returns: The function returns a list denoting the SDFG variables that refers to the output. @@ -710,9 +688,9 @@ def _handle_null_jaxpr( # If we are here then we are dealing with a nested SDFG/Jaxpr, that has output. # Because an input also serves as output, the nested SDFG will have a connector for the - # input and one for the output, but both with the same name. - # This will make node validation fail. - # We have to work around this by introducing some fake copies, which will be removed by DaCe later. + # input and one for the output, but both with the same name. This will make node + # validation fail. We have to work around this by introducing some fake copies, which + # will be removed by DaCe later. for jax_out_var in jaxpr.jaxpr.outvars: # Since the output is also used as an input the variable mapping must be already known. sdfg_in_name: str = self.map_jax_var_to_sdfg(jax_out_var) @@ -735,11 +713,11 @@ def _handle_null_jaxpr( data=dace.Memlet.from_array(sdfg_in_name, self.get_array(sdfg_in_name)), ) - # A Jax variable now has, in some sense, two SDFG equivalent, the input, that was previously created by - # `self._create_initial_input()` and the `sdfg_out_name` we just created. - # But we can not add this to the mapping, because of this situation we will now remove the variable from the mapping all together. + # A Jax variable now has, in some sense, two SDFG equivalent, the input, that + # was previously created by `self._create_initial_input()` and the `sdfg_out_name` + # we just created. But we can not add this to the mapping, because of this situation + # we will now remove the variable from the mapping all together. # I am open for different approaches. - # Note that input variables that are not used as outputs, will remain in the mapping. self._jax_name_map.pop(jax_out_var) return tuple(out_var_names) diff --git a/src/jace/translator/primitive_translator.py b/src/jace/translator/primitive_translator.py index df52f90..8e18a27 100644 --- a/src/jace/translator/primitive_translator.py +++ b/src/jace/translator/primitive_translator.py @@ -4,14 +4,7 @@ # All rights reserved. # # SPDX-License-Identifier: BSD-3-Clause -"""Contains the interface for all primitive subtranslators. - -Note the name of this file is because it has to be the first that is imported in the `__init__.py` file. -If not, we would get a cyclic import error. -However, all attempts to prevent ruff from mindlessly (rule abiding) destroying this orders failed. -Thus the name was changed to enforce this. -If you have the solution, feel free to implement it. -""" +"""Contains the interface for all primitive translators.""" from __future__ import annotations @@ -31,9 +24,6 @@ class PrimitiveTranslatorCallable(Protocol): """Callable version of the primitive translators. Used for type annotation purposes, classes should be derived from `PrimitiveTranslator` instead. - - Todo: - - This split information `__call__()` should be documented in `PrimitiveTranslator` instead and not here. """ __slots__ = () @@ -51,33 +41,30 @@ def __call__( Before the driver calls this function it will perform the following preparatory tasks: - - It will allocate the SDFG variables that are used as outputs. - Their names will be passed through the `out_var_names` argument, - in the same order as `eqn.outvars`. - - It will collect the names of the SDFG variables that are used as input - and place them in `in_var_names`, in the same order as `eqn.invars`. - If an input argument refers to a literal no SDFG variable is created - for it and `None` is passed to indicate this. - - The subtranslator will create variables that are used as output. - They are passed as `out_var_names`, same order as in the equation. - - The driver will create a new terminal state and pass it as - `eqn_state` argument. This state is guaranteed to be empty and - `translator.terminal_sdfg_state is eqn_state` holds. - - Then the subtranslator is called. - Usually a subtranslator should construct the dataflow graph inside `eqn_state`. - It is allowed that the subtranslators creates more states if needed, but this state machinery - has to have a single terminal state, which must be returned and reachable from `eqn_state`. - If the function returns `None` the driver will assume that subtranslator was able to - fully construct the dataflow graph within `eqn_state`. - - While a subtranslator is forbidden from meddling with the input variables mentioned in - `in_var_names` in any way, it is allowed to modify the output variables. - For example it could create a new SDFG variable, with different strides. - But in that case the subtranslator must update the internal mapping of the driver TBA HOW, - and modify the mapping specified by `out_var_names`. - However, the subtranslator is allowed to create internal temporary variables. - It just have to ensure that no name collision will occur, a way to do this is to use a passed variable name as prefix. + - It will allocate the SDFG variables that are used as outputs. Their names will be passed + through the `out_var_names` argument, in the same order as `eqn.outvars`. + - It will collect the names of the SDFG variables that are used as input and place them in + `in_var_names`, in the same order as `eqn.invars`. If an input argument refers to a + literal no SDFG variable is created for it and `None` is passed to indicate this. + - The driver will create variables that are used as output. They are passed as + `out_var_names`, same order as in the equation. + - The driver will create a new terminal state and pass it as `eqn_state` argument. This + state is guaranteed to be empty and `translator.terminal_sdfg_state is eqn_state` holds. + + Then the primitive translator is called. + Usually a primitive translator should construct the dataflow graph inside `eqn_state`. + It is allowed that the primitive translators creates more states if needed, but this + state machinery has to have a single terminal state, which must be returned and reachable + from `eqn_state`. If the function returns `None` the driver will assume that primitive + translator was able to fully construct the dataflow graph within `eqn_state`. + + While a primitive translator is forbidden from meddling with the input variables mentioned + in `in_var_names` in any way, it is allowed to modify the output variables. For example + it could create a new SDFG variable, with different strides. But in that case the primitive + translator must update the internal mapping of the driver TBA HOW, and modify the mapping + specified by `out_var_names`. However, the subtranslator is allowed to create internal + temporary variables. It just have to ensure that no name collision will occur, a way to + do this is to use a passed variable name as prefix. Args: driver: The driver object of the translation. @@ -94,16 +81,17 @@ def __call__( @runtime_checkable class PrimitiveTranslator(PrimitiveTranslatorCallable, Protocol): - """Interface for all Jax primitive translators, also known as subtranslator, that are implemented as class. + """Interface for all Jax primitive translators. A translator for a primitive translates a single equation of a Jaxpr into its SDFG equivalent. For satisfying this interface a concrete implementation must be immutable after construction. - Subtranslators are simple, but highly specialized objects that are only able to perform the translation of a single primitive. - The overall translation process itself is managed by a driver object, which also owns and manage the subtranslators. - In the end this implements the delegation pattern. + Primitive translators are simple, but highly specialized objects that are only able to perform + the translation of a single primitive. The overall translation process itself is managed by a + driver object, which also owns and manage the primitive translators. In the end this implements + the delegation pattern. - You can use `jace.translator.add_subtranslator()` to register your translator to Jace. + You can use `jace.translator.register_primitive_translator()` to register your translator to Jace. """ __slots__ = () diff --git a/tests/test_jaxpr_translator_driver.py b/tests/test_jaxpr_translator_driver.py index 96a7419..c896f8a 100644 --- a/tests/test_jaxpr_translator_driver.py +++ b/tests/test_jaxpr_translator_driver.py @@ -42,7 +42,7 @@ def translation_driver(): """Returns an allocated driver instance.""" name = "fixture_driver" driver = translator.JaxprTranslationDriver( - sub_translators=translator.get_regsitered_primitive_translators() + primitive_translators=translator.get_regsitered_primitive_translators() ) driver._allocate_translation_ctx(name=name) return driver @@ -54,7 +54,7 @@ def test_driver_alloc() -> None: Does not use the fixture because it does it on its own. """ driver = translator.JaxprTranslationDriver( - sub_translators=translator.get_regsitered_primitive_translators() + primitive_translators=translator.get_regsitered_primitive_translators() ) assert not driver.is_allocated(), "Driver was created allocated." assert len(driver._ctx_stack) == 0 @@ -221,7 +221,7 @@ def test_driver_nested(translation_driver: translator.JaxprTranslationDriver) -> with pytest.raises( expected_exception=ValueError, match=re.escape( - f"Tried to create the mapping '{array1} -> {name_1}', but the variable is already mapped." + f"Cannot change the mapping of '{array1}' from '{name_1}' to '{name_1}'." ), ): _ = translation_driver.add_array(array1, update_var_mapping=True) @@ -307,7 +307,7 @@ def test_driver_variable_multiple_variables( with pytest.raises( expected_exception=ValueError, match=re.escape( - f"Tried to create the mapping '{array1} -> {prefix_expected_name}', but the variable is already mapped." + f"Cannot change the mapping of '{array1}' from '{translation_driver.map_jax_var_to_sdfg(array1)}' to '{prefix_expected_name}'." ), ): _ = translation_driver.add_array(array1, update_var_mapping=True, name_prefix=prefix)