Skip to content

Commit

Permalink
Merge pull request #7 from nvlukasz/lwawrzyniak/update-type-registration
Browse files Browse the repository at this point in the history
Update custom type handling
  • Loading branch information
nvlukasz authored Jun 27, 2024
2 parents 088ef71 + 4466ae2 commit d4dc9fb
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 13 deletions.
87 changes: 76 additions & 11 deletions warp/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,12 @@ def type_to_ctype(t, value_type=False):
return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
elif t.__name__ in ("bool", "int", "float"):
return t.__name__
# !!!
elif warp.types.type_is_external(t):
type_info = warp.types._custom_types[t]
# print(f"~!~!~ Type info: {type_info}")
type_name = type_info.native_name
return f"wp::{type_name}"
else:
return f"wp::{t.__name__}"

Expand Down Expand Up @@ -2532,16 +2538,67 @@ def scalar_value(x):
# !!! custom type constant
# FIXME: device mismatch?
elif warp.types.type_is_external(type(value)):
# NOTE: we require that the custom type has a constructor that takes all the fields in declaration order
type_ctype = type(value)._type_
type_name = type(value).__name__
# # NOTE: we require that the custom type has a constructor that takes all the fields in declaration order
# type_ctype = type(value)._type_
# type_name = type(value).__name__
# ctor_name = f"wp::{type_name}"
# field_values = []
# # TODO: this should work recursively if members are also custom types
# for field_name, field_type in type_ctype._fields_:
# field_values.append(getattr(value, field_name))
# ctor_args = ", ".join([str(v) for v in field_values])
# return f"{ctor_name}{{{ctor_args}}}"

# !!! WIP !!!

_numeric_types = {
ctypes.c_int8,
ctypes.c_int16,
ctypes.c_int32,
ctypes.c_int64,

ctypes.c_uint8,
ctypes.c_uint16,
ctypes.c_uint32,
ctypes.c_uint64,

ctypes.c_float,
ctypes.c_double,
}

type_ctype = type(value)
type_info = warp.types._custom_types[type_ctype]
type_name = type_info.native_name
ctor_name = f"wp::{type_name}"
field_values = []
# TODO: this should work recursively if members are also custom types
for field_name, field_type in type_ctype._fields_:
field_values.append(getattr(value, field_name))
ctor_args = ", ".join([str(v) for v in field_values])
return f"{ctor_name}{{{ctor_args}}}"

if type_info.has_binary_ctor:

# !!! Use binary string encoding !!!
value_bytes = bytes(value)
value_str = "".join([f"\\x{b:02x}" for b in value_bytes])
value_len = len(value_bytes)

return f"{ctor_name}{{\"{value_str}\", {value_len}}}"

else:
field_values = []
# TODO: this should work recursively if members are also custom types
for field_name, field_type in type_ctype._fields_:
field_value = getattr(value, field_name)
if warp.types.type_is_external(field_type):
# recurse
field_values.append(constant_str(field_value))
elif field_type in _numeric_types:
field_values.append(field_value)
elif field_type is ctypes.c_bool:
field_values.append("true" if field_value else "false")
else:
# TODO: handle pointers

field_values.append("{/* !!! */}")

ctor_args = ", ".join([str(v) for v in field_values])
return f"{ctor_name}{{{ctor_args}}}"

else:
# otherwise just convert constant to string
Expand Down Expand Up @@ -2662,7 +2719,11 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
if var.constant is None:
lines += [f"{var.ctype()} {var.emit()};\n"]
else:
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
# HACK: don't declare custom type values as const so that they are mutable
if warp.types.type_is_external(var.type):
lines += [f"{var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
else:
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]

# forward pass
lines += ["//---------\n"]
Expand Down Expand Up @@ -2697,7 +2758,11 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
if var.constant is None:
lines += [f"{var.ctype()} {var.emit()};\n"]
else:
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
# HACK: don't declare custom type values as const so that they are mutable
if warp.types.type_is_external(var.type):
lines += [f"{var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
else:
lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]

# dual vars
lines += ["//---------\n"]
Expand Down
8 changes: 8 additions & 0 deletions warp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4296,6 +4296,14 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):

elif isinstance(value, arg_type):
try:
# check if custom type
# print(f"~!~!~! Packing arg type {arg_type}")
type_info = warp.types._custom_types.get(arg_type)
# print(f"~!~!~ Type info: {type_info}")
if type_info is not None:
# print(f"~!~!~! Pack {value}")
return value

# try to pack as a scalar type
if arg_type is warp.types.float16:
return arg_type._type_(warp.types.float_to_half_bits(value.value))
Expand Down
3 changes: 3 additions & 0 deletions warp/native/warp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2591,6 +2591,9 @@ size_t cuda_compile_program(const char* cuda_src, int arch, const char* include_
if (fast_math)
opts.push_back("--use_fast_math");

// suppress unused variable warnings
opts.push_back("--diag-suppress=177");

std::vector<const char*> headers;
std::vector<const char*> header_names;
for (auto it = wp::jitsafe_headers_map.begin(); it != wp::jitsafe_headers_map.end(); ++it)
Expand Down
32 changes: 30 additions & 2 deletions warp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4735,7 +4735,7 @@ def type_is_generic_scalar(t):


def type_is_external(t):
return hasattr(t, "_type_") and inspect.isclass(t._type_) and issubclass(t._type_, ctypes.Structure)
return t in _custom_types


def type_matches_template(arg_type, template_type):
Expand Down Expand Up @@ -4895,7 +4895,7 @@ def get_type_code(arg_type):
raise TypeError("Invalid vector/matrix dimensionality")
# !!!
elif type_is_external(arg_type):
return arg_type.__name__
return _custom_types[arg_type].native_name
else:
# simple type
type_code = simple_type_codes.get(arg_type)
Expand Down Expand Up @@ -4950,3 +4950,31 @@ def get_signature(arg_types, func_name=None, arg_names=None):

def is_generic_signature(sig):
return "?" in sig


class TypeInfo:
def __init__(self, T, native_name, has_binary_ctor):
self.T = T
self.native_name = native_name
self.has_binary_ctor = has_binary_ctor

# !!!
_custom_types = {}


def add_type(T, native_name=None, has_binary_ctor=False):

# TODO:
# - allow for specifying a custom ctor function?
# - allow for specifying whether this can be instantiated as a constexpr

if not inspect.isclass(T) or not issubclass(T, ctypes.Structure):
raise TypeError(f"Type must be a subclass of ctypes.Structure, got {T}")

if native_name is None:
native_name = T.__name__

print(f"~!~!~! ADDING TYPE {native_name}: {T}")
type_info = TypeInfo(T, native_name, has_binary_ctor=has_binary_ctor)

_custom_types[T] = type_info

0 comments on commit d4dc9fb

Please sign in to comment.