diff --git a/warp/codegen.py b/warp/codegen.py index 7db6ecc5a..7e0daefb8 100644 --- a/warp/codegen.py +++ b/warp/codegen.py @@ -1659,6 +1659,9 @@ def emit_Name(adj, node): # lookup symbol, if it has already been assigned to a variable then return the existing mapping if node.id in adj.symbols: return adj.symbols[node.id] + # Check if the node has a warp_func attribute + if hasattr(node, "warp_func"): + return node.warp_func obj = adj.resolve_external_reference(node.id) @@ -2323,6 +2326,18 @@ def emit_Assign(adj, node): raise WarpCodegenError( "Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead." ) + elif isinstance(lhs, ast.Name): + # symbol name + name = lhs.id + + # evaluate rhs + rhs = adj.eval(node.value) + + # Check if rhs is a function object + if isinstance(rhs, warp.context.Function): + # Assign the function directly to the symbol table + adj.symbols[name] = rhs + return # handle the case where we are assigning multiple output variables if isinstance(lhs, ast.Tuple): diff --git a/warp/tests/test_codegen.py b/warp/tests/test_codegen.py index 6beb4d542..9f9b16080 100644 --- a/warp/tests/test_codegen.py +++ b/warp/tests/test_codegen.py @@ -577,6 +577,26 @@ def test_while_condition_eval(): it.valid = False +def test_function_assignment(test, device): + @wp.func + def multiply_by_two(a: float): + return a * 2.0 + + @wp.kernel + def kernel_function_assignment(input: wp.array(dtype=wp.float32), output: wp.array(dtype=wp.float32)): + tid = wp.tid() + func = multiply_by_two # Assign function to variable + output[tid] = func(input[tid]) # Call function through variable + + input_data = wp.array([1.0, 2.0, 3.0], dtype=wp.float32, device=device) + output_data = wp.empty_like(input_data) + wp.launch( + kernel_function_assignment, dim=input_data.size, inputs=[input_data], outputs=[output_data], device=device + ) + expected_output = np.array([2.0, 4.0, 6.0], dtype=np.float32) + assert_np_equal(output_data.numpy(), expected_output) + + class TestCodeGen(unittest.TestCase): pass @@ -719,7 +739,7 @@ class TestCodeGen(unittest.TestCase): name="test_error_mutating_constant_in_dynamic_loop", devices=devices, ) - +add_function_test(TestCodeGen, func=test_function_assignment, name="test_function_assignment", devices=devices) add_kernel_test(TestCodeGen, name="test_call_syntax", kernel=test_call_syntax, dim=1, devices=devices) add_kernel_test(TestCodeGen, name="test_shadow_builtin", kernel=test_shadow_builtin, dim=1, devices=devices) add_kernel_test(TestCodeGen, name="test_while_condition_eval", kernel=test_while_condition_eval, dim=1, devices=devices)