Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement multi-output Elemwise in Numba via guvectorize #1271

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion aesara/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,53 @@ def create_vectorize_func(
return elemwise_fn


def create_guvectorize_func(
scalar_op_fn: Callable,
node: Apply,
identity: Optional[Any] = None,
**kwargs,
) -> Callable:
r"""Create a guvectorized Numba function from a `Apply`\s Python function."""

signature_ = create_numba_signature(node, force_scalar=False)
signature = [(*signature_.args, *signature_.return_type.types)]

target = (
getattr(node.tag, "numba__vectorize_target", None)
or config.numba__vectorize_target
)

layout = f"{','.join(('()',) * len(node.inputs))}->{','.join(('()',) * len(node.outputs))}"
print(f"{signature=}, {layout=}")
numba_guvectorized_fn = numba.guvectorize(
signature,
layout,
identity=identity,
target=target,
fastmath=config.numba__fastmath,
)

input_names = [f"i{i}" for i in range(len(node.inputs))]
output_names = [f"o{i}" for i in range(len(node.outputs))]
gu_fn_name = "gu_func"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not familiar with the auto-naming strategy we have going on with Numba, are there any developer docs I can use as a reference?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only need to make sure that the names we generate are fixed and don't clobber each other. As long as you're generating all the names yourself everything should be fine; it's usually when you're using unknown names provided by something else that problems start to arise.

On a very related note, if we want certain type of caching to work (e.g. the kind that's based on hashes of source code), we'll need to clean up some old code that uses Variable.name, and anything else that could differ between equivalent graphs. Since most of the unique name-based code was used to avoid Variable.name issues, we can probably drop all of it now. In summary, it might be useful for debugging and readability, but it's not necessary and it can negatively affect caching, so don't worry about it.


gu_fn_src = f"""
def {gu_fn_name}({', '.join(input_names)}, {', '.join(output_names)}):
for i in range({input_names[0]}.shape[0]):
{'[i], '.join(output_names)}[i] = scalar_op_fn({'[i], '.join(input_names)}[i])
"""
print(gu_fn_src)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This basically creates a function that looks like:

def gu_func(i0, i1, ..., iN, o0, o1, ..., oN):
  for i in range(i0.shape[0]):
    o0[i], o1[i], ..., oN[i] = scalar_op_fn(i0[i], i1[i], ..., iN[i])

Copy link
Contributor Author

@ricardoV94 ricardoV94 Oct 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also only seems to work for vector inputs. Am I supposed to do a nested loop for every dimension, or is there a shortcut/helper I can use?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is outdated now! I shouldn't need the loop at all


gu_fn_inner = compile_function_src(
gu_fn_src, gu_fn_name, {"scalar_op_fn": scalar_op_fn, **globals()}
)

gu_fn = numba_guvectorized_fn(gu_fn_inner)
# gu_fn.py_scalar_func = py_scalar_func

return gu_fn


def create_axis_reducer(
scalar_op: Op,
identity: Union[np.ndarray, Number],
Expand Down Expand Up @@ -426,7 +473,10 @@ def axis_apply_fn(x):
def numba_funcify_Elemwise(op, node, **kwargs):

scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs)
elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False)
if node.outputs == 1:
elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False)
else:
elemwise_fn = create_guvectorize_func(scalar_op_fn, node)
elemwise_fn_name = elemwise_fn.__name__

if op.inplace_pattern:
Expand Down
1 change: 1 addition & 0 deletions aesara/link/numba/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def fgraph_convert(self, fgraph, **kwargs):
return numba_funcify(fgraph, **kwargs)

def jit_compile(self, fn):
return fn
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My test errors out when the jitting of the whole function is attempted:

E           numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
E           Untyped global name 'gu_func': Cannot determine Numba type of <class 'numba.np.ufunc.gufunc.GUFunc'>
E           
E           File "../../../../../../tmp/tmphopvpths", line 3:
E           def numba_funcified_fgraph(tensor_variable):
E               <source elided>
E                   # Elemwise{Composite{exp(i0), log(i0)}}(<TensorType(float64, (None,))>)
E               tensor_variable_1, tensor_variable_2 = gu_func(tensor_variable)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably related to numba/numba#2089

import numba

jitted_fn = numba.njit(fn)
Expand Down
2 changes: 1 addition & 1 deletion tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def assert_fn(x, y):
numba_res = aesara_numba_fn(*inputs)

# Get some coverage
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
# eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails for reasons I haven't explored. It seems like it may need special logic for the multi-output Elemwises?


if len(fn_outputs) > 1:
for j, p in zip(numba_res, py_res):
Expand Down
17 changes: 17 additions & 0 deletions tests/link/numba/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pytest

import aesara.scalar as aes
import aesara.tensor as at
import aesara.tensor.inplace as ati
import aesara.tensor.math as aem
Expand All @@ -13,6 +14,7 @@
from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph
from aesara.tensor import elemwise as at_elemwise
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
from tests.link.numba.test_basic import (
compare_numba_and_py,
Expand Down Expand Up @@ -111,6 +113,21 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
compare_numba_and_py(out_fg, input_vals)


def test_multioutput_elemwise():
scalar_inp = aes.float64()
scalar_out1 = aes.exp(scalar_inp)
scalar_out2 = aes.log(scalar_inp)
scalar_composite = aes.Composite([scalar_inp], [scalar_out1, scalar_out2])

tensor_inp = at.dvector()
tensor_outs = Elemwise(scalar_composite)(tensor_inp)

out_fg = FunctionGraph([tensor_inp], tensor_outs)

print("")
compare_numba_and_py(out_fg, [np.r_[1.0, 2.0, 3.5]])


@pytest.mark.parametrize(
"v, new_order",
[
Expand Down