Skip to content

Commit

Permalink
feat: enable users to disable validation only on specific topics
Browse files Browse the repository at this point in the history
  • Loading branch information
eliax1996 committed Nov 10, 2023
1 parent 0a57f3b commit d6c1221
Show file tree
Hide file tree
Showing 12 changed files with 359 additions and 43 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ Keys to take special care are the ones needed to configure Kafka and advertised_
- Name strategy to use when storing schemas from the kafka rest proxy service. You can opt between ``name_strategy`` , ``record_name`` and ``topic_record_name``
* - ``name_strategy_validation``
- ``true``
- If enabled, validate that given schema is registered under used name strategy when producing messages from Kafka Rest
- If enabled, validate that given schema is registered under the expected subjects requireds by the specified name strategy when producing messages from Kafka Rest. Otherwise no validation are performed
* - ``master_election_strategy``
- ``lowest``
- Decides on what basis the Karapace cluster master is chosen (only relevant in a multi node setup)
Expand Down
12 changes: 11 additions & 1 deletion karapace/in_memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass, field
from karapace.schema_models import SchemaVersion, TypedSchema
from karapace.schema_references import Reference, Referents
from karapace.typing import ResolvedVersion, SchemaId, Subject
from karapace.typing import ResolvedVersion, SchemaId, Subject, TopicName
from threading import Lock, RLock
from typing import Iterable, Sequence

Expand All @@ -32,6 +32,7 @@ def __init__(self) -> None:
self.schemas: dict[SchemaId, TypedSchema] = {}
self.schema_lock_thread = RLock()
self.referenced_by: dict[tuple[Subject, ResolvedVersion], Referents] = {}
self.topic_without_validation: set[TopicName] = set()

# Content based deduplication of schemas. This is used to reduce memory
# usage when the same schema is produce multiple times to the same or
Expand Down Expand Up @@ -229,6 +230,15 @@ def find_subject_schemas(self, *, subject: Subject, include_deleted: bool) -> di
if schema_version.deleted is False
}

def is_topic_requiring_validation(self, *, topic_name: TopicName) -> bool:
return topic_name not in self.topic_without_validation

def override_topic_validation(self, *, topic_name: TopicName, skip_validation: bool) -> None:
if skip_validation:
self.topic_without_validation.add(topic_name)
else:
self.topic_without_validation.discard(topic_name)

def delete_subject(self, *, subject: Subject, version: ResolvedVersion) -> None:
with self.schema_lock_thread:
for schema_version in self.subjects[subject].schemas.values():
Expand Down
105 changes: 83 additions & 22 deletions karapace/kafka_rest_apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from aiokafka import AIOKafkaProducer
from aiokafka.errors import KafkaConnectionError
from binascii import Error as B64DecodeError
Expand Down Expand Up @@ -32,20 +34,21 @@
get_subject_name,
InvalidMessageSchema,
InvalidPayload,
SchemaRegistryClient,
SchemaRegistrySerializer,
SchemaRetrievalError,
)
from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType
from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType, TopicName
from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Final, MutableMapping, NewType

import asyncio
import base64
import datetime
import logging
import time

SUBJECT_VALID_POSTFIX = [SubjectType.key, SubjectType.value]
SUBJECT_VALID_POSTFIX = [SubjectType.key, SubjectType.value_]
PUBLISH_KEYS = {"records", "value_schema", "value_schema_id", "key_schema", "key_schema_id"}
RECORD_CODES = [42201, 42202]
KNOWN_FORMATS = {"json", "avro", "protobuf", "binary"}
Expand All @@ -67,10 +70,10 @@ def __init__(self, config: Config) -> None:
super().__init__(config=config)
self._add_kafka_rest_routes()
self.serializer = SchemaRegistrySerializer(config=config)
self.proxies: Dict[str, "UserRestProxy"] = {}
self.proxies: dict[str, UserRestProxy] = {}
self._proxy_lock = asyncio.Lock()
log.info("REST proxy starting with (delegated authorization=%s)", self.config.get("rest_authorization", False))
self._idle_proxy_janitor_task: Optional[asyncio.Task] = None
self._idle_proxy_janitor_task: asyncio.Task | None = None

async def close(self) -> None:
if self._idle_proxy_janitor_task is not None:
Expand Down Expand Up @@ -419,13 +422,56 @@ async def topic_publish(self, topic: str, content_type: str, *, request: HTTPReq
await proxy.topic_publish(topic, content_type, request=request)


LastTimeCheck = NewType("LastTimeCheck", float)

DEFAULT_CACHE_INTERVAL_NS: Final = 120 * 1_000_000_000 # 120 seconds


class ValidationCheckWrapper:
def __init__(
self,
registry_client: SchemaRegistryClient,
topic_name: TopicName,
cache_interval_ns: float = DEFAULT_CACHE_INTERVAL_NS,
):
self._last_check = 0
# by default if not specified otherwise, let's be conservative
self._require_validation = True
self._topic_name = topic_name
self._registry_client = registry_client
self._cache_interval_ns = cache_interval_ns

async def _query_registry(self) -> bool:
require_validation = await self._registry_client.topic_require_validation(self._topic_name)
return require_validation

async def require_validation(self) -> bool:
if (time.monotonic_ns() - self._last_check) > self._cache_interval_ns:
self._require_validation = await self._query_registry()
self._last_check = time.monotonic_ns()

return self._require_validation

@classmethod
async def construct_new(
cls,
registry_client: SchemaRegistryClient,
topic_name: TopicName,
cache_interval_ns: float = DEFAULT_CACHE_INTERVAL_NS,
) -> ValidationCheckWrapper:
validation_checker = cls(registry_client, topic_name, cache_interval_ns)
validation_checker._require_validation = await validation_checker._query_registry()
validation_checker._last_check = time.monotonic_ns()
return validation_checker


class UserRestProxy:
def __init__(
self,
config: Config,
kafka_timeout: int,
serializer: SchemaRegistrySerializer,
auth_expiry: Optional[datetime.datetime] = None,
auth_expiry: datetime.datetime | None = None,
):
self.config = config
self.kafka_timeout = kafka_timeout
Expand All @@ -444,8 +490,18 @@ def __init__(
self._auth_expiry = auth_expiry

self._async_producer_lock = asyncio.Lock()
self._async_producer: Optional[AIOKafkaProducer] = None
self._async_producer: AIOKafkaProducer | None = None
self.naming_strategy = NameStrategy(self.config["name_strategy"])
self.topic_validation: MutableMapping[TopicName,] = {}

async def is_validation_required(self, topic_name: TopicName) -> bool:
if topic_name not in self.topic_validation:
self.topic_validation[topic_name] = await ValidationCheckWrapper.construct_new(
self.serializer.registry_client,
topic_name,
)

return await self.topic_validation[topic_name].require_validation()

def __str__(self) -> str:
return f"UserRestProxy(username={self.config['sasl_plain_username']})"
Expand Down Expand Up @@ -605,7 +661,7 @@ async def get_topic_config(self, topic: str) -> dict:
async with self.admin_lock:
return self.admin_client.get_topic_config(topic)

async def cluster_metadata(self, topics: Optional[List[str]] = None) -> dict:
async def cluster_metadata(self, topics: list[str] | None = None) -> dict:
async with self.admin_lock:
if self._metadata_birth is None or time.monotonic() - self._metadata_birth > self.metadata_max_age:
self._cluster_metadata = None
Expand Down Expand Up @@ -678,7 +734,7 @@ async def aclose(self) -> None:
self.admin_client = None
self.consumer_manager = None

async def publish(self, topic: str, partition_id: Optional[str], content_type: str, request: HTTPRequest) -> None:
async def publish(self, topic: str, partition_id: str | None, content_type: str, request: HTTPRequest) -> None:
formats: dict = request.content_type
data: dict = request.json
_ = await self.get_topic_info(topic, content_type)
Expand Down Expand Up @@ -776,7 +832,7 @@ async def get_schema_id(
:raises InvalidSchema:
"""
log.debug("[resolve schema id] Retrieving schema id for %r", data)
schema_id: Union[SchemaId, None] = (
schema_id: SchemaId | None = (
SchemaId(int(data[f"{subject_type}_schema_id"])) if f"{subject_type}_schema_id" in data else None
)
schema_str = data.get(f"{subject_type}_schema")
Expand All @@ -795,8 +851,9 @@ async def get_schema_id(
)
schema_id = await self._query_schema_id_from_cache_or_registry(parsed_schema, schema_str, subject_name)
else:
is_validation_required = await self.is_validation_required(topic_name=TopicName(topic))

def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool:
def subject_not_included(schema: TypedSchema, subjects: list[Subject]) -> bool:
subject = get_subject_name(topic, schema, subject_type, self.naming_strategy)
return subject not in subjects

Expand All @@ -805,14 +862,18 @@ def subject_not_included(schema: TypedSchema, subjects: List[Subject]) -> bool:
need_new_call=subject_not_included,
)

if self.config["name_strategy_validation"] and subject_not_included(parsed_schema, valid_subjects):
if (
self.config["name_strategy_validation"]
and is_validation_required
and subject_not_included(parsed_schema, valid_subjects)
):
raise InvalidSchema()

return schema_id

async def _query_schema_and_subjects(
self, schema_id: SchemaId, *, need_new_call: Optional[Callable[[TypedSchema, List[Subject]], bool]]
) -> Tuple[TypedSchema, List[Subject]]:
self, schema_id: SchemaId, *, need_new_call: Callable[[TypedSchema, list[Subject]], bool] | None
) -> tuple[TypedSchema, list[Subject]]:
try:
return await self.serializer.get_schema_for_id(schema_id, need_new_call=need_new_call)
except SchemaRetrievalError as schema_error:
Expand Down Expand Up @@ -903,10 +964,10 @@ async def _prepare_records(
content_type: str,
data: dict,
ser_format: str,
key_schema_id: Optional[int],
value_schema_id: Optional[int],
default_partition: Optional[int] = None,
) -> List[Tuple]:
key_schema_id: int | None,
value_schema_id: int | None,
default_partition: int | None = None,
) -> list[tuple]:
prepared_records = []
for record in data["records"]:
key = record.get("key")
Expand Down Expand Up @@ -957,8 +1018,8 @@ async def serialize(
self,
content_type: str,
obj=None,
ser_format: Optional[str] = None,
schema_id: Optional[int] = None,
ser_format: str | None = None,
schema_id: int | None = None,
) -> bytes:
if not obj:
return b""
Expand All @@ -982,7 +1043,7 @@ async def serialize(
return await self.schema_serialize(obj, schema_id)
raise FormatError(f"Unknown format: {ser_format}")

async def schema_serialize(self, obj: dict, schema_id: Optional[int]) -> bytes:
async def schema_serialize(self, obj: dict, schema_id: int | None) -> bytes:
schema, _ = await self.serializer.get_schema_for_id(schema_id)
bytes_ = await self.serializer.serialize(schema, obj)
return bytes_
Expand Down Expand Up @@ -1045,7 +1106,7 @@ async def validate_publish_request_format(self, data: dict, formats: dict, conte
sub_code=RESTErrorCodes.INVALID_DATA.value,
)

async def produce_messages(self, *, topic: str, prepared_records: List) -> List:
async def produce_messages(self, *, topic: str, prepared_records: list) -> list:
producer = await self._maybe_create_async_producer()

produce_futures = []
Expand Down
13 changes: 10 additions & 3 deletions karapace/schema_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from avro.schema import Schema as AvroSchema
from contextlib import closing, ExitStack
from enum import Enum
from jsonschema.validators import Draft7Validator
from kafka import KafkaConsumer, TopicPartition
from kafka.admin import KafkaAdminClient, NewTopic
Expand All @@ -32,7 +31,7 @@
from karapace.schema_models import parse_protobuf_schema_definition, SchemaType, TypedSchema, ValidatedTypedSchema
from karapace.schema_references import LatestVersionReference, Reference, reference_from_mapping, Referents
from karapace.statsd import StatsClient
from karapace.typing import JsonObject, ResolvedVersion, SchemaId, Subject
from karapace.typing import JsonObject, ResolvedVersion, SchemaId, StrEnum, Subject, TopicName
from karapace.utils import json_decode, JSONDecodeError, KarapaceKafkaClient
from threading import Event, Thread
from typing import Final, Mapping, Sequence
Expand All @@ -59,10 +58,11 @@
METRIC_SUBJECT_DATA_SCHEMA_VERSIONS_GAUGE: Final = "karapace_schema_reader_subject_data_schema_versions"


class MessageType(Enum):
class MessageType(StrEnum):
config = "CONFIG"
schema = "SCHEMA"
delete_subject = "DELETE_SUBJECT"
schema_validation = "SCHEMA_VALIDATION"
no_operation = "NOOP"


Expand Down Expand Up @@ -437,6 +437,11 @@ def _handle_msg_delete_subject(self, key: dict, value: dict | None) -> None: #
LOG.info("Deleting subject: %r, value: %r", subject, value)
self.database.delete_subject(subject=subject, version=version)

def _handle_msg_schema_validation(self, key: dict, value: dict | None) -> None: # pylint: disable=unused-argument
assert isinstance(value, dict)
topic, skip_validation = TopicName(value["topic"]), bool(value["skip_validation"])
self.database.override_topic_validation(topic_name=topic, skip_validation=skip_validation)

def _handle_msg_schema_hard_delete(self, key: dict) -> None:
subject, version = key["subject"], key["version"]

Expand Down Expand Up @@ -540,6 +545,8 @@ def handle_msg(self, key: dict, value: dict | None) -> None:
self._handle_msg_schema(key, value)
elif message_type == MessageType.delete_subject:
self._handle_msg_delete_subject(key, value)
elif message_type == MessageType.schema_validation:
self._handle_msg_schema_validation(key, value)
elif message_type == MessageType.no_operation:
pass
except ValueError:
Expand Down
17 changes: 15 additions & 2 deletions karapace/schema_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from karapace.messaging import KarapaceProducer
from karapace.offset_watcher import OffsetWatcher
from karapace.schema_models import ParsedTypedSchema, SchemaType, SchemaVersion, TypedSchema, ValidatedTypedSchema
from karapace.schema_reader import KafkaSchemaReader
from karapace.schema_reader import KafkaSchemaReader, MessageType
from karapace.schema_references import LatestVersionReference, Reference
from karapace.typing import JsonObject, ResolvedVersion, SchemaId, Subject, Version
from karapace.typing import JsonObject, ResolvedVersion, SchemaId, Subject, TopicName, Version
from typing import Mapping, Sequence

import asyncio
Expand Down Expand Up @@ -466,6 +466,19 @@ def send_schema_message(
value = None
self.producer.send_message(key=key, value=value)

def is_topic_requiring_validation(self, *, topic_name: TopicName) -> bool:
return self.database.is_topic_requiring_validation(topic_name=topic_name)

def update_require_validation_for_topic(
self,
*,
topic_name: TopicName,
skip_validation: bool,
) -> None:
key = {"topic": topic_name, "keytype": str(MessageType.schema_validation), "magic": 0}
value = {"skip_validation": skip_validation, "topic": topic_name}
self.producer.send_message(key=key, value=value)

def send_config_message(self, compatibility_level: CompatibilityModes, subject: Subject | None = None) -> None:
key = {"subject": subject, "magic": 0, "keytype": "CONFIG"}
value = {"compatibilityLevel": compatibility_level.value}
Expand Down
Loading

0 comments on commit d6c1221

Please sign in to comment.