Skip to content

Commit

Permalink
Some changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Apr 19, 2024
1 parent 93cf776 commit 16f9adf
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
4 changes: 1 addition & 3 deletions src/jace/translator/jaxpr_translator_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,9 +597,7 @@ def _add_array(
)
alt_name = jutil.get_jax_var_name(arg)
if name_prefix is not None:
assert isinstance(name_prefix, str) and (
len(name_prefix) > 0
), f"Invalid 'name_prefix': '{name_prefix}'."
assert isinstance(name_prefix, str) and (len(name_prefix) > 0)
if alt_name is not None:
raise ValueError("Specified 'name_prefix' and 'alt_name' which is not possible.")

Expand Down
6 changes: 2 additions & 4 deletions src/jace/translator/sub_translators/alu_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,11 @@ def translate_jaxeqn(
inp_scalars = [len(Inp.aval.shape) == 0 for i, Inp in enumerate(eqn.invars)]
has_scalars_as_inputs = any(inp_scalars)
only_scalars_as_inputs = all(inp_scalars)
has_some_literals = any([x is None for x in in_var_names])
only_literals_as_inputs = all([x is None for x in in_var_names])
has_some_literals = any(x is None for x in in_var_names)
only_literals_as_inputs = all(x is None for x in in_var_names)
inps_same_shape = all(
[
eqn.invars[0].aval.shape == eqn.invars[i].aval.shape
for i in range(1, len(eqn.invars))
]
)

# We will now look which dimensions have to be broadcasted on which operator.
Expand Down

0 comments on commit 16f9adf

Please sign in to comment.