Skip to content

Commit

Permalink
Merge branch 'lwawrzyniak/fix-adj-print' into 'main'
Browse files Browse the repository at this point in the history
Fix adjoint print function for various data types

See merge request omniverse/warp!779
  • Loading branch information
nvlukasz committed Oct 15, 2024
2 parents 5b1ae66 + 545c338 commit 8e2ef04
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 17 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

- Relax the integer types expected when indexing arrays.
- Promote the `wp.Int`, `wp.Float`, and `wp.Scalar` generic annotation types to the public API.
- Make the output of `wp.print()` in backward kernels consistent for all supported data types.

### Fixed

Expand All @@ -24,6 +25,8 @@
- Fix caching of kernels with static expressions.
- Fix `ModelBuilder.add_builder(builder)` to correctly update `articulation_start` and thereby `articulation_count` when `builder` contains more than one articulation.
- Re-introduced the `wp.rand*()`, `wp.sample*()`, and `wp.poisson()` onto the Python scope to revert a breaking change.
- Fix printing vector and matrix adjoints in backward kernels.
- Fix kernel compile error when printing structs.

## [1.4.0] - 2024-10-01

Expand Down
75 changes: 58 additions & 17 deletions warp/native/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -1575,32 +1575,73 @@ inline CUDA_CALLABLE void print(transform_t<Type> t)
printf("(%g %g %g) (%g %g %g %g)\n", float(t.p[0]), float(t.p[1]), float(t.p[2]), float(t.q.x), float(t.q.y), float(t.q.z), float(t.q.w));
}

inline CUDA_CALLABLE void adj_print(int i, int adj_i) { printf("%d adj: %d\n", i, adj_i); }
inline CUDA_CALLABLE void adj_print(float f, float adj_f) { printf("%g adj: %g\n", f, adj_f); }
inline CUDA_CALLABLE void adj_print(short f, short adj_f) { printf("%hd adj: %hd\n", f, adj_f); }
inline CUDA_CALLABLE void adj_print(long f, long adj_f) { printf("%ld adj: %ld\n", f, adj_f); }
inline CUDA_CALLABLE void adj_print(long long f, long long adj_f) { printf("%lld adj: %lld\n", f, adj_f); }
inline CUDA_CALLABLE void adj_print(unsigned f, unsigned adj_f) { printf("%u adj: %u\n", f, adj_f); }
inline CUDA_CALLABLE void adj_print(unsigned short f, unsigned short adj_f) { printf("%hu adj: %hu\n", f, adj_f); }
inline CUDA_CALLABLE void adj_print(unsigned long f, unsigned long adj_f) { printf("%lu adj: %lu\n", f, adj_f); }
inline CUDA_CALLABLE void adj_print(unsigned long long f, unsigned long long adj_f) { printf("%llu adj: %llu\n", f, adj_f); }
inline CUDA_CALLABLE void adj_print(half h, half adj_h) { printf("%g adj: %g\n", half_to_float(h), half_to_float(adj_h)); }
inline CUDA_CALLABLE void adj_print(double f, double adj_f) { printf("%g adj: %g\n", f, adj_f); }
template<typename T>
inline CUDA_CALLABLE void adj_print(const T& x, const T& adj_x)
{
printf("adj: <type without print implementation>\n");
}

// note: adj_print() only prints the adjoint value, since the value itself gets printed in replay print()
inline CUDA_CALLABLE void adj_print(half x, half adj_x) { printf("adj: %g\n", half_to_float(adj_x)); }
inline CUDA_CALLABLE void adj_print(float x, float adj_x) { printf("adj: %g\n", adj_x); }
inline CUDA_CALLABLE void adj_print(double x, double adj_x) { printf("adj: %g\n", adj_x); }

inline CUDA_CALLABLE void adj_print(signed char x, signed char adj_x) { printf("adj: %d\n", adj_x); }
inline CUDA_CALLABLE void adj_print(short x, short adj_x) { printf("adj: %d\n", adj_x); }
inline CUDA_CALLABLE void adj_print(int x, int adj_x) { printf("adj: %d\n", adj_x); }
inline CUDA_CALLABLE void adj_print(long x, long adj_x) { printf("adj: %ld\n", adj_x); }
inline CUDA_CALLABLE void adj_print(long long x, long long adj_x) { printf("adj: %lld\n", adj_x); }

inline CUDA_CALLABLE void adj_print(unsigned char x, unsigned char adj_x) { printf("adj: %u\n", adj_x); }
inline CUDA_CALLABLE void adj_print(unsigned short x, unsigned short adj_x) { printf("adj: %u\n", adj_x); }
inline CUDA_CALLABLE void adj_print(unsigned x, unsigned adj_x) { printf("adj: %u\n", adj_x); }
inline CUDA_CALLABLE void adj_print(unsigned long x, unsigned long adj_x) { printf("adj: %lu\n", adj_x); }
inline CUDA_CALLABLE void adj_print(unsigned long long x, unsigned long long adj_x) { printf("adj: %llu\n", adj_x); }

inline CUDA_CALLABLE void adj_print(bool x, bool adj_x) { printf("adj: %s\n", (adj_x ? "True" : "False")); }

template<unsigned Length, typename Type>
inline CUDA_CALLABLE void adj_print(vec_t<Length, Type> v, vec_t<Length, Type>& adj_v) { printf("%g %g adj: %g %g \n", v[0], v[1], adj_v[0], adj_v[1]); }
inline CUDA_CALLABLE void adj_print(const vec_t<Length, Type>& v, const vec_t<Length, Type>& adj_v)
{
printf("adj:");
for (unsigned i = 0; i < Length; i++)
printf(" %g", float(adj_v[i]));
printf("\n");
}

template<unsigned Rows, unsigned Cols, typename Type>
inline CUDA_CALLABLE void adj_print(mat_t<Rows, Cols, Type> m, mat_t<Rows, Cols, Type>& adj_m) { }
inline CUDA_CALLABLE void adj_print(const mat_t<Rows, Cols, Type>& m, const mat_t<Rows, Cols, Type>& adj_m)
{
for (unsigned i = 0; i < Rows; i++)
{
if (i == 0)
printf("adj:");
else
printf(" ");
for (unsigned j = 0; j < Cols; j++)
printf(" %g", float(adj_m.data[i][j]));
printf("\n");
}
}

template<typename Type>
inline CUDA_CALLABLE void adj_print(quat_t<Type> q, quat_t<Type>& adj_q) { printf("%g %g %g %g adj: %g %g %g %g\n", q.x, q.y, q.z, q.w, adj_q.x, adj_q.y, adj_q.z, adj_q.w); }
inline CUDA_CALLABLE void adj_print(const quat_t<Type>& q, const quat_t<Type>& adj_q)
{
printf("adj: %g %g %g %g\n", float(adj_q.x), float(adj_q.y), float(adj_q.z), float(adj_q.w));
}

template<typename Type>
inline CUDA_CALLABLE void adj_print(transform_t<Type> t, transform_t<Type>& adj_t) {}

inline CUDA_CALLABLE void adj_print(str t, str& adj_t) {}
inline CUDA_CALLABLE void adj_print(const transform_t<Type>& t, const transform_t<Type>& adj_t)
{
printf("adj: (%g %g %g) (%g %g %g %g)\n",
float(adj_t.p[0]), float(adj_t.p[1]), float(adj_t.p[2]),
float(adj_t.q.x), float(adj_t.q.y), float(adj_t.q.z), float(adj_t.q.w));
}

inline CUDA_CALLABLE void adj_print(str t, str& adj_t)
{
printf("adj: %s\n", t);
}

template <typename T>
inline CUDA_CALLABLE void expect_eq(const T& actual, const T& expected)
Expand Down
135 changes: 135 additions & 0 deletions warp/tests/test_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import sys
import unittest
from typing import Any

import warp as wp
from warp.tests.unittest_utils import *
Expand Down Expand Up @@ -126,6 +127,139 @@ def test_print_boolean(test, device):
test.assertRegex(s, rf"True{os.linesep}False{os.linesep}")


@wp.kernel
def generic_print_kernel(x: Any):
print(x)


@wp.struct
class SimpleStruct:
x: float
y: float


generic_print_types = [*wp.types.scalar_types]
for scalar_type in wp.types.scalar_types:
generic_print_types.append(wp.types.vector(2, scalar_type))
generic_print_types.append(wp.types.vector(3, scalar_type))
generic_print_types.append(wp.types.vector(4, scalar_type))
generic_print_types.append(wp.types.matrix((2, 2), scalar_type))
generic_print_types.append(wp.types.matrix((3, 3), scalar_type))
generic_print_types.append(wp.types.matrix((4, 4), scalar_type))
generic_print_types.append(wp.bool)
generic_print_types.append(SimpleStruct)
generic_print_types.append(wp.array(dtype=float))

for T in generic_print_types:
wp.overload(generic_print_kernel, [T])


def test_print_adjoint(test, device):
for scalar_type in wp.types.scalar_types:
# scalar
capture = StdOutCapture()
capture.begin()
wp.launch(
generic_print_kernel,
dim=1,
inputs=[scalar_type(17)],
adj_inputs=[scalar_type(42)],
adjoint=True,
device=device,
)
wp.synchronize_device(device)
s = capture.end()

# We skip the win32 comparison for now since the capture sometimes is an empty string
if sys.platform != "win32":
test.assertRegex(s, rf"17{os.linesep}adj: 42{os.linesep}")

for dim in (2, 3, 4):
# vector
vec_type = wp.types.vector(dim, scalar_type)
vec_data = np.arange(vec_type._length_, dtype=wp.dtype_to_numpy(scalar_type))
v = vec_type(vec_data)
adj_v = vec_type(vec_data[::-1])

capture = StdOutCapture()
capture.begin()
wp.launch(generic_print_kernel, dim=1, inputs=[v], adj_inputs=[adj_v], adjoint=True, device=device)
wp.synchronize_device(device)
s = capture.end()

# We skip the win32 comparison for now since the capture sometimes is an empty string
if sys.platform != "win32":
expected_forward = " ".join(str(int(x)) for x in v) + " "
expected_adjoint = " ".join(str(int(x)) for x in adj_v)
test.assertRegex(s, rf"{expected_forward}{os.linesep}adj: {expected_adjoint}{os.linesep}")

# matrix
mat_type = wp.types.matrix((dim, dim), scalar_type)
mat_data = np.arange(mat_type._length_, dtype=wp.dtype_to_numpy(scalar_type))
m = mat_type(mat_data)
adj_m = mat_type(mat_data[::-1])

capture = StdOutCapture()
capture.begin()
wp.launch(generic_print_kernel, dim=1, inputs=[m], adj_inputs=[adj_m], adjoint=True, device=device)
wp.synchronize_device(device)
s = capture.end()

# We skip the win32 comparison for now since the capture sometimes is an empty string
if sys.platform != "win32":
expected_forward = ""
expected_adjoint = ""
for row in range(dim):
if row == 0:
adj_prefix = "adj: "
else:
adj_prefix = " "
expected_forward += " ".join(str(int(x)) for x in m[row]) + f" {os.linesep}"
expected_adjoint += adj_prefix + " ".join(str(int(x)) for x in adj_m[row]) + f"{os.linesep}"
test.assertRegex(s, rf"{expected_forward}{expected_adjoint}")

# Booleans
capture = StdOutCapture()
capture.begin()
wp.launch(generic_print_kernel, dim=1, inputs=[True], adj_inputs=[False], adjoint=True, device=device)
wp.synchronize_device(device)
s = capture.end()

# We skip the win32 comparison for now since the capture sometimes is an empty string
if sys.platform != "win32":
test.assertRegex(s, rf"True{os.linesep}adj: False{os.linesep}")

# structs, not printable yet
capture = StdOutCapture()
capture.begin()
wp.launch(
generic_print_kernel, dim=1, inputs=[SimpleStruct()], adj_inputs=[SimpleStruct()], adjoint=True, device=device
)
wp.synchronize_device(device)
s = capture.end()

# We skip the win32 comparison for now since the capture sometimes is an empty string
if sys.platform != "win32":
test.assertRegex(
s, rf"<type without print implementation>{os.linesep}adj: <type without print implementation>{os.linesep}"
)

# arrays, not printable
capture = StdOutCapture()
capture.begin()
a = wp.ones(10, dtype=float, device=device)
adj_a = wp.zeros(10, dtype=float, device=device)
wp.launch(generic_print_kernel, dim=1, inputs=[a], adj_inputs=[adj_a], adjoint=True, device=device)
wp.synchronize_device(device)
s = capture.end()

# We skip the win32 comparison for now since the capture sometimes is an empty string
if sys.platform != "win32":
test.assertRegex(
s, rf"<type without print implementation>{os.linesep}adj: <type without print implementation>{os.linesep}"
)


class TestPrint(unittest.TestCase):
pass

Expand All @@ -134,6 +268,7 @@ class TestPrint(unittest.TestCase):
add_function_test(TestPrint, "test_print", test_print, devices=devices, check_output=False)
add_function_test(TestPrint, "test_print_numeric", test_print_numeric, devices=devices, check_output=False)
add_function_test(TestPrint, "test_print_boolean", test_print_boolean, devices=devices, check_output=False)
add_function_test(TestPrint, "test_print_adjoint", test_print_adjoint, devices=devices, check_output=False)


if __name__ == "__main__":
Expand Down

0 comments on commit 8e2ef04

Please sign in to comment.