Skip to content

Commit

Permalink
Added all the other tests in the parser, added also the other rest en…
Browse files Browse the repository at this point in the history
…dpoint
  • Loading branch information
eliax1996 committed Apr 9, 2024
1 parent 478feb5 commit 57372e7
Show file tree
Hide file tree
Showing 5 changed files with 708 additions and 14 deletions.
2 changes: 2 additions & 0 deletions karapace/protobuf/option_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@


class OptionElement:
name: str

class Kind(Enum):
STRING = 1
BOOLEAN = 2
Expand Down
192 changes: 188 additions & 4 deletions karapace/protobuf/proto_normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,211 @@
Copyright (c) 2024 Aiven Ltd
See LICENSE for details
"""
from karapace.protobuf.enum_constant_element import EnumConstantElement
from karapace.protobuf.enum_element import EnumElement
from karapace.protobuf.extend_element import ExtendElement
from karapace.protobuf.field_element import FieldElement
from karapace.protobuf.group_element import GroupElement
from karapace.protobuf.message_element import MessageElement
from karapace.protobuf.one_of_element import OneOfElement
from karapace.protobuf.option_element import OptionElement
from karapace.protobuf.proto_file_element import ProtoFileElement
from karapace.protobuf.rpc_element import RpcElement
from karapace.protobuf.service_element import ServiceElement
from karapace.protobuf.type_element import TypeElement
from karapace.typing import StrEnum
from typing import List


class ProtobufNormalisationOptions(StrEnum):
sort_options = "sort_options"


def sort_by_name(element: OptionElement) -> str:
return element.name


def type_field_element_with_sorted_options(type_field: FieldElement) -> FieldElement:
sorted_options = None if type_field.options is None else list(sorted(type_field.options, key=sort_by_name))
return FieldElement(
location=type_field.location,
label=type_field.label,
element_type=type_field.element_type,
name=type_field.name,
default_value=type_field.default_value,
json_name=type_field.json_name,
tag=type_field.tag,
documentation=type_field.documentation,
options=sorted_options,
)


def enum_constant_element_with_sorted_options(enum_constant: EnumConstantElement) -> EnumConstantElement:
sorted_options = None if enum_constant.options is None else list(sorted(enum_constant.options, key=sort_by_name))
return EnumConstantElement(
location=enum_constant.location,
name=enum_constant.name,
tag=enum_constant.tag,
documentation=enum_constant.documentation,
options=sorted_options,
)


def enum_element_with_sorted_options(enum_element: EnumElement) -> EnumElement:
sorted_options = None if enum_element.options is None else list(sorted(enum_element.options, key=sort_by_name))
constants_with_sorted_options = (
None
if enum_element.constants is None
else [enum_constant_element_with_sorted_options(constant) for constant in enum_element.constants]
)
return EnumElement(
location=enum_element.location,
name=enum_element.name,
documentation=enum_element.documentation,
options=sorted_options,
constants=constants_with_sorted_options,
)


def groups_with_sorted_options(group: GroupElement) -> GroupElement:
sorted_fields = (
None if group.fields is None else [type_field_element_with_sorted_options(field) for field in group.fields]
)
return GroupElement(
label=group.label,
location=group.location,
name=group.name,
tag=group.tag,
documentation=group.documentation,
fields=sorted_fields,
)


def one_ofs_with_sorted_options(one_ofs: OneOfElement) -> OneOfElement:
sorted_options = None if one_ofs.options is None else list(sorted(one_ofs.options, key=sort_by_name))
sorted_fields = [type_field_element_with_sorted_options(field) for field in one_ofs.fields]
sorted_groups = [groups_with_sorted_options(group) for group in one_ofs.groups]

return OneOfElement(
name=one_ofs.name,
documentation=one_ofs.documentation,
fields=sorted_fields,
groups=sorted_groups,
options=sorted_options,
)


def message_element_with_sorted_options(message_element: MessageElement) -> MessageElement:
sorted_options = None if message_element.options is None else list(sorted(message_element.options, key=sort_by_name))
sorted_neasted_types = [type_element_with_sorted_options(nested_type) for nested_type in message_element.nested_types]
sorted_fields = [type_field_element_with_sorted_options(field) for field in message_element.fields]
sorted_one_ofs = [one_ofs_with_sorted_options(one_of) for one_of in message_element.one_ofs]

return MessageElement(
location=message_element.location,
name=message_element.name,
documentation=message_element.documentation,
nested_types=sorted_neasted_types,
options=sorted_options,
reserveds=message_element.reserveds,
fields=sorted_fields,
one_ofs=sorted_one_ofs,
extensions=message_element.extensions,
groups=message_element.groups,
)


def type_element_with_sorted_options(type_element: TypeElement) -> TypeElement:
sorted_neasted_types: List[TypeElement] = []

for nested_type in type_element.nested_types:
if isinstance(nested_type, EnumElement):
sorted_neasted_types.append(enum_element_with_sorted_options(nested_type))
elif isinstance(nested_type, MessageElement):
sorted_neasted_types.append(message_element_with_sorted_options(nested_type))
else:
raise ValueError("Unknown type element") # tried with assert_never but it did not work

# doing it here since the subtypes do not declare the nested_types property
type_element.nested_types = sorted_neasted_types

if isinstance(type_element, EnumElement):
return enum_element_with_sorted_options(type_element)

if isinstance(type_element, MessageElement):
return message_element_with_sorted_options(type_element)

raise ValueError("Unknown type element") # tried with assert_never but it did not work


def extends_element_with_sorted_options(extend_element: ExtendElement) -> ExtendElement:
sorted_fields = (
None
if extend_element.fields is None
else [type_field_element_with_sorted_options(field) for field in extend_element.fields]
)
return ExtendElement(
location=extend_element.location,
name=extend_element.name,
documentation=extend_element.documentation,
fields=sorted_fields,
)


def rpc_element_with_sorted_options(rpc: RpcElement) -> RpcElement:
sorted_options = None if rpc.options is None else list(sorted(rpc.options, key=sort_by_name))
return RpcElement(
location=rpc.location,
name=rpc.name,
documentation=rpc.documentation,
request_type=rpc.request_type,
response_type=rpc.response_type,
request_streaming=rpc.request_streaming,
response_streaming=rpc.response_streaming,
options=sorted_options,
)


def service_element_with_sorted_options(service_element: ServiceElement) -> ServiceElement:
sorted_options = None if service_element.options is None else list(sorted(service_element.options, key=sort_by_name))
sorted_rpc = (
None if service_element.rpcs is None else [rpc_element_with_sorted_options(rpc) for rpc in service_element.rpcs]
)

return ServiceElement(
location=service_element.location,
name=service_element.name,
documentation=service_element.documentation,
rpcs=sorted_rpc,
options=sorted_options,
)


def normalize_options_ordered(proto_file_element: ProtoFileElement) -> ProtoFileElement:
sorted_types = [type_element_with_sorted_options(type_element) for type_element in proto_file_element.types]
sorted_options = (
None if proto_file_element.options is None else list(sorted(proto_file_element.options, key=lambda x: x.name))
None if proto_file_element.options is None else list(sorted(proto_file_element.options, key=sort_by_name))
)
sorted_services = (
None
if proto_file_element.services is None
else [service_element_with_sorted_options(service) for service in proto_file_element.services]
)
sorted_extend_declarations = (
None
if proto_file_element.extend_declarations is None
else [extends_element_with_sorted_options(extend) for extend in proto_file_element.extend_declarations]
)

return ProtoFileElement(
location=proto_file_element.location,
package_name=proto_file_element.package_name,
syntax=proto_file_element.syntax,
imports=proto_file_element.imports,
public_imports=proto_file_element.public_imports,
types=proto_file_element.types,
services=proto_file_element.services,
extend_declarations=proto_file_element.extend_declarations,
types=sorted_types,
services=sorted_services,
extend_declarations=sorted_extend_declarations,
options=sorted_options,
)

Expand Down
8 changes: 7 additions & 1 deletion karapace/schema_registry_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,11 +1097,16 @@ async def subjects_schema_post(
schema_type = self._validate_schema_type(content_type=content_type, data=body)
references = self._validate_references(content_type, schema_type, body)
references, new_schema_dependencies = self.schema_registry.resolve_references(references)
normalize = request.query.get("normalize", "false").lower() == "true"
try:
# When checking if schema is already registered, allow unvalidated schema in as
# there might be stored schemas that are non-compliant from the past.
new_schema = ParsedTypedSchema.parse(
schema_type=schema_type, schema_str=schema_str, references=references, dependencies=new_schema_dependencies
schema_type=schema_type,
schema_str=schema_str,
references=references,
dependencies=new_schema_dependencies,
normalize=normalize,
)
except InvalidSchema:
self.log.warning("Invalid schema: %r", schema_str)
Expand Down Expand Up @@ -1133,6 +1138,7 @@ async def subjects_schema_post(
schema_version.schema.schema_str,
references=other_references,
dependencies=other_dependencies,
normalize=normalize,
)
except InvalidSchema as e:
failed_schema_id = schema_version.schema_id
Expand Down
53 changes: 53 additions & 0 deletions tests/integration/test_schema_protobuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,6 +1309,59 @@ async def test_protobuf_normalization_of_options(registry_async_client: Client)
option java_outer_classname = "FredProto";
option java_generic_services = true;
message Foo {
string code = 1;
}
"""

body = {"schemaType": "PROTOBUF", "schema": schema_with_option_unordered_2}
res = await registry_async_client.post(f"subjects/{subject}", json=body)
assert res.status_code == 404

res = await registry_async_client.post(f"subjects/{subject}?normalize=true", json=body)

assert res.status_code == 200
assert "id" in res.json()
assert original_schema_id == res.json()["id"]


async def test_protobuf_normalization_of_options_specify_version(registry_async_client: Client) -> None:
subject = create_subject_name_factory("test_protobuf_normalization")()

schema_with_option_unordered_1 = """\
syntax = "proto3";
package tc4;
option java_package = "com.example";
option java_outer_classname = "FredProto";
option java_multiple_files = true;
option java_generic_services = true;
option java_generate_equals_and_hash = true;
option java_string_check_utf8 = true;
message Foo {
string code = 1;
}
"""

body = {"schemaType": "PROTOBUF", "schema": schema_with_option_unordered_1}
res = await registry_async_client.post(f"subjects/{subject}/versions?normalize=true", json=body)

assert res.status_code == 200
assert "id" in res.json()
original_schema_id = res.json()["id"]

schema_with_option_unordered_2 = """\
syntax = "proto3";
package tc4;
option java_package = "com.example";
option java_generate_equals_and_hash = true;
option java_string_check_utf8 = true;
option java_multiple_files = true;
option java_outer_classname = "FredProto";
option java_generic_services = true;
message Foo {
string code = 1;
}
Expand Down
Loading

0 comments on commit 57372e7

Please sign in to comment.