Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Firebase SDK with new APIs to consume streaming callable function response #6602

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package com.google.firebase.functions.ktx

import androidx.test.InstrumentationRegistry
import androidx.test.runner.AndroidJUnit4
import com.google.android.gms.tasks.Tasks
import com.google.common.truth.Truth.assertThat
import com.google.firebase.FirebaseApp
import com.google.firebase.functions.FirebaseFunctions
import com.google.firebase.functions.FirebaseFunctionsException
import com.google.firebase.functions.SSETaskListener
import com.google.firebase.ktx.Firebase
import com.google.firebase.ktx.initialize
import java.util.concurrent.ExecutionException
import java.util.concurrent.TimeUnit
import org.junit.After
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith

@RunWith(AndroidJUnit4::class)
class StreamTests {

private lateinit var app: FirebaseApp
private lateinit var listener: SSETaskListener

private lateinit var functions: FirebaseFunctions
var onNext = mutableListOf<Any>()
var onError: Any? = null
var onComplete: Any? = null

@Before
fun setup() {
app = Firebase.initialize(InstrumentationRegistry.getContext())!!
functions = FirebaseFunctions.getInstance()
functions.useEmulator("10.0.2.2", 5001)
listener =
object : SSETaskListener {
override fun onNext(event: Any) {
onNext.add(event)
}

override fun onError(event: Any) {
onError = event
}

override fun onComplete(event: Any) {
onComplete = event
}
}
}

@After
fun clear() {
onNext.clear()
onError = null
onComplete = null
}

@Test
fun testGenStream() {
val input = hashMapOf("data" to "Why is the sky blue")

val function = functions.getHttpsCallable("genStream")
val httpsCallableResult = Tasks.await(function.stream(input, listener))

val onNextStringList = onNext.map { it.toString() }
assertThat(onNextStringList)
.containsExactly(
"{chunk=hello}",
"{chunk=world}",
"{chunk=this}",
"{chunk=is}",
"{chunk=cool}"
)
assertThat(onError).isNull()
assertThat(onComplete).isEqualTo("hello world this is cool")
assertThat(httpsCallableResult.data).isEqualTo("hello world this is cool")
}

@Test
fun testGenStreamError() {
val input = hashMapOf("data" to "Why is the sky blue")
val function = functions.getHttpsCallable("genStreamError").withTimeout(7, TimeUnit.SECONDS)

try {
Tasks.await(function.stream(input, listener))
} catch (exception: Exception) {
onError = exception
}

val onNextStringList = onNext.map { it.toString() }
assertThat(onNextStringList)
.containsExactly(
"{chunk=hello}",
"{chunk=world}",
"{chunk=this}",
"{chunk=is}",
"{chunk=cool}"
)
assertThat(onError).isInstanceOf(ExecutionException::class.java)
val cause = (onError as ExecutionException).cause
assertThat(cause).isInstanceOf(FirebaseFunctionsException::class.java)
assertThat((cause as FirebaseFunctionsException).message).contains("Socket closed")
assertThat(onComplete).isNull()
}

@Test
fun testGenStreamNoReturn() {
val input = hashMapOf("data" to "Why is the sky blue")

val function = functions.getHttpsCallable("genStreamNoReturn")
try {
Tasks.await(function.stream(input, listener), 7, TimeUnit.SECONDS)
} catch (_: Exception) {}

val onNextStringList = onNext.map { it.toString() }
assertThat(onNextStringList)
.containsExactly(
"{chunk=hello}",
"{chunk=world}",
"{chunk=this}",
"{chunk=is}",
"{chunk=cool}"
)
assertThat(onError).isNull()
assertThat(onComplete).isNull()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ import com.google.firebase.functions.FirebaseFunctionsException.Code.Companion.f
import com.google.firebase.functions.FirebaseFunctionsException.Companion.fromResponse
import dagger.assisted.Assisted
import dagger.assisted.AssistedInject
import java.io.BufferedReader
import java.io.IOException
import java.io.InputStream
import java.io.InputStreamReader
import java.io.InterruptedIOException
import java.net.MalformedURLException
import java.net.URL
Expand Down Expand Up @@ -311,6 +314,229 @@ internal constructor(
return tcs.task
}

internal fun stream(
name: String,
data: Any?,
options: HttpsCallOptions,
listener: SSETaskListener
): Task<HttpsCallableResult> {
return providerInstalled.task
.continueWithTask(executor) { contextProvider.getContext(options.limitedUseAppCheckTokens) }
.continueWithTask(executor) { task: Task<HttpsCallableContext?> ->
if (!task.isSuccessful) {
return@continueWithTask Tasks.forException<HttpsCallableResult>(task.exception!!)
}
val context = task.result
val url = getURL(name)
stream(url, data, options, context, listener)
}
}

internal fun stream(
url: URL,
data: Any?,
options: HttpsCallOptions,
listener: SSETaskListener
): Task<HttpsCallableResult> {
return providerInstalled.task
.continueWithTask(executor) { contextProvider.getContext(options.limitedUseAppCheckTokens) }
.continueWithTask(executor) { task: Task<HttpsCallableContext?> ->
if (!task.isSuccessful) {
return@continueWithTask Tasks.forException<HttpsCallableResult>(task.exception!!)
}
val context = task.result
stream(url, data, options, context, listener)
}
}

private fun stream(
url: URL,
data: Any?,
options: HttpsCallOptions,
context: HttpsCallableContext?,
listener: SSETaskListener
): Task<HttpsCallableResult> {
Preconditions.checkNotNull(url, "url cannot be null")
val tcs = TaskCompletionSource<HttpsCallableResult>()
val callClient = options.apply(client)
callClient.postStream(url, tcs, listener) { applyCommonConfiguration(data, context) }

return tcs.task
}

private inline fun OkHttpClient.postStream(
url: URL,
tcs: TaskCompletionSource<HttpsCallableResult>,
listener: SSETaskListener,
crossinline config: Request.Builder.() -> Unit = {}
) {
val requestBuilder = Request.Builder().url(url)
requestBuilder.config()
val request = requestBuilder.build()

val call = newCall(request)
call.enqueue(
object : Callback {
override fun onFailure(ignored: Call, e: IOException) {
val exception: Exception =
if (e is InterruptedIOException) {
FirebaseFunctionsException(
FirebaseFunctionsException.Code.DEADLINE_EXCEEDED.name,
FirebaseFunctionsException.Code.DEADLINE_EXCEEDED,
null,
e
)
} else {
FirebaseFunctionsException(
FirebaseFunctionsException.Code.INTERNAL.name,
FirebaseFunctionsException.Code.INTERNAL,
null,
e
)
}
listener.onError(exception)
tcs.setException(exception)
}

@Throws(IOException::class)
override fun onResponse(ignored: Call, response: Response) {
try {
validateResponse(response)
val bodyStream = response.body()?.byteStream()
if (bodyStream != null) {
processSSEStream(bodyStream, serializer, listener, tcs)
} else {
val error =
FirebaseFunctionsException(
"Response body is null",
FirebaseFunctionsException.Code.INTERNAL,
null
)
listener.onError(error)
tcs.setException(error)
}
} catch (e: FirebaseFunctionsException) {
listener.onError(e)
tcs.setException(e)
}
}
}
)
}

private fun validateResponse(response: Response) {
if (response.isSuccessful) return

val htmlContentType = "text/html; charset=utf-8"
val trimMargin: String
if (response.code() == 404 && response.header("Content-Type") == htmlContentType) {
trimMargin = """URL not found. Raw response: ${response.body()?.string()}""".trimMargin()
throw FirebaseFunctionsException(
trimMargin,
FirebaseFunctionsException.Code.fromHttpStatus(response.code()),
null
)
}

val text = response.body()?.string() ?: ""
val error: Any?
try {
val json = JSONObject(text)
error = serializer.decode(json.opt("error"))
} catch (e: Throwable) {
throw FirebaseFunctionsException(
"${e.message} Unexpected Response:\n$text ",
FirebaseFunctionsException.Code.INTERNAL,
e
)
}
throw FirebaseFunctionsException(
error.toString(),
FirebaseFunctionsException.Code.INTERNAL,
error
)
}

private fun Request.Builder.applyCommonConfiguration(data: Any?, context: HttpsCallableContext?) {
val body: MutableMap<String?, Any?> = HashMap()
val encoded = serializer.encode(data)
body["data"] = encoded
if (context!!.authToken != null) {
header("Authorization", "Bearer " + context.authToken)
}
if (context.instanceIdToken != null) {
header("Firebase-Instance-ID-Token", context.instanceIdToken)
}
if (context.appCheckToken != null) {
header("X-Firebase-AppCheck", context.appCheckToken)
}
header("Accept", "text/event-stream")
val bodyJSON = JSONObject(body)
val contentType = MediaType.parse("application/json")
val requestBody = RequestBody.create(contentType, bodyJSON.toString())
post(requestBody)
}

private fun processSSEStream(
inputStream: InputStream,
serializer: Serializer,
listener: SSETaskListener,
tcs: TaskCompletionSource<HttpsCallableResult>
) {
BufferedReader(InputStreamReader(inputStream)).use { reader ->
try {
reader.lineSequence().forEach { line ->
val dataChunk =
when {
line.startsWith("data:") -> line.removePrefix("data:")
line.startsWith("result:") -> line.removePrefix("result:")
else -> return@forEach
}
try {
val json = JSONObject(dataChunk)
when {
json.has("message") ->
serializer.decode(json.opt("message"))?.let { listener.onNext(it) }
json.has("error") -> {
serializer.decode(json.opt("error"))?.let {
throw FirebaseFunctionsException(
it.toString(),
FirebaseFunctionsException.Code.INTERNAL,
it
)
}
}
json.has("result") -> {
serializer.decode(json.opt("result"))?.let {
listener.onComplete(it)
tcs.setResult(HttpsCallableResult(it))
}
return
}
}
} catch (e: Throwable) {
throw FirebaseFunctionsException(
"${e.message} Invalid JSON: $dataChunk",
FirebaseFunctionsException.Code.INTERNAL,
e
)
}
}
throw FirebaseFunctionsException(
"Stream ended unexpectedly without completion.",
FirebaseFunctionsException.Code.INTERNAL,
null
)
} catch (e: Exception) {
throw FirebaseFunctionsException(
e.message ?: "Error reading stream",
FirebaseFunctionsException.Code.INTERNAL,
e
)
}
}
}

public companion object {
/** A task that will be resolved once ProviderInstaller has installed what it needs to. */
private val providerInstalled = TaskCompletionSource<Void>()
Expand Down
Loading