From 8f027e5868dcd3172d74563b57932f5564d1498c Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Thu, 10 Oct 2024 17:27:40 -0400 Subject: [PATCH 1/3] Fixed a bug when assigning a function to a variable in kernels --- warp/codegen.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/warp/codegen.py b/warp/codegen.py index 7db6ecc5a..a94905134 100644 --- a/warp/codegen.py +++ b/warp/codegen.py @@ -1659,6 +1659,9 @@ def emit_Name(adj, node): # lookup symbol, if it has already been assigned to a variable then return the existing mapping if node.id in adj.symbols: return adj.symbols[node.id] + # Check if the node has a warp_func attribute + if hasattr(node, 'warp_func'): + return node.warp_func obj = adj.resolve_external_reference(node.id) @@ -2323,6 +2326,18 @@ def emit_Assign(adj, node): raise WarpCodegenError( "Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead." ) + elif isinstance(lhs, ast.Name): + # symbol name + name = lhs.id + + # evaluate rhs + rhs = adj.eval(node.value) + + # Check if rhs is a function object + if isinstance(rhs, warp.context.Function): + # Assign the function directly to the symbol table + adj.symbols[name] = rhs + return # handle the case where we are assigning multiple output variables if isinstance(lhs, ast.Tuple): From a9117c989a36404ec9cd59ef469d96cbba587ed1 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Thu, 10 Oct 2024 18:08:18 -0400 Subject: [PATCH 2/3] Added test for function assignment to a variable inside kernel --- warp/tests/test_codegen.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/warp/tests/test_codegen.py b/warp/tests/test_codegen.py index 6beb4d542..e052ef658 100644 --- a/warp/tests/test_codegen.py +++ b/warp/tests/test_codegen.py @@ -576,6 +576,23 @@ def test_while_condition_eval(): while it.valid: it.valid = False +def test_function_assignment(test, device): + + @wp.func + def multiply_by_two(a: float): + return a * 2.0 + + @wp.kernel + def kernel_function_assignment(input: wp.array(dtype=wp.float32), output: wp.array(dtype=wp.float32)): + tid = wp.tid() + func = multiply_by_two # Assign function to variable + output[tid] = func(input[tid]) # Call function through variable + + input_data = wp.array([1.0, 2.0, 3.0], dtype=wp.float32, device=device) + output_data = wp.empty_like(input_data) + wp.launch(kernel_function_assignment, dim=input_data.size, inputs=[input_data], outputs=[output_data], device=device) + expected_output = np.array([2.0, 4.0, 6.0], dtype=np.float32) + assert_np_equal(output_data.numpy(), expected_output) class TestCodeGen(unittest.TestCase): pass @@ -719,6 +736,12 @@ class TestCodeGen(unittest.TestCase): name="test_error_mutating_constant_in_dynamic_loop", devices=devices, ) +add_function_test( + TestCodeGen, + func=test_function_assignment, + name="test_function_assignment", + devices=devices +) add_kernel_test(TestCodeGen, name="test_call_syntax", kernel=test_call_syntax, dim=1, devices=devices) add_kernel_test(TestCodeGen, name="test_shadow_builtin", kernel=test_shadow_builtin, dim=1, devices=devices) From acbb387bb01ece342d3404c976523cb19df0ccc9 Mon Sep 17 00:00:00 2001 From: Mehdi Ataei Date: Thu, 10 Oct 2024 18:13:57 -0400 Subject: [PATCH 3/3] Fixed formatting --- warp/codegen.py | 2 +- warp/tests/test_codegen.py | 15 ++++++--------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/warp/codegen.py b/warp/codegen.py index a94905134..7e0daefb8 100644 --- a/warp/codegen.py +++ b/warp/codegen.py @@ -1660,7 +1660,7 @@ def emit_Name(adj, node): if node.id in adj.symbols: return adj.symbols[node.id] # Check if the node has a warp_func attribute - if hasattr(node, 'warp_func'): + if hasattr(node, "warp_func"): return node.warp_func obj = adj.resolve_external_reference(node.id) diff --git a/warp/tests/test_codegen.py b/warp/tests/test_codegen.py index e052ef658..9f9b16080 100644 --- a/warp/tests/test_codegen.py +++ b/warp/tests/test_codegen.py @@ -576,8 +576,8 @@ def test_while_condition_eval(): while it.valid: it.valid = False -def test_function_assignment(test, device): +def test_function_assignment(test, device): @wp.func def multiply_by_two(a: float): return a * 2.0 @@ -590,10 +590,13 @@ def kernel_function_assignment(input: wp.array(dtype=wp.float32), output: wp.arr input_data = wp.array([1.0, 2.0, 3.0], dtype=wp.float32, device=device) output_data = wp.empty_like(input_data) - wp.launch(kernel_function_assignment, dim=input_data.size, inputs=[input_data], outputs=[output_data], device=device) + wp.launch( + kernel_function_assignment, dim=input_data.size, inputs=[input_data], outputs=[output_data], device=device + ) expected_output = np.array([2.0, 4.0, 6.0], dtype=np.float32) assert_np_equal(output_data.numpy(), expected_output) + class TestCodeGen(unittest.TestCase): pass @@ -736,13 +739,7 @@ class TestCodeGen(unittest.TestCase): name="test_error_mutating_constant_in_dynamic_loop", devices=devices, ) -add_function_test( - TestCodeGen, - func=test_function_assignment, - name="test_function_assignment", - devices=devices -) - +add_function_test(TestCodeGen, func=test_function_assignment, name="test_function_assignment", devices=devices) add_kernel_test(TestCodeGen, name="test_call_syntax", kernel=test_call_syntax, dim=1, devices=devices) add_kernel_test(TestCodeGen, name="test_shadow_builtin", kernel=test_shadow_builtin, dim=1, devices=devices) add_kernel_test(TestCodeGen, name="test_while_condition_eval", kernel=test_while_condition_eval, dim=1, devices=devices)