-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement multi-output Elemwise in Numba via guvectorize #1271
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -162,6 +162,53 @@ def create_vectorize_func( | |
return elemwise_fn | ||
|
||
|
||
def create_guvectorize_func( | ||
scalar_op_fn: Callable, | ||
node: Apply, | ||
identity: Optional[Any] = None, | ||
**kwargs, | ||
) -> Callable: | ||
r"""Create a guvectorized Numba function from a `Apply`\s Python function.""" | ||
|
||
signature_ = create_numba_signature(node, force_scalar=False) | ||
signature = [(*signature_.args, *signature_.return_type.types)] | ||
|
||
target = ( | ||
getattr(node.tag, "numba__vectorize_target", None) | ||
or config.numba__vectorize_target | ||
) | ||
|
||
layout = f"{','.join(('()',) * len(node.inputs))}->{','.join(('()',) * len(node.outputs))}" | ||
print(f"{signature=}, {layout=}") | ||
numba_guvectorized_fn = numba.guvectorize( | ||
signature, | ||
layout, | ||
identity=identity, | ||
target=target, | ||
fastmath=config.numba__fastmath, | ||
) | ||
|
||
input_names = [f"i{i}" for i in range(len(node.inputs))] | ||
output_names = [f"o{i}" for i in range(len(node.outputs))] | ||
gu_fn_name = "gu_func" | ||
|
||
gu_fn_src = f""" | ||
def {gu_fn_name}({', '.join(input_names)}, {', '.join(output_names)}): | ||
for i in range({input_names[0]}.shape[0]): | ||
{'[i], '.join(output_names)}[i] = scalar_op_fn({'[i], '.join(input_names)}[i]) | ||
""" | ||
print(gu_fn_src) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This basically creates a function that looks like: def gu_func(i0, i1, ..., iN, o0, o1, ..., oN):
for i in range(i0.shape[0]):
o0[i], o1[i], ..., oN[i] = scalar_op_fn(i0[i], i1[i], ..., iN[i]) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also only seems to work for vector inputs. Am I supposed to do a nested loop for every dimension, or is there a shortcut/helper I can use? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is outdated now! I shouldn't need the loop at all |
||
|
||
gu_fn_inner = compile_function_src( | ||
gu_fn_src, gu_fn_name, {"scalar_op_fn": scalar_op_fn, **globals()} | ||
) | ||
|
||
gu_fn = numba_guvectorized_fn(gu_fn_inner) | ||
# gu_fn.py_scalar_func = py_scalar_func | ||
|
||
return gu_fn | ||
|
||
|
||
def create_axis_reducer( | ||
scalar_op: Op, | ||
identity: Union[np.ndarray, Number], | ||
|
@@ -426,7 +473,10 @@ def axis_apply_fn(x): | |
def numba_funcify_Elemwise(op, node, **kwargs): | ||
|
||
scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs) | ||
elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False) | ||
if node.outputs == 1: | ||
elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False) | ||
else: | ||
elemwise_fn = create_guvectorize_func(scalar_op_fn, node) | ||
elemwise_fn_name = elemwise_fn.__name__ | ||
|
||
if op.inplace_pattern: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ def fgraph_convert(self, fgraph, **kwargs): | |
return numba_funcify(fgraph, **kwargs) | ||
|
||
def jit_compile(self, fn): | ||
return fn | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My test errors out when the jitting of the whole function is attempted:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably related to numba/numba#2089 |
||
import numba | ||
|
||
jitted_fn = numba.njit(fn) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -233,7 +233,7 @@ def assert_fn(x, y): | |
numba_res = aesara_numba_fn(*inputs) | ||
|
||
# Get some coverage | ||
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) | ||
# eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This fails for reasons I haven't explored. It seems like it may need special logic for the multi-output Elemwises? |
||
|
||
if len(fn_outputs) > 1: | ||
for j, p in zip(numba_res, py_res): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not familiar with the auto-naming strategy we have going on with Numba, are there any developer docs I can use as a reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should only need to make sure that the names we generate are fixed and don't clobber each other. As long as you're generating all the names yourself everything should be fine; it's usually when you're using unknown names provided by something else that problems start to arise.
On a very related note, if we want certain type of caching to work (e.g. the kind that's based on hashes of source code), we'll need to clean up some old code that uses
Variable.name
, and anything else that could differ between equivalent graphs. Since most of the unique name-based code was used to avoidVariable.name
issues, we can probably drop all of it now. In summary, it might be useful for debugging and readability, but it's not necessary and it can negatively affect caching, so don't worry about it.