Skip to content

Commit

Permalink
Merge pull request #664 from aiven/eliax1996/fix-inconsistent-schema-…
Browse files Browse the repository at this point in the history
…publish

fix: inconsistent schema during message produce
  • Loading branch information
giuseppelillo committed Jul 3, 2023
2 parents 6a4b9fe + ca65ddd commit ca96e88
Show file tree
Hide file tree
Showing 22 changed files with 487 additions and 105 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ venv
.run
.python-version
.hypothesis/
.DS_Store
2 changes: 1 addition & 1 deletion karapace/avro_dataclasses/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def optional_parser(parser: Parser | None) -> Parser | None:
return None

def parse(value: object) -> object:
return None if value is None else parser(value) # type: ignore[misc]
return None if value is None else parser(value)

return parse

Expand Down
2 changes: 2 additions & 0 deletions karapace/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ async def get(
json: JsonData = None,
headers: Optional[Headers] = None,
auth: Optional[BasicAuth] = None,
params: Optional[Mapping[str, str]] = None,
) -> Result:
path = self.path_for(path)
if not headers:
Expand All @@ -101,6 +102,7 @@ async def get(
headers=headers,
auth=auth,
ssl=self.ssl_mode,
params=params,
) as res:
# required for forcing the response body conversion to json despite missing valid Accept headers
json_result = await res.json(content_type=None)
Expand Down
15 changes: 13 additions & 2 deletions karapace/in_memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from karapace.schema_references import Reference, Referents
from karapace.typing import ResolvedVersion, SchemaId, Subject
from threading import Lock, RLock
from typing import Iterable
from typing import Iterable, Sequence

import logging

Expand Down Expand Up @@ -111,7 +111,7 @@ def insert_schema_version(
version: ResolvedVersion,
deleted: bool,
schema: TypedSchema,
references: list[Reference] | None,
references: Sequence[Reference] | None,
) -> None:
with self.schema_lock_thread:
self.global_schema_id = max(self.global_schema_id, schema_id)
Expand Down Expand Up @@ -184,6 +184,17 @@ def find_schemas(self, *, include_deleted: bool, latest_only: bool) -> dict[Subj
res_schemas[subject] = selected_schemas
return res_schemas

def subjects_for_schema(self, schema_id: SchemaId) -> list[Subject]:
subjects = []
with self.schema_lock_thread:
for subject, subject_data in self.subjects.items():
for version in subject_data.schemas.values():
if version.deleted is False and version.schema_id == schema_id:
subjects.append(subject)
break

return subjects

def find_schema_versions_by_schema_id(self, *, schema_id: SchemaId, include_deleted: bool) -> list[SchemaVersion]:
schema_versions: list[SchemaVersion] = []
with self.schema_lock_thread:
Expand Down
106 changes: 89 additions & 17 deletions karapace/kafka_rest_apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
from karapace.kafka_rest_apis.admin import KafkaRestAdminClient
from karapace.kafka_rest_apis.consumer_manager import ConsumerManager
from karapace.kafka_rest_apis.error_codes import RESTErrorCodes
from karapace.kafka_rest_apis.schema_cache import TopicSchemaCache
from karapace.karapace import KarapaceBase
from karapace.rapu import HTTPRequest, HTTPResponse, JSON_CONTENT_TYPE
from karapace.schema_reader import SchemaType
from karapace.schema_models import TypedSchema, ValidatedTypedSchema
from karapace.schema_type import SchemaType
from karapace.serialization import InvalidMessageSchema, InvalidPayload, SchemaRegistrySerializer, SchemaRetrievalError
from karapace.typing import SchemaId, Subject
from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient
from typing import Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Union

import aiohttp.web
import asyncio
Expand Down Expand Up @@ -416,7 +419,7 @@ def __init__(self, config: Config, kafka_timeout: int, serializer):
self.admin_client = None
self.admin_lock = asyncio.Lock()
self.metadata_cache = None
self.schemas_cache = {}
self.topic_schema_cache = TopicSchemaCache()
self.consumer_manager = ConsumerManager(config=config, deserializer=self.serializer)
self.init_admin_client()
self._last_used = time.monotonic()
Expand Down Expand Up @@ -739,18 +742,83 @@ def is_valid_schema_request(data: dict, prefix: str) -> bool:
return False
return isinstance(schema, str)

async def get_schema_id(self, data: dict, topic: str, prefix: str, schema_type: SchemaType) -> int:
log.debug("Retrieving schema id for %r", data)
if f"{prefix}_schema_id" in data and data[f"{prefix}_schema_id"] is not None:
log.debug("Will use schema id %d for serializing %s on topic %s", data[f"{prefix}_schema_id"], prefix, topic)
return int(data[f"{prefix}_schema_id"])
schema_str = data[f"{prefix}_schema"]
log.debug("Registering / Retrieving ID for schema %s", schema_str)
if schema_str not in self.schemas_cache:
subject_name = self.serializer.get_subject_name(topic, data[f"{prefix}_schema"], prefix, schema_type)
schema_id = await self.serializer.get_id_for_schema(data[f"{prefix}_schema"], subject_name, schema_type)
self.schemas_cache[schema_str] = schema_id
return self.schemas_cache[schema_str]
async def get_schema_id(
self,
data: dict,
topic: str,
prefix: str,
schema_type: SchemaType,
) -> SchemaId:
"""
This method search and validate the SchemaId for a request, it acts as a guard (In case of something wrong
throws an error).
:raises InvalidSchema:
"""
log.debug("[resolve schema id] Retrieving schema id for %r", data)
schema_id: Union[SchemaId, None] = (
SchemaId(int(data[f"{prefix}_schema_id"])) if f"{prefix}_schema_id" in data else None
)
schema_str = data.get(f"{prefix}_schema")

if schema_id is None and schema_str is None:
raise InvalidSchema()

if schema_id is None:
parsed_schema = ValidatedTypedSchema.parse(schema_type, schema_str)
subject_name = self.serializer.get_subject_name(topic, parsed_schema, prefix, schema_type)
schema_id = await self._query_schema_id_from_cache_or_registry(parsed_schema, schema_str, subject_name)
else:

def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool:
subject = self.serializer.get_subject_name(topic, schema, prefix, schema_type)
return subject not in subjects

parsed_schema, valid_subjects = await self._query_schema_and_subjects(
schema_id,
need_new_call=subject_not_included,
)

if subject_not_included(parsed_schema, valid_subjects):
raise InvalidSchema()

return schema_id

async def _query_schema_and_subjects(
self, schema_id: SchemaId, *, need_new_call: Optional[Callable[[TypedSchema, List[Subject]], bool]]
) -> Tuple[TypedSchema, List[Subject]]:
try:
return await self.serializer.get_schema_for_id(schema_id, need_new_call=need_new_call)
except SchemaRetrievalError as schema_error:
# if the schema doesn't exist we treated as if the error was due to an invalid schema
raise InvalidSchema() from schema_error

async def _query_schema_id_from_cache_or_registry(
self,
parsed_schema: ValidatedTypedSchema,
schema_str: str,
subject_name: Subject,
) -> SchemaId:
"""
Checks if the schema registered with a certain id match with the schema provided (you can provide
a valid id but place in the body a totally unrelated schema).
Also, here if we don't have a match we query the registry since the cache could be evicted in the meanwhile
or the schema could be registered without passing though the http proxy.
"""
schema_id = self.topic_schema_cache.get_schema_id(subject_name, parsed_schema)
if schema_id is None:
log.debug("[resolve schema id] Registering / Retrieving ID for %s and schema %s", subject_name, schema_str)
schema_id = await self.serializer.upsert_id_for_schema(parsed_schema, subject_name)
log.debug("[resolve schema id] Found schema id %s from registry for subject %s", schema_id, subject_name)
self.topic_schema_cache.set_schema(subject_name, schema_id, parsed_schema)
else:
log.debug(
"[resolve schema id] schema ID %s found from cache for %s and schema %s",
schema_id,
subject_name,
schema_str,
)
return schema_id

async def validate_schema_info(self, data: dict, prefix: str, content_type: str, topic: str, schema_type: str):
try:
Expand Down Expand Up @@ -788,10 +856,14 @@ async def validate_schema_info(self, data: dict, prefix: str, content_type: str,
status=HTTPStatus.REQUEST_TIMEOUT,
)
except InvalidSchema:
if f"{prefix}_schema" in data:
err = f'schema = {data[f"{prefix}_schema"]}'
else:
err = f'schema_id = {data[f"{prefix}_schema_id"]}'
KafkaRest.r(
body={
"error_code": RESTErrorCodes.INVALID_DATA.value,
"message": f'Invalid schema. format = {schema_type.value}, schema = {data[f"{prefix}_schema"]}',
"message": f"Invalid schema. format = {schema_type.value}, {err}",
},
content_type=content_type,
status=HTTPStatus.UNPROCESSABLE_ENTITY,
Expand Down Expand Up @@ -882,7 +954,7 @@ async def serialize(
raise FormatError(f"Unknown format: {ser_format}")

async def schema_serialize(self, obj: dict, schema_id: Optional[int]) -> bytes:
schema = await self.serializer.get_schema_for_id(schema_id)
schema, _ = await self.serializer.get_schema_for_id(schema_id)
bytes_ = await self.serializer.serialize(schema, obj)
return bytes_

Expand Down
107 changes: 107 additions & 0 deletions karapace/kafka_rest_apis/schema_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
Copyright (c) 2023 Aiven Ltd
See LICENSE for details
"""

from abc import ABC, abstractmethod
from cachetools import TTLCache
from karapace.schema_models import TypedSchema
from karapace.typing import SchemaId, Subject
from typing import Dict, Final, MutableMapping, Optional

import hashlib


class SchemaCacheProtocol(ABC):
@abstractmethod
def get_schema_id(self, schema: TypedSchema) -> Optional[SchemaId]:
pass

@abstractmethod
def has_schema_id(self, schema_id: SchemaId) -> bool:
pass

@abstractmethod
def set_schema(self, schema_id: SchemaId, schema: TypedSchema) -> None:
pass

@abstractmethod
def get_schema(self, schema_id: SchemaId) -> Optional[TypedSchema]:
pass

@abstractmethod
def get_schema_str(self, schema_id: SchemaId) -> Optional[str]:
pass


class TopicSchemaCache:
def __init__(self) -> None:
self._topic_cache: Dict[Subject, SchemaCache] = {}
self._empty_schema_cache: Final = EmptySchemaCache()

def get_schema_id(self, topic: Subject, schema: TypedSchema) -> Optional[SchemaId]:
return self._topic_cache.get(topic, self._empty_schema_cache).get_schema_id(schema)

def has_schema_id(self, topic: Subject, schema_id: SchemaId) -> bool:
return self._topic_cache.get(topic, self._empty_schema_cache).has_schema_id(schema_id)

def set_schema(self, topic: str, schema_id: SchemaId, schema: TypedSchema) -> None:
schema_cache_with_defaults = self._topic_cache.setdefault(Subject(topic), SchemaCache())
schema_cache_with_defaults.set_schema(schema_id, schema)

def get_schema(self, topic: Subject, schema_id: SchemaId) -> Optional[TypedSchema]:
schema_cache = self._topic_cache.get(topic, self._empty_schema_cache)
return schema_cache.get_schema(schema_id)

def get_schema_str(self, topic: Subject, schema_id: SchemaId) -> Optional[str]:
schema_cache = self._topic_cache.get(topic, self._empty_schema_cache)
return schema_cache.get_schema_str(schema_id)


class SchemaCache(SchemaCacheProtocol):
def __init__(self) -> None:
self._schema_hash_str_to_id: Dict[str, SchemaId] = {}
self._id_to_schema_str: MutableMapping[SchemaId, TypedSchema] = TTLCache(maxsize=10000, ttl=600)

def get_schema_id(self, schema: TypedSchema) -> Optional[SchemaId]:
fingerprint = hashlib.sha1(str(schema).encode("utf8")).hexdigest()

maybe_id = self._schema_hash_str_to_id.get(fingerprint)

if maybe_id is not None and maybe_id not in self._id_to_schema_str:
del self._schema_hash_str_to_id[fingerprint]
return None

return maybe_id

def has_schema_id(self, schema_id: SchemaId) -> bool:
return schema_id in self._id_to_schema_str

def set_schema(self, schema_id: SchemaId, schema: TypedSchema) -> None:
fingerprint = hashlib.sha1(str(schema).encode("utf8")).hexdigest()
self._schema_hash_str_to_id[fingerprint] = schema_id
self._id_to_schema_str[schema_id] = schema

def get_schema(self, schema_id: SchemaId) -> Optional[TypedSchema]:
return self._id_to_schema_str.get(schema_id)

def get_schema_str(self, schema_id: SchemaId) -> Optional[str]:
maybe_schema = self.get_schema(schema_id)
return None if maybe_schema is None else str(maybe_schema)


class EmptySchemaCache(SchemaCacheProtocol):
def get_schema_id(self, schema: TypedSchema) -> None:
return None

def has_schema_id(self, schema_id: SchemaId) -> bool:
return False

def set_schema(self, schema_id: SchemaId, schema: TypedSchema) -> None:
raise NotImplementedError("Empty schema cache. Cannot set schemas.")

def get_schema(self, schema_id: SchemaId) -> None:
return None

def get_schema_str(self, schema_id: SchemaId) -> None:
return None
3 changes: 1 addition & 2 deletions karapace/schema_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ async def write_new_schema_local(

all_schema_versions = self.database.find_subject_schemas(subject=subject, include_deleted=True)
if not all_schema_versions:
version = 1
version = ResolvedVersion(1)
schema_id = self.database.get_schema_id(new_schema)
LOG.debug(
"Registering new subject: %r, id: %r with version: %r with schema %r, schema_id: %r",
Expand Down Expand Up @@ -407,7 +407,6 @@ async def write_new_schema_local(
# We didn't find an existing schema and the schema is compatible so go and create one
version = self.database.get_next_version(subject=subject)
schema_id = self.database.get_schema_id(new_schema)

LOG.debug(
"Registering subject: %r, id: %r new version: %r with schema %s, schema_id: %r",
subject,
Expand Down
12 changes: 8 additions & 4 deletions karapace/schema_registry_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ async def schemas_get(
self, content_type: str, *, request: HTTPRequest, user: User | None = None, schema_id: str
) -> None:
try:
schema_id_int = int(schema_id)
parsed_schema_id = SchemaId(int(schema_id))
except ValueError:
self.r(
body={
Expand All @@ -473,7 +473,7 @@ async def schemas_get(
)

fetch_max_id = request.query.get("fetchMaxId", "false").lower() == "true"
schema = self.schema_registry.schemas_get(schema_id_int, fetch_max_id=fetch_max_id)
schema = self.schema_registry.schemas_get(parsed_schema_id, fetch_max_id=fetch_max_id)

def _has_subject_with_id() -> bool:
schema_versions = self.schema_registry.database.find_schemas(include_deleted=True, latest_only=False)
Expand All @@ -482,7 +482,7 @@ def _has_subject_with_id() -> bool:
continue
for schema_version in schema_versions:
if (
schema_version.schema_id == schema_id_int
schema_version.schema_id == parsed_schema_id
and not schema_version.deleted
and self._auth is not None
and self._auth.check_authorization(user, Operation.Read, f"Subject:{subject}")
Expand All @@ -504,13 +504,17 @@ def _has_subject_with_id() -> bool:
content_type=content_type,
status=HTTPStatus.NOT_FOUND,
)
response_body = {"schema": schema.schema_str}

subjects = self.schema_registry.database.subjects_for_schema(parsed_schema_id)

response_body = {"schema": schema.schema_str, "subjects": subjects}
if schema.schema_type is not SchemaType.AVRO:
response_body["schemaType"] = schema.schema_type
if schema.references:
response_body["references"] = [r.to_dict() for r in schema.references]
if fetch_max_id:
response_body["maxId"] = schema.max_id

self.r(response_body, content_type)

async def schemas_get_versions(
Expand Down
Loading

0 comments on commit ca96e88

Please sign in to comment.