Skip to content

Commit

Permalink
Implement extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
goffrie committed Feb 12, 2024
1 parent 3e4510b commit 4b7a934
Show file tree
Hide file tree
Showing 10 changed files with 793 additions and 47 deletions.
161 changes: 133 additions & 28 deletions pb-jelly-gen/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,7 @@ def escape_name(s: str) -> str:
# https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto
SourceCodeLocation = List[int]
ProtoTypes = Union[FileDescriptorProto, EnumDescriptorProto, DescriptorProto]
WalkRet = Tuple[
List[Tuple[List[Text], EnumDescriptorProto, SourceCodeLocation]],
List[Tuple[List[Text], DescriptorProto, SourceCodeLocation]],
]
ModTree = DefaultDict[Text, DefaultDict[Text, Any]]
ModTree = DefaultDict[Text, "ModTree"]


T = TypeVar("T")
Expand Down Expand Up @@ -223,7 +219,7 @@ def __init__(
self,
ctx: "Context",
proto_file: FileDescriptorProto,
msg_type: DescriptorProto,
msg_type: Optional[DescriptorProto],
field: FieldDescriptorProto,
) -> None:
self.ctx = ctx
Expand All @@ -234,6 +230,7 @@ def __init__(
self.oneof = (
field.HasField("oneof_index")
and not field.proto3_optional
and msg_type is not None
and msg_type.oneof_decl[field.oneof_index]
)

Expand Down Expand Up @@ -784,7 +781,7 @@ def write_comments(self, sci_loc: Optional[SourceCodeInfo.Location]) -> None:
self.write_line_broken_text_with_prefix(sci_loc.trailing_comments, "///")

def rust_type(
self, msg_type: DescriptorProto, field: FieldDescriptorProto
self, msg_type: Optional[DescriptorProto], field: FieldDescriptorProto
) -> RustType:
return RustType(self.ctx, self.proto_file, msg_type, field)

Expand Down Expand Up @@ -971,6 +968,11 @@ def gen_msg(
name = "_".join(path + [msg_type.name])
escaped_name = escape_name(name)

preserve_unrecognized = msg_type.options.Extensions[
extensions_pb2.preserve_unrecognized
]
has_extensions = len(msg_type.extension_range) > 0

oneof_fields: DefaultDict[Text, List[FieldDescriptorProto]] = defaultdict(list)
proto3_optional_synthetic_oneofs: Set[int] = {
field.oneof_index for field in msg_type.field if field.proto3_optional
Expand Down Expand Up @@ -1024,9 +1026,12 @@ def gen_msg(
% (escape_name(oneof.name), oneof_msg_name(name, oneof))
)

if msg_type.options.Extensions[extensions_pb2.preserve_unrecognized]:
if preserve_unrecognized:
self.write("pub _unrecognized: Vec<u8>,")

if has_extensions:
self.write("pub _extensions: ::pb_jelly::Unrecognized,")

# Generate any oneof enum structs
for oneof in oneof_decls:
self.write("#[derive(%s)]" % ", ".join(sorted(derives)))
Expand Down Expand Up @@ -1124,10 +1129,10 @@ def gen_msg(
self.write(
"%s: %s," % (escape_name(oneof.name), typ.default(name))
)
if msg_type.options.Extensions[
extensions_pb2.preserve_unrecognized
]:
if preserve_unrecognized:
self.write("_unrecognized: Vec::new(),")
if has_extensions:
self.write("_extensions: ::pb_jelly::Unrecognized::default(),")

with block(self, "lazy_static!"):
self.write(
Expand Down Expand Up @@ -1205,10 +1210,7 @@ def gen_msg(
self.write('name: "%s",' % oneof.name)

with block(self, "fn compute_size(&self) -> usize"):
if (
len(msg_type.field) > 0
or msg_type.options.Extensions[extensions_pb2.preserve_unrecognized]
):
if len(msg_type.field) > 0 or preserve_unrecognized or has_extensions:
self.write("let mut size = 0;")
for field in msg_type.field:
typ = self.rust_type(msg_type, field)
Expand Down Expand Up @@ -1263,10 +1265,10 @@ def gen_msg(
% field.name
)
self.write("size += %s_size;" % field.name)
if msg_type.options.Extensions[
extensions_pb2.preserve_unrecognized
]:
if preserve_unrecognized:
self.write("size += self._unrecognized.len();")
if has_extensions:
self.write("size += self._extensions.compute_size();")
self.write("size")
else:
self.write("0")
Expand Down Expand Up @@ -1341,17 +1343,16 @@ def gen_msg(
)
self.write("::pb_jelly::varint::write(l as u64, w)?;")
self.write("::pb_jelly::Message::serialize(val, w)?;")
if msg_type.options.Extensions[extensions_pb2.preserve_unrecognized]:
if preserve_unrecognized:
self.write("w.write_all(&self._unrecognized)?;")
if has_extensions:
self.write("self._extensions.serialize(w)?;")
self.write("Ok(())")

with block(
self,
"fn deserialize<B: ::pb_jelly::PbBufferReader>(&mut self, mut buf: &mut B) -> ::std::io::Result<()>",
):
preserve_unrecognized = msg_type.options.Extensions[
extensions_pb2.preserve_unrecognized
]
if preserve_unrecognized:
self.write(
"let mut unrecognized = ::pb_jelly::Unrecognized::default();"
Expand Down Expand Up @@ -1461,6 +1462,15 @@ def gen_msg(
"self.%s = %s;"
% (escape_name(field.name), field_val)
)
if has_extensions:
pattern = " | ".join(
"{}..={}".format(r.start, r.end - 1)
for r in msg_type.extension_range
)
with block(self, pattern + " =>"):
self.write(
"self._extensions.gather(field_number, typ, &mut buf)?;"
)
with block(self, "_ =>"):
if preserve_unrecognized:
self.write(
Expand Down Expand Up @@ -1488,7 +1498,12 @@ def gen_msg(
)

if preserve_unrecognized:
self.write("unrecognized.serialize(&mut self._unrecognized)?;")
self.write(
"self._unrecognized.reserve(unrecognized.compute_size());"
)
self.write(
"unrecognized.serialize(&mut std::io::Cursor::new(&mut self._unrecognized))?;"
)
self.write("Ok(())")

with block(self, "impl ::pb_jelly::Reflection for " + name):
Expand Down Expand Up @@ -1594,9 +1609,58 @@ def gen_msg(
with block(self, "_ =>"):
self.write('panic!("unknown field name given")')

if has_extensions:
with block(self, "impl ::pb_jelly::extensions::Extensible for " + name):
with block(
self,
"fn _extensions(&self) -> &::pb_jelly::Unrecognized",
):
self.write("&self._extensions")

def gen_extension(
self,
path: List[Text],
extension_field: FieldDescriptorProto,
scl: SourceCodeLocation,
) -> None:
crate, mod_parts = self.ctx.crate_from_proto_filename(self.proto_file.name)

self.write_comments(self.source_code_info_by_scl.get(tuple(scl)))
name = ("_".join(path + [extension_field.name])).upper()
rust_type = self.rust_type(None, extension_field)
extendee = self.ctx.find(extension_field.extendee)
kind = (
"RepeatedExtension"
if extension_field.label == FieldDescriptorProto.LABEL_REPEATED
else "SingularExtension"
)

self.write(
"""pub const {name}: ::pb_jelly::extensions::{kind}<{extendee}, {field_type}> =
::pb_jelly::extensions::{kind}::new(
{field_number},
::pb_jelly::wire_format::Type::{wire_format},
"{raw_name}",
);""".format(
name=name,
extendee=extendee.rust_name(crate, mod_parts),
field_type=rust_type.rust_type(),
kind=kind,
field_number=extension_field.number,
wire_format=rust_type.wire_format(),
raw_name=extension_field.name,
)
)

def walk(proto: FileDescriptorProto) -> WalkRet:
enums, messages = [], []

def walk(
proto: FileDescriptorProto,
) -> Tuple[
List[Tuple[List[Text], EnumDescriptorProto, SourceCodeLocation]],
List[Tuple[List[Text], DescriptorProto, SourceCodeLocation]],
List[Tuple[List[Text], FieldDescriptorProto, SourceCodeLocation]],
]:
enums, messages, extensions = [], [], []

def _walk(
proto: ProtoTypes, parents: List[Text], scl_prefix: SourceCodeLocation
Expand All @@ -1613,6 +1677,15 @@ def _walk(
for i, nested_message in enumerate(proto.nested_type):
ntfn = DescriptorProto.NESTED_TYPE_FIELD_NUMBER
_walk(nested_message, parents + [proto.name], scl_prefix + [ntfn, i])

for i, nested_extension in enumerate(proto.extension):
extensions.append(
(
parents + [proto.name],
nested_extension,
scl_prefix + [DescriptorProto.EXTENSION_FIELD_NUMBER, i],
)
)
elif isinstance(proto, FileDescriptorProto):
for i, enum_type in enumerate(proto.enum_type):
etfn = FileDescriptorProto.ENUM_TYPE_FIELD_NUMBER
Expand All @@ -1622,8 +1695,17 @@ def _walk(
mtfn = FileDescriptorProto.MESSAGE_TYPE_FIELD_NUMBER
_walk(message_type, parents, scl_prefix + [mtfn, i])

for i, nested_extension in enumerate(proto.extension):
extensions.append(
(
parents,
nested_extension,
scl_prefix + [FileDescriptorProto.EXTENSION_FIELD_NUMBER, i],
)
)

_walk(proto, [], [])
return enums, messages
return enums, messages, extensions


M = TypeVar("M", DescriptorProto, EnumDescriptorProto)
Expand Down Expand Up @@ -1755,6 +1837,11 @@ def calc_impls(
if msg_type.typ.options.Extensions[extensions_pb2.preserve_unrecognized]:
impls_copy = False # Preserve unparsed has a Vec which is not Copy

if len(msg_type.typ.extension_range) > 0:
# `Unrecognized` is neither Copy nor Eq
impls_eq = False
impls_copy = False

for field in msg_type.typ.field:
typ = field.type
rust_type = RustType(self, msg_type.proto_file, msg_type.typ, field)
Expand Down Expand Up @@ -1810,6 +1897,7 @@ def calc_impls(
if msg_type.typ.options.Extensions[
extensions_pb2.preserve_unrecognized
]:
# TODO: this check isn't really necessary, but it is useful
assert field_type.typ.options.Extensions[
extensions_pb2.preserve_unrecognized
], (
Expand Down Expand Up @@ -1844,7 +1932,7 @@ def calc_impls(
)

def feed(self, proto_file: FileDescriptorProto, to_generate: List[Text]) -> None:
enums, messages = walk(proto_file)
enums, messages, extensions = walk(proto_file)

for name in to_generate:
crate, _ = self.crate_from_proto_filename(name)
Expand All @@ -1865,6 +1953,8 @@ def feed(self, proto_file: FileDescriptorProto, to_generate: List[Text]) -> None
# so it suffices to examine one file at a time for the purposes of `box_recursive_fields`
box_recursive_fields(message_types)

crate, _ = self.crate_from_proto_filename(proto_file.name)

for path, typ, _ in messages:
msg_pt = ProtoType(self, proto_file, path, typ)

Expand All @@ -1879,6 +1969,17 @@ def edges(type_name: Text) -> List[Text]:

self.scc.process(msg_pt.proto_name(), edges, self.calc_impls)

if crate in self.deps_map:
for path, field, _ in extensions:
for type_name in [field.type_name, field.extendee]:
if type_name:
field_type = self.find(type_name)
dep_crate, _ = self.crate_from_proto_filename(
field_type.proto_file.name
)
if dep_crate != crate:
self.deps_map[crate].add(dep_crate)

def find_enum(self, typename: Text) -> ProtoType[EnumDescriptorProto]:
pt = self.find(typename)
assert isinstance(pt.typ, EnumDescriptorProto)
Expand Down Expand Up @@ -2124,7 +2225,7 @@ def add_mod(writer: CodeWriter) -> None:
if writer.derive_serde:
derive_serde = True

enums, messages = walk(proto_file)
enums, messages, extensions = walk(proto_file)

for path, enum_typ, scl in enums:
writer.gen_enum(path, enum_typ, scl)
Expand All @@ -2134,6 +2235,10 @@ def add_mod(writer: CodeWriter) -> None:
writer.gen_msg(path, msg_typ, scl)
writer.write("")

for path, extension_field, scl in extensions:
writer.gen_extension(path, extension_field, scl)
writer.write("")

add_mod(writer=writer)

# Note that output filenames must use "/" even on windows. It is part of the
Expand Down
Loading

0 comments on commit 4b7a934

Please sign in to comment.