From 76da93200312acfb9bfc625161289d9db92f636f Mon Sep 17 00:00:00 2001 From: Elia Migliore Date: Thu, 20 Jul 2023 20:26:36 +0200 Subject: [PATCH] Resolve entities recursively when receiving the schema The previous implementation did not correctly resolve entities within a schema that contained references. This led the REST Proxy responds with a 422 error even when the message was correct. To address this issue, the schema resolver has been modified to ensure that schemas with references are recursively resolved and that the references are passed to the parser correctly. --- .gitignore | 2 + karapace/kafka_rest_apis/schema_cache.py | 2 +- karapace/protobuf/schema.py | 2 +- karapace/schema_references.py | 13 +- karapace/serialization.py | 107 ++++++++-- tests/integration/test_client.py | 4 +- tests/integration/test_client_protobuf.py | 4 +- .../test_rest_consumer_protobuf.py | 188 ++++++++++++++++++ tests/unit/test_protobuf_serialization.py | 32 +-- tests/unit/test_serialization.py | 22 +- 10 files changed, 326 insertions(+), 50 deletions(-) diff --git a/.gitignore b/.gitignore index 612ad46b2..07fcf594f 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,5 @@ venv .python-version .hypothesis/ .DS_Store +# ignoring protobuf generated files. +runtime/ diff --git a/karapace/kafka_rest_apis/schema_cache.py b/karapace/kafka_rest_apis/schema_cache.py index 00953a69d..bde742e37 100644 --- a/karapace/kafka_rest_apis/schema_cache.py +++ b/karapace/kafka_rest_apis/schema_cache.py @@ -61,7 +61,7 @@ def get_schema_str(self, topic: Subject, schema_id: SchemaId) -> Optional[str]: class SchemaCache(SchemaCacheProtocol): def __init__(self) -> None: self._schema_hash_str_to_id: Dict[str, SchemaId] = {} - self._id_to_schema_str: MutableMapping[SchemaId, TypedSchema] = TTLCache(maxsize=10000, ttl=600) + self._id_to_schema_str: MutableMapping[SchemaId, TypedSchema] = TTLCache(maxsize=100, ttl=600) def get_schema_id(self, schema: TypedSchema) -> Optional[SchemaId]: fingerprint = hashlib.sha1(str(schema).encode("utf8")).hexdigest() diff --git a/karapace/protobuf/schema.py b/karapace/protobuf/schema.py index a71b235d0..8df162e75 100644 --- a/karapace/protobuf/schema.py +++ b/karapace/protobuf/schema.py @@ -131,7 +131,7 @@ def verify_schema_dependencies(self) -> DependencyVerifierResult: def collect_dependencies(self, verifier: ProtobufDependencyVerifier) -> None: if self.dependencies: for key in self.dependencies: - self.dependencies[key].schema.schema.collect_dependencies(verifier) + self.dependencies[key].get_schema().schema.collect_dependencies(verifier) for i in self.proto_file_element.imports: verifier.add_import(i) diff --git a/karapace/schema_references.py b/karapace/schema_references.py index c7bbf70ff..497bd61b1 100644 --- a/karapace/schema_references.py +++ b/karapace/schema_references.py @@ -8,12 +8,11 @@ from __future__ import annotations from karapace.dataclasses import default_dataclass -from karapace.typing import JsonData, ResolvedVersion, SchemaId, Subject -from typing import List, Mapping, NewType, TypeVar +from karapace.typing import JsonData, JsonObject, ResolvedVersion, SchemaId, Subject +from typing import cast, List, Mapping, NewType, TypeVar Referents = NewType("Referents", List[SchemaId]) - T = TypeVar("T") @@ -64,6 +63,14 @@ def to_dict(self) -> JsonData: "version": self.version, } + @staticmethod + def from_dict(data: JsonObject) -> Reference: + return Reference( + name=str(data["name"]), + subject=Subject(str(data["subject"])), + version=ResolvedVersion(cast(int, data["version"])), + ) + def reference_from_mapping( data: Mapping[str, object], diff --git a/karapace/serialization.py b/karapace/serialization.py index 6f10b437b..8765a4858 100644 --- a/karapace/serialization.py +++ b/karapace/serialization.py @@ -5,17 +5,19 @@ from aiohttp import BasicAuth from avro.io import BinaryDecoder, BinaryEncoder, DatumReader, DatumWriter from cachetools import TTLCache +from functools import lru_cache from google.protobuf.message import DecodeError from jsonschema import ValidationError from karapace.client import Client +from karapace.dependency import Dependency from karapace.errors import InvalidReferences from karapace.protobuf.exception import ProtobufTypeException from karapace.protobuf.io import ProtobufDatumReader, ProtobufDatumWriter from karapace.schema_models import InvalidSchema, ParsedTypedSchema, SchemaType, TypedSchema, ValidatedTypedSchema -from karapace.schema_references import Reference, reference_from_mapping -from karapace.typing import SchemaId, Subject +from karapace.schema_references import LatestVersionReference, Reference, reference_from_mapping +from karapace.typing import ResolvedVersion, SchemaId, Subject from karapace.utils import json_decode, json_encode -from typing import Any, Callable, Dict, List, MutableMapping, Optional, Tuple +from typing import Any, Callable, Dict, List, MutableMapping, Optional, Set, Tuple from urllib.parse import quote import asyncio @@ -101,19 +103,79 @@ async def post_new_schema( raise SchemaRetrievalError(result.json()) return SchemaId(result.json()["id"]) - async def get_latest_schema(self, subject: str) -> Tuple[SchemaId, ParsedTypedSchema]: - result = await self.client.get(f"subjects/{quote(subject)}/versions/latest") + async def _get_schema_r( + self, + subject: Subject, + explored_schemas: Set[Tuple[Subject, Optional[ResolvedVersion]]], + version: Optional[ResolvedVersion] = None, + ) -> Tuple[SchemaId, ValidatedTypedSchema, ResolvedVersion]: + if (subject, version) in explored_schemas: + raise InvalidSchema( + f"The schema has at least a cycle in dependencies, " + f"one path of the cycle is given by the following nodes: {explored_schemas}" + ) + + explored_schemas = explored_schemas | {(subject, version)} + + version_str = str(version) if version is not None else "latest" + result = await self.client.get(f"subjects/{quote(subject)}/versions/{version_str}") + if not result.ok: raise SchemaRetrievalError(result.json()) + json_result = result.json() - if "id" not in json_result or "schema" not in json_result: + if "id" not in json_result or "schema" not in json_result or "version" not in json_result: raise SchemaRetrievalError(f"Invalid result format: {json_result}") + + if "references" in json_result: + references = [Reference.from_dict(data) for data in json_result["references"]] + dependencies = {} + for reference in references: + _, schema, version = await self._get_schema_r(reference.subject, explored_schemas, reference.version) + dependencies[reference.name] = Dependency( + name=reference.name, subject=reference.subject, version=version, target_schema=schema + ) + else: + references = None + dependencies = None + try: schema_type = SchemaType(json_result.get("schemaType", "AVRO")) - return SchemaId(json_result["id"]), ParsedTypedSchema.parse(schema_type, json_result["schema"]) + return ( + SchemaId(json_result["id"]), + ValidatedTypedSchema.parse( + schema_type, + json_result["schema"], + references=references, + dependencies=dependencies, + ), + ResolvedVersion(json_result["version"]), + ) except InvalidSchema as e: raise SchemaRetrievalError(f"Failed to parse schema string from response: {json_result}") from e + @lru_cache(maxsize=100) + async def get_schema( + self, + subject: Subject, + version: Optional[ResolvedVersion] = None, + ) -> Tuple[SchemaId, ValidatedTypedSchema, ResolvedVersion]: + """ + Retrieves the schema and its dependencies for the specified subject. + + Args: + subject (Subject): The subject for which to retrieve the schema. + version (Optional[ResolvedVersion]): The specific version of the schema to retrieve. + If None, the latest available schema will be returned. + + Returns: + Tuple[SchemaId, ValidatedTypedSchema, ResolvedVersion]: A tuple containing: + - SchemaId: The ID of the retrieved schema. + - ValidatedTypedSchema: The retrieved schema, validated and typed. + - ResolvedVersion: The version of the schema that was retrieved. + """ + return await self._get_schema_r(subject, set(), version) + async def get_schema_for_id(self, schema_id: SchemaId) -> Tuple[TypedSchema, List[Subject]]: result = await self.client.get(f"schemas/ids/{schema_id}", params={"includeSubjects": "True"}) if not result.ok: @@ -138,15 +200,24 @@ async def get_schema_for_id(self, schema_id: SchemaId) -> Tuple[TypedSchema, Lis raise InvalidReferences from exc parsed_references.append(reference) if parsed_references: - return ( - ParsedTypedSchema.parse( - schema_type, - json_result["schema"], - references=parsed_references, - ), - subjects, - ) - return ParsedTypedSchema.parse(schema_type, json_result["schema"]), subjects + dependencies = {} + + for reference in parsed_references: + if isinstance(reference, LatestVersionReference): + _, schema, version = await self.get_schema(reference.subject) + else: + _, schema, version = await self.get_schema(reference.subject, reference.version) + + dependencies[reference.name] = Dependency(reference.name, reference.subject, version, schema) + else: + dependencies = None + + return ( + ParsedTypedSchema.parse( + schema_type, json_result["schema"], references=parsed_references, dependencies=dependencies + ), + subjects, + ) except InvalidSchema as e: raise SchemaRetrievalError(f"Failed to parse schema string from response: {json_result}") from e @@ -204,9 +275,9 @@ def get_subject_name( return Subject(f"{self.subject_name_strategy(topic_name, namespace)}-{subject_type}") - async def get_schema_for_subject(self, subject: str) -> TypedSchema: + async def get_schema_for_subject(self, subject: Subject) -> TypedSchema: assert self.registry_client, "must not call this method after the object is closed." - schema_id, schema = await self.registry_client.get_latest_schema(subject) + schema_id, schema, _ = await self.registry_client.get_schema(subject) async with self.state_lock: schema_ser = str(schema) self.schemas_to_ids[schema_ser] = schema_id diff --git a/tests/integration/test_client.py b/tests/integration/test_client.py index 9f6a26fb7..ba0fb5ff3 100644 --- a/tests/integration/test_client.py +++ b/tests/integration/test_client.py @@ -17,7 +17,7 @@ async def test_remote_client(registry_async_client: Client) -> None: assert sc_id >= 0 stored_schema, _ = await reg_cli.get_schema_for_id(sc_id) assert stored_schema == schema_avro, f"stored schema {stored_schema.to_dict()} is not {schema_avro.to_dict()}" - stored_id, stored_schema = await reg_cli.get_latest_schema(subject) + stored_id, stored_schema, _ = await reg_cli.get_schema(subject) assert stored_id == sc_id assert stored_schema == schema_avro @@ -31,6 +31,6 @@ async def test_remote_client_tls(registry_async_client_tls: Client) -> None: assert sc_id >= 0 stored_schema, _ = await reg_cli.get_schema_for_id(sc_id) assert stored_schema == schema_avro, f"stored schema {stored_schema.to_dict()} is not {schema_avro.to_dict()}" - stored_id, stored_schema = await reg_cli.get_latest_schema(subject) + stored_id, stored_schema = await reg_cli.get_schema(subject) assert stored_id == sc_id assert stored_schema == schema_avro diff --git a/tests/integration/test_client_protobuf.py b/tests/integration/test_client_protobuf.py index 7de39a7cd..231730db9 100644 --- a/tests/integration/test_client_protobuf.py +++ b/tests/integration/test_client_protobuf.py @@ -18,7 +18,7 @@ async def test_remote_client_protobuf(registry_async_client): assert sc_id >= 0 stored_schema, _ = await reg_cli.get_schema_for_id(sc_id) assert stored_schema == schema_protobuf, f"stored schema {stored_schema} is not {schema_protobuf}" - stored_id, stored_schema = await reg_cli.get_latest_schema(subject) + stored_id, stored_schema, _ = await reg_cli.get_schema(subject) assert stored_id == sc_id assert stored_schema == schema_protobuf @@ -33,6 +33,6 @@ async def test_remote_client_protobuf2(registry_async_client): assert sc_id >= 0 stored_schema, _ = await reg_cli.get_schema_for_id(sc_id) assert stored_schema == schema_protobuf, f"stored schema {stored_schema} is not {schema_protobuf}" - stored_id, stored_schema = await reg_cli.get_latest_schema(subject) + stored_id, stored_schema, _ = await reg_cli.get_schema(subject) assert stored_id == sc_id assert stored_schema == schema_protobuf_after diff --git a/tests/integration/test_rest_consumer_protobuf.py b/tests/integration/test_rest_consumer_protobuf.py index 4e0cb5284..dfa7278b5 100644 --- a/tests/integration/test_rest_consumer_protobuf.py +++ b/tests/integration/test_rest_consumer_protobuf.py @@ -2,13 +2,20 @@ Copyright (c) 2023 Aiven Ltd See LICENSE for details """ + +from karapace.client import Client +from karapace.kafka_rest_apis import KafkaRestAdminClient +from karapace.protobuf.kotlin_wrapper import trim_margin +from tests.integration.test_rest import NEW_TOPIC_TIMEOUT from tests.utils import ( new_consumer, + new_random_name, new_topic, repeat_until_successful_request, REST_HEADERS, schema_data, schema_data_second, + wait_for_topics, ) import pytest @@ -74,3 +81,184 @@ async def test_publish_consume_protobuf_second(rest_async_client, admin_client, data_values = [x["value"] for x in data] for expected, actual in zip(publish_payload, data_values): assert expected == actual, f"Expecting {actual} to be {expected}" + + +async def test_publish_protobuf_with_references( + rest_async_client: Client, + admin_client: KafkaRestAdminClient, + registry_async_client: Client, +): + topic_name = new_topic(admin_client) + subject_reference = "reference" + subject_topic = f"{topic_name}-value" + + await wait_for_topics(rest_async_client, topic_names=[topic_name], timeout=NEW_TOPIC_TIMEOUT, sleep=1) + + reference_schema = trim_margin( + """ + |syntax = "proto3"; + |message Reference { + | string name = 1; + |} + |""" + ) + + topic_schema = trim_margin( + """ + |syntax = "proto3"; + |import "Reference.proto"; + |message Example { + | Reference example = 1; + |} + |""" + ) + + res = await registry_async_client.post( + f"subjects/{subject_reference}/versions", json={"schemaType": "PROTOBUF", "schema": reference_schema} + ) + assert "id" in res.json() + + res = await registry_async_client.post( + f"subjects/{subject_topic}/versions", + json={ + "schemaType": "PROTOBUF", + "schema": topic_schema, + "references": [ + { + "name": "Reference.proto", + "subject": subject_reference, + "version": 1, + } + ], + }, + ) + topic_schema_id = res.json()["id"] + + example_message = {"value_schema_id": topic_schema_id, "records": [{"value": {"example": {"name": "myname"}}}]} + + res = await rest_async_client.post( + f"/topics/{topic_name}", + json=example_message, + headers=REST_HEADERS["avro"], + ) + assert res.status_code == 200 + + +async def test_publish_and_consume_protobuf_with_recursive_references( + rest_async_client: Client, + admin_client: KafkaRestAdminClient, + registry_async_client: Client, +): + topic_name = new_topic(admin_client) + subject_meta_reference = "meta-reference" + subject_inner_reference = "inner-reference" + subject_topic = f"{topic_name}-value" + + await wait_for_topics(rest_async_client, topic_names=[topic_name], timeout=NEW_TOPIC_TIMEOUT, sleep=1) + + meta_reference = trim_margin( + """ + |syntax = "proto3"; + |message MetaReference { + | string name = 1; + |} + |""" + ) + + inner_reference = trim_margin( + """ + |syntax = "proto3"; + |import "MetaReference.proto"; + |message InnerReference { + | MetaReference reference = 1; + } + |""" + ) + + topic_schema = trim_margin( + """ + |syntax = "proto3"; + |import "InnerReference.proto"; + |message Example { + | InnerReference example = 1; + |} + |""" + ) + + res = await registry_async_client.post( + f"subjects/{subject_meta_reference}/versions", json={"schemaType": "PROTOBUF", "schema": meta_reference} + ) + assert "id" in res.json() + res = await registry_async_client.post( + f"subjects/{subject_inner_reference}/versions", + json={ + "schemaType": "PROTOBUF", + "schema": inner_reference, + "references": [ + { + "name": "MetaReference.proto", + "subject": subject_meta_reference, + "version": 1, + } + ], + }, + ) + assert "id" in res.json() + + res = await registry_async_client.post( + f"subjects/{subject_topic}/versions", + json={ + "schemaType": "PROTOBUF", + "schema": topic_schema, + "references": [ + { + "name": "InnerReference.proto", + "subject": subject_inner_reference, + "version": 1, + } + ], + }, + ) + topic_schema_id = res.json()["id"] + + produced_message = {"example": {"reference": {"name": "myname"}}} + example_message = { + "value_schema_id": topic_schema_id, + "records": [{"value": produced_message}], + } + + res = await rest_async_client.post( + f"/topics/{topic_name}", + json=example_message, + headers=REST_HEADERS["avro"], + ) + assert res.status_code == 200 + + group = new_random_name("protobuf_recursive_reference_message") + instance_id = await new_consumer(rest_async_client, group) + + subscribe_path = f"/consumers/{group}/instances/{instance_id}/subscription" + + consume_path = f"/consumers/{group}/instances/{instance_id}/records?timeout=1000" + + res = await rest_async_client.post(subscribe_path, json={"topics": [topic_name]}, headers=REST_HEADERS["binary"]) + assert res.ok + + resp = await rest_async_client.get(consume_path, headers=REST_HEADERS["avro"]) + data = resp.json() + + assert isinstance(data, list) + assert len(data) == 1 + + msg = data[0] + + assert "key" in msg + assert "offset" in msg + assert "topic" in msg + assert "value" in msg + assert "timestamp" in msg + + assert msg["key"] is None, "no key defined in production" + assert msg["offset"] == 0 and msg["partition"] == 0, "first message of the only partition available" + assert msg["topic"] == topic_name + assert msg["value"] == produced_message diff --git a/tests/unit/test_protobuf_serialization.py b/tests/unit/test_protobuf_serialization.py index c66cf50f7..3a5dc08f0 100644 --- a/tests/unit/test_protobuf_serialization.py +++ b/tests/unit/test_protobuf_serialization.py @@ -14,7 +14,7 @@ SchemaRegistrySerializer, START_BYTE, ) -from karapace.typing import Subject +from karapace.typing import ResolvedVersion, Subject from tests.utils import schema_protobuf, test_fail_objects_protobuf, test_objects_protobuf from unittest.mock import call, Mock @@ -43,8 +43,10 @@ async def test_happy_flow(default_config_path): ) mock_protobuf_registry_client.get_schema_for_id.return_value = schema_for_id_one_future get_latest_schema_future = asyncio.Future() - get_latest_schema_future.set_result((1, ParsedTypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf)))) - mock_protobuf_registry_client.get_latest_schema.return_value = get_latest_schema_future + get_latest_schema_future.set_result( + (1, ParsedTypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf)), ResolvedVersion(1)) + ) + mock_protobuf_registry_client.get_schema.return_value = get_latest_schema_future serializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) assert len(serializer.ids_to_schemas) == 0 @@ -56,7 +58,7 @@ async def test_happy_flow(default_config_path): assert len(serializer.ids_to_schemas) == 1 assert 1 in serializer.ids_to_schemas - assert mock_protobuf_registry_client.method_calls == [call.get_latest_schema("top"), call.get_schema_for_id(1)] + assert mock_protobuf_registry_client.method_calls == [call.get_schema("top"), call.get_schema_for_id(1)] async def test_happy_flow_references(default_config_path): @@ -111,8 +113,8 @@ async def test_happy_flow_references(default_config_path): schema_for_id_one_future.set_result((ref_schema, [Subject("stub")])) mock_protobuf_registry_client.get_schema_for_id.return_value = schema_for_id_one_future get_latest_schema_future = asyncio.Future() - get_latest_schema_future.set_result((1, ref_schema)) - mock_protobuf_registry_client.get_latest_schema.return_value = get_latest_schema_future + get_latest_schema_future.set_result((1, ref_schema, ResolvedVersion(1))) + mock_protobuf_registry_client.get_schema.return_value = get_latest_schema_future serializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) assert len(serializer.ids_to_schemas) == 0 @@ -124,7 +126,7 @@ async def test_happy_flow_references(default_config_path): assert len(serializer.ids_to_schemas) == 1 assert 1 in serializer.ids_to_schemas - assert mock_protobuf_registry_client.method_calls == [call.get_latest_schema("top"), call.get_schema_for_id(1)] + assert mock_protobuf_registry_client.method_calls == [call.get_schema("top"), call.get_schema_for_id(1)] async def test_happy_flow_references_two(default_config_path): @@ -198,8 +200,8 @@ async def test_happy_flow_references_two(default_config_path): schema_for_id_one_future.set_result((ref_schema_two, [Subject("mock")])) mock_protobuf_registry_client.get_schema_for_id.return_value = schema_for_id_one_future get_latest_schema_future = asyncio.Future() - get_latest_schema_future.set_result((1, ref_schema_two)) - mock_protobuf_registry_client.get_latest_schema.return_value = get_latest_schema_future + get_latest_schema_future.set_result((1, ref_schema_two, ResolvedVersion(1))) + mock_protobuf_registry_client.get_schema.return_value = get_latest_schema_future serializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) assert len(serializer.ids_to_schemas) == 0 @@ -211,28 +213,30 @@ async def test_happy_flow_references_two(default_config_path): assert len(serializer.ids_to_schemas) == 1 assert 1 in serializer.ids_to_schemas - assert mock_protobuf_registry_client.method_calls == [call.get_latest_schema("top"), call.get_schema_for_id(1)] + assert mock_protobuf_registry_client.method_calls == [call.get_schema("top"), call.get_schema_for_id(1)] async def test_serialization_fails(default_config_path): mock_protobuf_registry_client = Mock() get_latest_schema_future = asyncio.Future() - get_latest_schema_future.set_result((1, ParsedTypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf)))) - mock_protobuf_registry_client.get_latest_schema.return_value = get_latest_schema_future + get_latest_schema_future.set_result( + (1, ParsedTypedSchema.parse(SchemaType.PROTOBUF, trim_margin(schema_protobuf)), ResolvedVersion(1)) + ) + mock_protobuf_registry_client.get_schema.return_value = get_latest_schema_future serializer = await make_ser_deser(default_config_path, mock_protobuf_registry_client) with pytest.raises(InvalidMessageSchema): schema = await serializer.get_schema_for_subject("top") await serializer.serialize(schema, test_fail_objects_protobuf[0]) - assert mock_protobuf_registry_client.method_calls == [call.get_latest_schema("top")] + assert mock_protobuf_registry_client.method_calls == [call.get_schema("top")] mock_protobuf_registry_client.reset_mock() with pytest.raises(InvalidMessageSchema): schema = await serializer.get_schema_for_subject("top") await serializer.serialize(schema, test_fail_objects_protobuf[1]) - assert mock_protobuf_registry_client.method_calls == [call.get_latest_schema("top")] + assert mock_protobuf_registry_client.method_calls == [call.get_schema("top")] async def test_deserialization_fails(default_config_path): diff --git a/tests/unit/test_serialization.py b/tests/unit/test_serialization.py index c76d4f4c2..28389a98f 100644 --- a/tests/unit/test_serialization.py +++ b/tests/unit/test_serialization.py @@ -14,7 +14,7 @@ START_BYTE, write_value, ) -from karapace.typing import Subject +from karapace.typing import ResolvedVersion, Subject from tests.utils import schema_avro_json, test_objects_avro from unittest.mock import call, Mock @@ -42,21 +42,23 @@ async def make_ser_deser(config_path: str, mock_client) -> SchemaRegistrySeriali async def test_happy_flow(default_config_path): mock_registry_client = Mock() get_latest_schema_future = asyncio.Future() - get_latest_schema_future.set_result((1, ValidatedTypedSchema.parse(SchemaType.AVRO, schema_avro_json))) - mock_registry_client.get_latest_schema.return_value = get_latest_schema_future + get_latest_schema_future.set_result( + (1, ValidatedTypedSchema.parse(SchemaType.AVRO, schema_avro_json), ResolvedVersion(1)) + ) + mock_registry_client.get_schema.return_value = get_latest_schema_future schema_for_id_one_future = asyncio.Future() schema_for_id_one_future.set_result((ValidatedTypedSchema.parse(SchemaType.AVRO, schema_avro_json), [Subject("stub")])) mock_registry_client.get_schema_for_id.return_value = schema_for_id_one_future serializer = await make_ser_deser(default_config_path, mock_registry_client) assert len(serializer.ids_to_schemas) == 0 - schema = await serializer.get_schema_for_subject("top") + schema = await serializer.get_schema_for_subject(Subject("top")) for o in test_objects_avro: assert o == await serializer.deserialize(await serializer.serialize(schema, o)) assert len(serializer.ids_to_schemas) == 1 assert 1 in serializer.ids_to_schemas - assert mock_registry_client.method_calls == [call.get_latest_schema("top"), call.get_schema_for_id(1)] + assert mock_registry_client.method_calls == [call.get_schema("top"), call.get_schema_for_id(1)] def test_flatten_unions_record() -> None: @@ -249,15 +251,17 @@ def test_avro_json_write_accepts_json_encoded_data_without_tagged_unions() -> No async def test_serialization_fails(default_config_path): mock_registry_client = Mock() get_latest_schema_future = asyncio.Future() - get_latest_schema_future.set_result((1, ValidatedTypedSchema.parse(SchemaType.AVRO, schema_avro_json))) - mock_registry_client.get_latest_schema.return_value = get_latest_schema_future + get_latest_schema_future.set_result( + (1, ValidatedTypedSchema.parse(SchemaType.AVRO, schema_avro_json), ResolvedVersion(1)) + ) + mock_registry_client.get_schema.return_value = get_latest_schema_future serializer = await make_ser_deser(default_config_path, mock_registry_client) with pytest.raises(InvalidMessageSchema): - schema = await serializer.get_schema_for_subject("topic") + schema = await serializer.get_schema_for_subject(Subject("topic")) await serializer.serialize(schema, {"foo": "bar"}) - assert mock_registry_client.method_calls == [call.get_latest_schema("topic")] + assert mock_registry_client.method_calls == [call.get_schema("topic")] async def test_deserialization_fails(default_config_path):