Skip to content

Commit

Permalink
Fixed compatibility with REST Prokxy
Browse files Browse the repository at this point in the history
  • Loading branch information
blootsvoets committed Sep 20, 2023
1 parent 1180882 commit 9622000
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 37 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
package org.radarbase.producer.io

import io.ktor.client.*
import io.ktor.client.plugins.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.util.*
import io.ktor.util.cio.*
import io.ktor.utils.io.*
import io.ktor.client.HttpClient
import io.ktor.client.plugins.HttpClientPlugin
import io.ktor.client.request.HttpRequestPipeline
import io.ktor.http.ContentType
import io.ktor.http.Headers
import io.ktor.http.HeadersBuilder
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpMethod
import io.ktor.http.content.OutgoingContent
import io.ktor.http.contentLength
import io.ktor.util.AttributeKey
import io.ktor.util.KtorDsl
import io.ktor.util.cio.use
import io.ktor.util.deflated
import io.ktor.utils.io.ByteReadChannel
import io.ktor.utils.io.ByteWriteChannel
import kotlinx.coroutines.coroutineScope

/**
Expand All @@ -30,9 +38,9 @@ class GzipContentEncoding private constructor() {

return when (content) {
is OutgoingContent.ProtocolUpgrade, is OutgoingContent.NoContent -> content
is OutgoingContent.ReadChannelContent -> GzipReadChannel(content.readFrom())
is OutgoingContent.ByteArrayContent -> GzipReadChannel(ByteReadChannel(content.bytes()))
is OutgoingContent.WriteChannelContent -> GzipWriteChannel(content)
is OutgoingContent.ReadChannelContent -> GzipReadChannel(content.readFrom(), content.contentType)
is OutgoingContent.ByteArrayContent -> GzipReadChannel(ByteReadChannel(content.bytes()), content.contentType)
is OutgoingContent.WriteChannelContent -> GzipWriteChannel(content, content.contentType)
}
}

Expand Down Expand Up @@ -74,13 +82,15 @@ class GzipContentEncoding private constructor() {

private class GzipReadChannel(
private val original: ByteReadChannel,
override val contentType: ContentType?,
) : OutgoingContent.ReadChannelContent() {
override fun readFrom(): ByteReadChannel =
original.deflated(gzip = true)
}

private class GzipWriteChannel(
private val content: WriteChannelContent,
override val contentType: ContentType?,
) : OutgoingContent.WriteChannelContent() {
override suspend fun writeTo(channel: ByteWriteChannel) {
coroutineScope {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,19 @@
*/
package org.radarbase.producer.rest

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Job
import kotlinx.coroutines.channels.BufferOverflow.DROP_OLDEST
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.distinctUntilChanged
import kotlinx.coroutines.flow.shareIn
import kotlinx.coroutines.flow.transformLatest
import kotlinx.coroutines.plus
import kotlin.coroutines.EmptyCoroutineContext
import kotlin.time.Duration

Expand Down Expand Up @@ -47,7 +58,10 @@ class ConnectionState(

val scope = scope + Job()

private val mutableState = MutableStateFlow(State.UNKNOWN)
private val mutableState = MutableSharedFlow<State>(
extraBufferCapacity = 1,
onBufferOverflow = DROP_OLDEST,
)

@OptIn(ExperimentalCoroutinesApi::class)
val state: Flow<State> = mutableState
Expand All @@ -58,27 +72,28 @@ class ConnectionState(
emit(State.UNKNOWN)
}
}
.distinctUntilChanged()
.shareIn(this.scope + Dispatchers.Unconfined, SharingStarted.Eagerly, replay = 1)

init {
mutableState.value = State.UNKNOWN
mutableState.tryEmit(State.UNKNOWN)
}

/** For a sender to indicate that a connection attempt succeeded. */
fun didConnect() {
mutableState.value = State.CONNECTED
suspend fun didConnect() {
mutableState.emit(State.CONNECTED)
}

/** For a sender to indicate that a connection attempt failed. */
fun didDisconnect() {
mutableState.value = State.DISCONNECTED
suspend fun didDisconnect() {
mutableState.emit(State.DISCONNECTED)
}

fun wasUnauthorized() {
mutableState.value = State.UNAUTHORIZED
suspend fun wasUnauthorized() {
mutableState.emit(State.UNAUTHORIZED)
}

fun reset() {
mutableState.value = State.UNKNOWN
suspend fun reset() {
mutableState.emit(State.UNKNOWN)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
*/
package org.radarbase.producer.rest

import io.ktor.client.statement.HttpResponse
import io.ktor.client.statement.bodyAsText
import io.ktor.client.statement.request
import io.ktor.http.HttpStatusCode
import io.ktor.http.Url
import java.io.IOException

/**
Expand All @@ -25,11 +29,18 @@ import java.io.IOException
*/
class RestException(
val status: HttpStatusCode,
url: Url? = null,
body: String? = null,
cause: Throwable? = null,
) : IOException(
buildString(150) {
append("REST call failed (HTTP code ")
append("REST call ")
if (url != null) {
append("to <")
append(url)
append("> ")
}
append("failed (HTTP code ")
append(status)
if (body == null) {
append(')')
Expand All @@ -45,4 +56,8 @@ class RestException(
}
},
cause,
)
) {
companion object {
suspend fun HttpResponse.toRestException() = RestException(status, request.url, bodyAsText())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.serialization.kotlinx.serialization
import io.ktor.util.reflect.*
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.first
import kotlinx.serialization.json.Json
import org.apache.avro.SchemaValidationException
import org.radarbase.data.RecordData
import org.radarbase.producer.AuthenticationException
Expand All @@ -36,6 +38,7 @@ import org.radarbase.producer.io.GzipContentEncoding
import org.radarbase.producer.io.UnsupportedMediaTypeException
import org.radarbase.producer.io.timeout
import org.radarbase.producer.io.unsafeSsl
import org.radarbase.producer.rest.RestException.Companion.toRestException
import org.radarbase.producer.schema.SchemaRetriever
import org.radarbase.topic.AvroTopic
import org.radarbase.util.RadarProducerDsl
Expand Down Expand Up @@ -69,7 +72,7 @@ class RestKafkaSender(config: Config) : KafkaSender {
override val connectionState: Flow<ConnectionState.State>
get() = _connectionState.state

private val baseUrl: String = requireNotNull(config.baseUrl)
private val baseUrl: String = requireNotNull(config.baseUrl).trimEnd('/')
private val headers: Headers = config.headers.build()
private val connectionTimeout: Duration = config.connectionTimeout
private val contentEncoding = config.contentEncoding
Expand Down Expand Up @@ -97,6 +100,12 @@ class RestKafkaSender(config: Config) : KafkaSender {
KAFKA_REST_JSON_ENCODING,
AvroContentConverter(schemaRetriever, binary = false),
)
serialization(
KAFKA_REST_ACCEPT,
Json {
ignoreUnknownKeys = true
},
)
}
when (contentEncoding) {
GZIP_CONTENT_ENCODING -> install(GzipContentEncoding)
Expand All @@ -106,7 +115,7 @@ class RestKafkaSender(config: Config) : KafkaSender {
unsafeSsl()
}
defaultRequest {
url(baseUrl)
url("$baseUrl/")
contentType(contentType)
accept(ContentType.Application.Json)
headers {
Expand All @@ -118,7 +127,7 @@ class RestKafkaSender(config: Config) : KafkaSender {
inner class RestKafkaTopicSender<K : Any, V : Any>(
override val topic: AvroTopic<K, V>,
) : KafkaTopicSender<K, V> {
override suspend fun send(records: RecordData<K, V>) = scope.async {
override suspend fun send(records: RecordData<K, V>) = withContext(scope.coroutineContext) {
try {
val response: HttpResponse = restClient.post {
url("topics/${topic.name}")
Expand All @@ -132,18 +141,18 @@ class RestKafkaSender(config: Config) : KafkaSender {
throw AuthenticationException("Request unauthorized")
} else if (response.status == HttpStatusCode.UnsupportedMediaType) {
throw UnsupportedMediaTypeException(
response.request.contentType(),
response.request.contentType() ?: response.request.content.contentType,
response.request.headers[HttpHeaders.ContentEncoding],
)
} else {
_connectionState.didDisconnect()
throw RestException(response.status, response.bodyAsText())
throw response.toRestException()
}
} catch (ex: IOException) {
_connectionState.didDisconnect()
throw ex
}
}.await()
}
}

@Throws(SchemaValidationException::class)
Expand Down Expand Up @@ -255,6 +264,7 @@ class RestKafkaSender(config: Config) : KafkaSender {
val DEFAULT_TIMEOUT: Duration = 20.seconds
val KAFKA_REST_BINARY_ENCODING = ContentType("application", "vnd.radarbase.avro.v1+binary")
val KAFKA_REST_JSON_ENCODING = ContentType("application", "vnd.kafka.avro.v2+json")
val KAFKA_REST_ACCEPT = ContentType("application", "vnd.kafka.v2+json")
const val GZIP_CONTENT_ENCODING = "gzip"

init {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import kotlinx.serialization.json.Json
import org.apache.avro.Schema
import org.radarbase.producer.rest.RestException
import org.radarbase.producer.rest.RestException.Companion.toRestException
import java.io.IOException
import kotlin.coroutines.CoroutineContext

Expand Down Expand Up @@ -50,7 +50,7 @@ class SchemaRestClient(
requestBuilder()
}
if (!response.status.isSuccess()) {
throw RestException(response.status, response.bodyAsText())
throw response.toRestException()
}
response.body(typeInfo)
}
Expand All @@ -62,7 +62,7 @@ class SchemaRestClient(
requestBuilder()
}
if (!response.status.isSuccess()) {
throw RestException(response.status, response.bodyAsText())
throw response.toRestException()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ open class SchemaRetriever(config: Config) {
val subject = subject(topic, ofValue)
val metadata = restClient.addSchema(subject, schema)

launch {
cachedMetadata(subject, metadata.schema).set(metadata)
}
if (metadata.version != null) {
launch {
cachedMetadata(subject, metadata.schema).set(metadata)
}
launch {
cachedVersion(subject, metadata.version).set(metadata)
}
Expand Down

0 comments on commit 9622000

Please sign in to comment.