Skip to content

Commit

Permalink
Fixed inconsistent schema during message produce, from now on it is p…
Browse files Browse the repository at this point in the history
…ossible to produce a message only if the schema sent with the record is registered to the topic.
  • Loading branch information
eliax1996 committed Jul 3, 2023
1 parent 5693047 commit 3293378
Show file tree
Hide file tree
Showing 22 changed files with 571 additions and 94 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ venv
.run
.python-version
.hypothesis/
.DS_Store
2 changes: 1 addition & 1 deletion karapace/avro_dataclasses/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def optional_parser(parser: Parser | None) -> Parser | None:
return None

def parse(value: object) -> object:
return None if value is None else parser(value) # type: ignore[misc]
return None if value is None else parser(value)

return parse

Expand Down
2 changes: 2 additions & 0 deletions karapace/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ async def get(
json: JsonData = None,
headers: Optional[Headers] = None,
auth: Optional[BasicAuth] = None,
params: Optional[Mapping[str, str]] = None,
) -> Result:
path = self.path_for(path)
if not headers:
Expand All @@ -101,6 +102,7 @@ async def get(
headers=headers,
auth=auth,
ssl=self.ssl_mode,
params=params,
) as res:
# required for forcing the response body conversion to json despite missing valid Accept headers
json_result = await res.json(content_type=None)
Expand Down
15 changes: 13 additions & 2 deletions karapace/in_memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from karapace.schema_references import Reference, Referents
from karapace.typing import ResolvedVersion, SchemaId, Subject
from threading import Lock, RLock
from typing import Iterable
from typing import Iterable, Sequence

import logging

Expand Down Expand Up @@ -111,7 +111,7 @@ def insert_schema_version(
version: ResolvedVersion,
deleted: bool,
schema: TypedSchema,
references: list[Reference] | None,
references: Sequence[Reference] | None,
) -> None:
with self.schema_lock_thread:
self.global_schema_id = max(self.global_schema_id, schema_id)
Expand Down Expand Up @@ -184,6 +184,17 @@ def find_schemas(self, *, include_deleted: bool, latest_only: bool) -> dict[Subj
res_schemas[subject] = selected_schemas
return res_schemas

def subjects_for_schema(self, schema_id: SchemaId) -> list[Subject]:
subjects = []
with self.schema_lock_thread:
for subject, subject_data in self.subjects.items():
for version in subject_data.schemas.values():
if version.deleted is False and version.schema_id == schema_id:
subjects.append(subject)
break

return subjects

def find_schema_versions_by_schema_id(self, *, schema_id: SchemaId, include_deleted: bool) -> list[SchemaVersion]:
schema_versions: list[SchemaVersion] = []
with self.schema_lock_thread:
Expand Down
106 changes: 89 additions & 17 deletions karapace/kafka_rest_apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
from karapace.kafka_rest_apis.admin import KafkaRestAdminClient
from karapace.kafka_rest_apis.consumer_manager import ConsumerManager
from karapace.kafka_rest_apis.error_codes import RESTErrorCodes
from karapace.kafka_rest_apis.schema_cache import TopicSchemaCache
from karapace.karapace import KarapaceBase
from karapace.rapu import HTTPRequest, HTTPResponse, JSON_CONTENT_TYPE
from karapace.schema_reader import SchemaType
from karapace.schema_models import TypedSchema, ValidatedTypedSchema
from karapace.schema_type import SchemaType
from karapace.serialization import InvalidMessageSchema, InvalidPayload, SchemaRegistrySerializer, SchemaRetrievalError
from karapace.typing import SchemaId, Subject
from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient
from typing import Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Union

import aiohttp.web
import asyncio
Expand Down Expand Up @@ -416,7 +419,7 @@ def __init__(self, config: Config, kafka_timeout: int, serializer):
self.admin_client = None
self.admin_lock = asyncio.Lock()
self.metadata_cache = None
self.schemas_cache = {}
self.topic_schema_cache = TopicSchemaCache()
self.consumer_manager = ConsumerManager(config=config, deserializer=self.serializer)
self.init_admin_client()
self._last_used = time.monotonic()
Expand Down Expand Up @@ -739,18 +742,83 @@ def is_valid_schema_request(data: dict, prefix: str) -> bool:
return False
return isinstance(schema, str)

async def get_schema_id(self, data: dict, topic: str, prefix: str, schema_type: SchemaType) -> int:
log.debug("Retrieving schema id for %r", data)
if f"{prefix}_schema_id" in data and data[f"{prefix}_schema_id"] is not None:
log.debug("Will use schema id %d for serializing %s on topic %s", data[f"{prefix}_schema_id"], prefix, topic)
return int(data[f"{prefix}_schema_id"])
schema_str = data[f"{prefix}_schema"]
log.debug("Registering / Retrieving ID for schema %s", schema_str)
if schema_str not in self.schemas_cache:
subject_name = self.serializer.get_subject_name(topic, data[f"{prefix}_schema"], prefix, schema_type)
schema_id = await self.serializer.get_id_for_schema(data[f"{prefix}_schema"], subject_name, schema_type)
self.schemas_cache[schema_str] = schema_id
return self.schemas_cache[schema_str]
async def get_schema_id(
self,
data: dict,
topic: str,
prefix: str,
schema_type: SchemaType,
) -> SchemaId:
"""
This method search and validate the SchemaId for a request, it acts as a guard (In case of something wrong
throws an error).
:raises InvalidSchema:
"""
log.debug("[resolve schema id] Retrieving schema id for %r", data)
schema_id: Union[SchemaId, None] = (
SchemaId(int(data[f"{prefix}_schema_id"])) if f"{prefix}_schema_id" in data else None
)
schema_str = data.get(f"{prefix}_schema")

if schema_id is None and schema_str is None:
raise InvalidSchema()

if schema_id is None:
parsed_schema = ValidatedTypedSchema.parse(schema_type, schema_str)
subject_name = self.serializer.get_subject_name(topic, parsed_schema, prefix, schema_type)
schema_id = await self._query_schema_id_from_cache_or_registry(parsed_schema, schema_str, subject_name)
else:

def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool:
subject = self.serializer.get_subject_name(topic, schema, prefix, schema_type)
return subject not in subjects

parsed_schema, valid_subjects = await self._query_schema_and_subjects(
schema_id,
need_new_call=subject_not_included,
)

if subject_not_included(parsed_schema, valid_subjects):
raise InvalidSchema()

return schema_id

async def _query_schema_and_subjects(
self, schema_id: SchemaId, *, need_new_call: Optional[Callable[[TypedSchema, List[Subject]], bool]]
) -> Tuple[TypedSchema, List[Subject]]:
try:
return await self.serializer.get_schema_for_id(schema_id, need_new_call=need_new_call)
except SchemaRetrievalError as schema_error:
# if the schema doesn't exist we treated as if the error was due to an invalid schema
raise InvalidSchema() from schema_error

async def _query_schema_id_from_cache_or_registry(
self,
parsed_schema: ValidatedTypedSchema,
schema_str: str,
subject_name: Subject,
) -> SchemaId:
"""
Checks if the schema registered with a certain id match with the schema provided (you can provide
a valid id but place in the body a totally unrelated schema).
Also, here if we don't have a match we query the registry since the cache could be evicted in the meanwhile
or the schema could be registered without passing though the http proxy.
"""
schema_id = self.topic_schema_cache.get_schema_id(subject_name, parsed_schema)
if schema_id is None:
log.debug("[resolve schema id] Registering / Retrieving ID for %s and schema %s", subject_name, schema_str)
schema_id = await self.serializer.upsert_id_for_schema(parsed_schema, subject_name)
log.debug("[resolve schema id] Found schema id %s from registry for subject %s", schema_id, subject_name)
self.topic_schema_cache.set_schema(subject_name, schema_id, parsed_schema)
else:
log.debug(
"[resolve schema id] schema ID %s found from cache for %s and schema %s",
schema_id,
subject_name,
schema_str,
)
return schema_id

async def validate_schema_info(self, data: dict, prefix: str, content_type: str, topic: str, schema_type: str):
try:
Expand Down Expand Up @@ -788,10 +856,14 @@ async def validate_schema_info(self, data: dict, prefix: str, content_type: str,
status=HTTPStatus.REQUEST_TIMEOUT,
)
except InvalidSchema:
if f"{prefix}_schema" in data:
err = f'schema = {data[f"{prefix}_schema"]}'
else:
err = f'schema_id = {data[f"{prefix}_schema_id"]}'
KafkaRest.r(
body={
"error_code": RESTErrorCodes.INVALID_DATA.value,
"message": f'Invalid schema. format = {schema_type.value}, schema = {data[f"{prefix}_schema"]}',
"message": f"Invalid schema. format = {schema_type.value}, {err}",
},
content_type=content_type,
status=HTTPStatus.UNPROCESSABLE_ENTITY,
Expand Down Expand Up @@ -882,7 +954,7 @@ async def serialize(
raise FormatError(f"Unknown format: {ser_format}")

async def schema_serialize(self, obj: dict, schema_id: Optional[int]) -> bytes:
schema = await self.serializer.get_schema_for_id(schema_id)
schema, _ = await self.serializer.get_schema_for_id(schema_id)
bytes_ = await self.serializer.serialize(schema, obj)
return bytes_

Expand Down
107 changes: 107 additions & 0 deletions karapace/kafka_rest_apis/schema_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
Copyright (c) 2023 Aiven Ltd
See LICENSE for details
"""

from abc import ABC, abstractmethod
from cachetools import TTLCache
from karapace.schema_models import TypedSchema
from karapace.typing import SchemaId, Subject
from typing import Dict, Final, MutableMapping, Optional

import hashlib


class SchemaCacheProtocol(ABC):
@abstractmethod
def get_schema_id(self, schema: TypedSchema) -> Optional[SchemaId]:
pass

@abstractmethod
def has_schema_id(self, schema_id: SchemaId) -> bool:
pass

@abstractmethod
def set_schema(self, schema_id: SchemaId, schema: TypedSchema) -> None:
pass

@abstractmethod
def get_schema(self, schema_id: SchemaId) -> Optional[TypedSchema]:
pass

@abstractmethod
def get_schema_str(self, schema_id: SchemaId) -> Optional[str]:
pass


class TopicSchemaCache:
def __init__(self) -> None:
self._topic_cache: Dict[Subject, SchemaCache] = {}
self._empty_schema_cache: Final = EmptySchemaCache()

def get_schema_id(self, topic: Subject, schema: TypedSchema) -> Optional[SchemaId]:
return self._topic_cache.get(topic, self._empty_schema_cache).get_schema_id(schema)

def has_schema_id(self, topic: Subject, schema_id: SchemaId) -> bool:
return self._topic_cache.get(topic, self._empty_schema_cache).has_schema_id(schema_id)

def set_schema(self, topic: str, schema_id: SchemaId, schema: TypedSchema) -> None:
schema_cache_with_defaults = self._topic_cache.setdefault(Subject(topic), SchemaCache())
schema_cache_with_defaults.set_schema(schema_id, schema)

def get_schema(self, topic: Subject, schema_id: SchemaId) -> Optional[TypedSchema]:
schema_cache = self._topic_cache.get(topic, self._empty_schema_cache)
return schema_cache.get_schema(schema_id)

def get_schema_str(self, topic: Subject, schema_id: SchemaId) -> Optional[str]:
schema_cache = self._topic_cache.get(topic, self._empty_schema_cache)
return schema_cache.get_schema_str(schema_id)


class SchemaCache(SchemaCacheProtocol):
def __init__(self) -> None:
self._schema_hash_str_to_id: Dict[str, SchemaId] = {}
self._id_to_schema_str: MutableMapping[SchemaId, TypedSchema] = TTLCache(maxsize=10000, ttl=600)

def get_schema_id(self, schema: TypedSchema) -> Optional[SchemaId]:
fingerprint = hashlib.sha1(str(schema).encode("utf8")).hexdigest()

maybe_id = self._schema_hash_str_to_id.get(fingerprint)

if maybe_id is not None and maybe_id not in self._id_to_schema_str:
del self._schema_hash_str_to_id[fingerprint]
return None

return maybe_id

def has_schema_id(self, schema_id: SchemaId) -> bool:
return schema_id in self._id_to_schema_str

def set_schema(self, schema_id: SchemaId, schema: TypedSchema) -> None:
fingerprint = hashlib.sha1(str(schema).encode("utf8")).hexdigest()
self._schema_hash_str_to_id[fingerprint] = schema_id
self._id_to_schema_str[schema_id] = schema

def get_schema(self, schema_id: SchemaId) -> Optional[TypedSchema]:
return self._id_to_schema_str.get(schema_id)

def get_schema_str(self, schema_id: SchemaId) -> Optional[str]:
maybe_schema = self.get_schema(schema_id)
return None if maybe_schema is None else str(maybe_schema)


class EmptySchemaCache(SchemaCacheProtocol):
def get_schema_id(self, schema: TypedSchema) -> None:
return None

def has_schema_id(self, schema_id: SchemaId) -> bool:
return False

def set_schema(self, schema_id: SchemaId, schema: TypedSchema) -> None:
raise NotImplementedError("Empty schema cache. Cannot set schemas.")

def get_schema(self, schema_id: SchemaId) -> None:
return None

def get_schema_str(self, schema_id: SchemaId) -> None:
return None
3 changes: 1 addition & 2 deletions karapace/schema_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ async def write_new_schema_local(

all_schema_versions = self.database.find_subject_schemas(subject=subject, include_deleted=True)
if not all_schema_versions:
version = 1
version = ResolvedVersion(1)
schema_id = self.database.get_schema_id(new_schema)
LOG.debug(
"Registering new subject: %r, id: %r with version: %r with schema %r, schema_id: %r",
Expand Down Expand Up @@ -407,7 +407,6 @@ async def write_new_schema_local(
# We didn't find an existing schema and the schema is compatible so go and create one
version = self.database.get_next_version(subject=subject)
schema_id = self.database.get_schema_id(new_schema)

LOG.debug(
"Registering subject: %r, id: %r new version: %r with schema %s, schema_id: %r",
subject,
Expand Down
12 changes: 8 additions & 4 deletions karapace/schema_registry_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ async def schemas_get(
self, content_type: str, *, request: HTTPRequest, user: User | None = None, schema_id: str
) -> None:
try:
schema_id_int = int(schema_id)
parsed_schema_id = SchemaId(int(schema_id))
except ValueError:
self.r(
body={
Expand All @@ -473,7 +473,7 @@ async def schemas_get(
)

fetch_max_id = request.query.get("fetchMaxId", "false").lower() == "true"
schema = self.schema_registry.schemas_get(schema_id_int, fetch_max_id=fetch_max_id)
schema = self.schema_registry.schemas_get(parsed_schema_id, fetch_max_id=fetch_max_id)

def _has_subject_with_id() -> bool:
schema_versions = self.schema_registry.database.find_schemas(include_deleted=True, latest_only=False)
Expand All @@ -482,7 +482,7 @@ def _has_subject_with_id() -> bool:
continue
for schema_version in schema_versions:
if (
schema_version.schema_id == schema_id_int
schema_version.schema_id == parsed_schema_id
and not schema_version.deleted
and self._auth is not None
and self._auth.check_authorization(user, Operation.Read, f"Subject:{subject}")
Expand All @@ -504,13 +504,17 @@ def _has_subject_with_id() -> bool:
content_type=content_type,
status=HTTPStatus.NOT_FOUND,
)
response_body = {"schema": schema.schema_str}

subjects = self.schema_registry.database.subjects_for_schema(parsed_schema_id)

response_body = {"schema": schema.schema_str, "subjects": subjects}
if schema.schema_type is not SchemaType.AVRO:
response_body["schemaType"] = schema.schema_type
if schema.references:
response_body["references"] = [r.to_dict() for r in schema.references]
if fetch_max_id:
response_body["maxId"] = schema.max_id

self.r(response_body, content_type)

async def schemas_get_versions(
Expand Down
Loading

0 comments on commit 3293378

Please sign in to comment.