From d2bb87bc690a5f02e2e8038c57315c3bca29f1ee Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 19 Apr 2024 09:59:04 +0200 Subject: [PATCH] Fixed an issue. --- .../jace_subtranslator_interface.py | 2 +- .../translator/jaxpr_translator_driver.py | 23 +++++++++---------- .../sub_translators/alu_translator.py | 3 +-- src/jace/util/traits.py | 9 ++------ 4 files changed, 15 insertions(+), 22 deletions(-) diff --git a/src/jace/translator/jace_subtranslator_interface.py b/src/jace/translator/jace_subtranslator_interface.py index 3d6f53b..cf9420c 100644 --- a/src/jace/translator/jace_subtranslator_interface.py +++ b/src/jace/translator/jace_subtranslator_interface.py @@ -57,7 +57,7 @@ def __init__( self, *args, **kwargs, - ): + ) -> None: """Initialize the interface. It is required that subclasses calls this method during initialization. diff --git a/src/jace/translator/jaxpr_translator_driver.py b/src/jace/translator/jaxpr_translator_driver.py index d7e9cdd..c3c266f 100644 --- a/src/jace/translator/jaxpr_translator_driver.py +++ b/src/jace/translator/jaxpr_translator_driver.py @@ -389,15 +389,14 @@ def is_allocated(self) -> bool: small_ctx: Sequence[Any] = [ getattr(self, x) for x in self.__shared_slots__ if x != "_reserved_names" ] - if all((x is None) for x in small_ctx): + if all((x is not None) for x in small_ctx): if self._reserved_names is None: raise RuntimeError( "Invalid allocation state: All context variables except the reserved name list are allocated." ) return True - elif all((x is not None) for x in small_ctx): + elif all((x is None) for x in small_ctx): return False - raise RuntimeError("Invalid allocation state: Translation context is mixed allocated.") def is_head_translator(self) -> bool: @@ -416,7 +415,7 @@ def same_family( They belong to the same family if they descend from the same head translator. """ if not isinstance(other, JaxprTranslationDriver): - return NotImplemented + return NotImplemented # type: ignore[unreachable] if all(getattr(self, x) is getattr(self, x) for x in self.__shared_slots__): assert (self if (self._rev_idx < other._rev_idx) else other).is_allocated() return True @@ -597,7 +596,8 @@ def _add_array( ) alt_name = jutil.get_jax_var_name(arg) if name_prefix is not None: - assert isinstance(name_prefix, str) and (len(name_prefix) > 0) + assert isinstance(name_prefix, str) + assert len(name_prefix) > 0 if alt_name is not None: raise ValueError("Specified 'name_prefix' and 'alt_name' which is not possible.") @@ -879,12 +879,10 @@ def _allocate_translation_ctx( # Handle the `reserved_names` argument as described above. # This is essentially needed that children works properly. if self._reserved_names is None: - self._reserved_names = set() - elif isinstance(self._reserved_names, set): - assert not self.is_head_translator() - assert all(isinstance(x, str) for x in self._reserved_names) + self._reserved_names = set() # type: ignore[unreachable] else: raise RuntimeError("The reserved names are allocated incorrectly.") + assert all(isinstance(x, str) for x in self._reserved_names) # type: ignore[unreachable] self._add_reserved_names(reserved_names) return self @@ -900,7 +898,7 @@ def _init_sub_translators( """ if isinstance(self._sub_translators, dict): raise RuntimeError("Tried to allocate the internal subtranslators twice.") - assert self._sub_translators is None + assert self._sub_translators is None # type: ignore[unreachable] # We might get arguments that starts with an underscore, which are not meant for the subtranslators. kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} @@ -1020,7 +1018,8 @@ def _translate_single_eqn( While `jaxpr` must be the closed version, `eqn` must come from the unclosed version. The function will also perform some consistency checking. """ - assert isinstance(eqn, jcore.JaxprEqn) and isinstance(jaxpr, jcore.ClosedJaxpr) + assert isinstance(eqn, jcore.JaxprEqn) + assert isinstance(jaxpr, jcore.ClosedJaxpr) if len(eqn.effects) != 0: raise NotImplementedError(f"Equation '{eqn}' had side effects.") @@ -1191,7 +1190,7 @@ def _handle_null_jaxpr( self.map_jax_var_to_sdfg(jax_out_var) for jax_out_var in jaxpr.jaxpr.outvars ) raise NotImplementedError("Please test me.") - return self + return self # type: ignore[unreachable] # reminder # assert self._term_sdfg_state is self._init_sdfg_state assert len(self._sdfg_in_names) > 0 diff --git a/src/jace/translator/sub_translators/alu_translator.py b/src/jace/translator/sub_translators/alu_translator.py index 5594497..c088fa9 100644 --- a/src/jace/translator/sub_translators/alu_translator.py +++ b/src/jace/translator/sub_translators/alu_translator.py @@ -156,8 +156,7 @@ def translate_jaxeqn( has_some_literals = any(x is None for x in in_var_names) only_literals_as_inputs = all(x is None for x in in_var_names) inps_same_shape = all( - eqn.invars[0].aval.shape == eqn.invars[i].aval.shape - for i in range(1, len(eqn.invars)) + eqn.invars[0].aval.shape == eqn.invars[i].aval.shape for i in range(1, len(eqn.invars)) ) # We will now look which dimensions have to be broadcasted on which operator. diff --git a/src/jace/util/traits.py b/src/jace/util/traits.py index ccdf19a..9d87803 100644 --- a/src/jace/util/traits.py +++ b/src/jace/util/traits.py @@ -25,21 +25,16 @@ def is_str( """ if len(args) == 0: return False - elif allow_empty: for x in args: if not isinstance(x, str): return False # Not a string - # end for(x): else: for x in args: if not isinstance(x, str): - return False # Not a string + return False if len(x) == 0: - return False # A string but empty; and check enabled - # end for(x): - # end if: - + return False return True