Skip to content

Commit

Permalink
Fixed an issue.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Apr 19, 2024
1 parent 16f9adf commit d2bb87b
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/jace/translator/jace_subtranslator_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self,
*args,
**kwargs,
):
) -> None:
"""Initialize the interface.
It is required that subclasses calls this method during initialization.
Expand Down
23 changes: 11 additions & 12 deletions src/jace/translator/jaxpr_translator_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,15 +389,14 @@ def is_allocated(self) -> bool:
small_ctx: Sequence[Any] = [
getattr(self, x) for x in self.__shared_slots__ if x != "_reserved_names"
]
if all((x is None) for x in small_ctx):
if all((x is not None) for x in small_ctx):
if self._reserved_names is None:
raise RuntimeError(
"Invalid allocation state: All context variables except the reserved name list are allocated."
)
return True
elif all((x is not None) for x in small_ctx):
elif all((x is None) for x in small_ctx):
return False

raise RuntimeError("Invalid allocation state: Translation context is mixed allocated.")

def is_head_translator(self) -> bool:
Expand All @@ -416,7 +415,7 @@ def same_family(
They belong to the same family if they descend from the same head translator.
"""
if not isinstance(other, JaxprTranslationDriver):
return NotImplemented
return NotImplemented # type: ignore[unreachable]
if all(getattr(self, x) is getattr(self, x) for x in self.__shared_slots__):
assert (self if (self._rev_idx < other._rev_idx) else other).is_allocated()
return True
Expand Down Expand Up @@ -597,7 +596,8 @@ def _add_array(
)
alt_name = jutil.get_jax_var_name(arg)
if name_prefix is not None:
assert isinstance(name_prefix, str) and (len(name_prefix) > 0)
assert isinstance(name_prefix, str)
assert len(name_prefix) > 0
if alt_name is not None:
raise ValueError("Specified 'name_prefix' and 'alt_name' which is not possible.")

Expand Down Expand Up @@ -879,12 +879,10 @@ def _allocate_translation_ctx(
# Handle the `reserved_names` argument as described above.
# This is essentially needed that children works properly.
if self._reserved_names is None:
self._reserved_names = set()
elif isinstance(self._reserved_names, set):
assert not self.is_head_translator()
assert all(isinstance(x, str) for x in self._reserved_names)
self._reserved_names = set() # type: ignore[unreachable]
else:
raise RuntimeError("The reserved names are allocated incorrectly.")
assert all(isinstance(x, str) for x in self._reserved_names) # type: ignore[unreachable]
self._add_reserved_names(reserved_names)

return self
Expand All @@ -900,7 +898,7 @@ def _init_sub_translators(
"""
if isinstance(self._sub_translators, dict):
raise RuntimeError("Tried to allocate the internal subtranslators twice.")
assert self._sub_translators is None
assert self._sub_translators is None # type: ignore[unreachable]

# We might get arguments that starts with an underscore, which are not meant for the subtranslators.
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
Expand Down Expand Up @@ -1020,7 +1018,8 @@ def _translate_single_eqn(
While `jaxpr` must be the closed version, `eqn` must come from the unclosed version.
The function will also perform some consistency checking.
"""
assert isinstance(eqn, jcore.JaxprEqn) and isinstance(jaxpr, jcore.ClosedJaxpr)
assert isinstance(eqn, jcore.JaxprEqn)
assert isinstance(jaxpr, jcore.ClosedJaxpr)

if len(eqn.effects) != 0:
raise NotImplementedError(f"Equation '{eqn}' had side effects.")
Expand Down Expand Up @@ -1191,7 +1190,7 @@ def _handle_null_jaxpr(
self.map_jax_var_to_sdfg(jax_out_var) for jax_out_var in jaxpr.jaxpr.outvars
)
raise NotImplementedError("Please test me.")
return self
return self # type: ignore[unreachable] # reminder
#
assert self._term_sdfg_state is self._init_sdfg_state
assert len(self._sdfg_in_names) > 0
Expand Down
3 changes: 1 addition & 2 deletions src/jace/translator/sub_translators/alu_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,7 @@ def translate_jaxeqn(
has_some_literals = any(x is None for x in in_var_names)
only_literals_as_inputs = all(x is None for x in in_var_names)
inps_same_shape = all(
eqn.invars[0].aval.shape == eqn.invars[i].aval.shape
for i in range(1, len(eqn.invars))
eqn.invars[0].aval.shape == eqn.invars[i].aval.shape for i in range(1, len(eqn.invars))
)

# We will now look which dimensions have to be broadcasted on which operator.
Expand Down
9 changes: 2 additions & 7 deletions src/jace/util/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,16 @@ def is_str(
"""
if len(args) == 0:
return False

elif allow_empty:
for x in args:
if not isinstance(x, str):
return False # Not a string
# end for(x):
else:
for x in args:
if not isinstance(x, str):
return False # Not a string
return False
if len(x) == 0:
return False # A string but empty; and check enabled
# end for(x):
# end if:

return False
return True


Expand Down

0 comments on commit d2bb87b

Please sign in to comment.