Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes a bug when assigning a function to an intermediate variable #327

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions warp/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1659,6 +1659,9 @@ def emit_Name(adj, node):
# lookup symbol, if it has already been assigned to a variable then return the existing mapping
if node.id in adj.symbols:
return adj.symbols[node.id]
# Check if the node has a warp_func attribute
if hasattr(node, "warp_func"):
return node.warp_func

obj = adj.resolve_external_reference(node.id)

Expand Down Expand Up @@ -2323,6 +2326,18 @@ def emit_Assign(adj, node):
raise WarpCodegenError(
"Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
)
elif isinstance(lhs, ast.Name):
# symbol name
name = lhs.id

# evaluate rhs
rhs = adj.eval(node.value)

# Check if rhs is a function object
if isinstance(rhs, warp.context.Function):
# Assign the function directly to the symbol table
adj.symbols[name] = rhs
return

# handle the case where we are assigning multiple output variables
if isinstance(lhs, ast.Tuple):
Expand Down
22 changes: 21 additions & 1 deletion warp/tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,26 @@ def test_while_condition_eval():
it.valid = False


def test_function_assignment(test, device):
@wp.func
def multiply_by_two(a: float):
return a * 2.0

@wp.kernel
def kernel_function_assignment(input: wp.array(dtype=wp.float32), output: wp.array(dtype=wp.float32)):
tid = wp.tid()
func = multiply_by_two # Assign function to variable
output[tid] = func(input[tid]) # Call function through variable

input_data = wp.array([1.0, 2.0, 3.0], dtype=wp.float32, device=device)
output_data = wp.empty_like(input_data)
wp.launch(
kernel_function_assignment, dim=input_data.size, inputs=[input_data], outputs=[output_data], device=device
)
expected_output = np.array([2.0, 4.0, 6.0], dtype=np.float32)
assert_np_equal(output_data.numpy(), expected_output)


class TestCodeGen(unittest.TestCase):
pass

Expand Down Expand Up @@ -719,7 +739,7 @@ class TestCodeGen(unittest.TestCase):
name="test_error_mutating_constant_in_dynamic_loop",
devices=devices,
)

add_function_test(TestCodeGen, func=test_function_assignment, name="test_function_assignment", devices=devices)
add_kernel_test(TestCodeGen, name="test_call_syntax", kernel=test_call_syntax, dim=1, devices=devices)
add_kernel_test(TestCodeGen, name="test_shadow_builtin", kernel=test_shadow_builtin, dim=1, devices=devices)
add_kernel_test(TestCodeGen, name="test_while_condition_eval", kernel=test_while_condition_eval, dim=1, devices=devices)
Expand Down