Skip to content

Commit

Permalink
Client scope cancellation propagation (#34)
Browse files Browse the repository at this point in the history
* add scope to call cancellation binding ext

* bind scope cancellation in client call builders

* add test-api module

* update unary client call tests

* add scope cancellation tests

* add tests for streaming client call builders

* fix jacoco config

* close request channel on response error

* address exception leaking in send channel builder

* fix channel test bug

* fix channel error propagation test

* add client bidi call tests
  • Loading branch information
marcoferrer authored Mar 11, 2019
1 parent 33f353f commit 2f783ba
Show file tree
Hide file tree
Showing 24 changed files with 1,694 additions and 110 deletions.
12 changes: 12 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ subprojects{ subproject ->
}

apply plugin: 'idea'
apply plugin: 'jacoco'
apply plugin: 'kotlin'

group = 'com.github.marcoferrer.krotoplus'
Expand All @@ -54,4 +55,15 @@ subprojects{ subproject ->
testImplementation "org.jetbrains.kotlin:kotlin-test-junit"
}

jacoco {
toolVersion = "0.8.3"
}

jacocoTestReport {
reports {
html.enabled = true
html.destination file("$buildDir/reports/coverage")
csv.enabled = false
}
}
}
1 change: 1 addition & 0 deletions kroto-plus-coroutines/benchmark/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ protobuf {

all().each{ task ->
task.inputs.files krotoConfig
task.dependsOn ':protoc-gen-kroto-plus:bootJar'

task.builtins {
remove java
Expand Down
33 changes: 33 additions & 0 deletions kroto-plus-coroutines/build.gradle
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
description = "Kroto+ Grpc Coroutine Support"

apply from: "$rootDir/publishing.gradle"
apply plugin: 'com.google.protobuf'

def experimentalFlags = [
"-Xuse-experimental=kotlin.Experimental",
Expand All @@ -27,5 +28,37 @@ compileTestKotlin {
dependencies {
implementation "org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutines}"
implementation "io.grpc:grpc-stub:${Versions.grpc}"

testProtobuf project(':test-api')
testImplementation project(':test-api:grpc')
testImplementation project(':test-api:java')
testImplementation "io.mockk:mockk:${Versions.mockk}"
}

tasks.withType(JavaCompile) {
enabled = false
}

protobuf {
protoc { artifact = "com.google.protobuf:protoc:${Versions.protobuf}"}

//noinspection GroovyAssignabilityCheck
plugins {
coroutines {
path = "${rootProject.projectDir}/protoc-gen-grpc-coroutines/build/libs/protoc-gen-grpc-coroutines-${project.version}-jvm8.jar"
}
}

generateProtoTasks {

all().each{ task ->
task.dependsOn ':protoc-gen-grpc-coroutines:bootJar'
task.builtins {
remove java
}
task.plugins {
coroutines {}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

package com.github.marcoferrer.krotoplus.coroutines

import com.github.marcoferrer.krotoplus.coroutines.call.completionHandler
import com.github.marcoferrer.krotoplus.coroutines.call.newProducerScope
import com.github.marcoferrer.krotoplus.coroutines.call.toRpcException
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Job
Expand All @@ -44,5 +44,12 @@ public fun <T> CoroutineScope.launchProducerJob(
context: CoroutineContext = EmptyCoroutineContext,
block: suspend ProducerScope<T>.()->Unit
): Job =
launch(context) { newProducerScope(channel).block() }
.apply { invokeOnCompletion(channel.completionHandler) }
launch(context) {
newProducerScope(channel).block()
}.apply {
invokeOnCompletion {
if(!channel.isClosedForSend){
channel.close(it?.toRpcException())
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ import io.grpc.StatusException
import io.grpc.StatusRuntimeException
import io.grpc.stub.ServerCallStreamObserver
import io.grpc.stub.StreamObserver
import io.grpc.ClientCall
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.channels.Channel
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.coroutines.CoroutineContext

Expand All @@ -34,15 +36,21 @@ internal fun <RespT> CoroutineScope.newSendChannelFromObserver(
capacity: Int = 1
): SendChannel<RespT> =
actor<RespT>(
context = Dispatchers.Unconfined,
context = observer.exceptionHandler + Dispatchers.Unconfined,
capacity = capacity,
start = CoroutineStart.LAZY
) {
consumeEach { observer.onNext(it) }
}.apply {
try {
consumeEach { observer.onNext(it) }
channel.close()
}catch (e:Throwable){
channel.close(e)
}
}.apply{
invokeOnClose(observer.completionHandler)
}


internal fun <ReqT, RespT> CoroutineScope.newManagedServerResponseChannel(
responseObserver: ServerCallStreamObserver<RespT>,
isMessagePreloaded: AtomicBoolean,
Expand All @@ -60,40 +68,36 @@ internal fun CoroutineScope.bindToClientCancellation(observer: ServerCallStreamO
observer.setOnCancelHandler { this@bindToClientCancellation.cancel() }
}

internal fun CoroutineScope.bindScopeCancellationToCall(call: ClientCall<*, *>){

val job = coroutineContext[Job]
?: error("Unable to bind cancellation to call because scope does not have a job: $this")

job.apply {
invokeOnCompletion {
if(isCancelled){
call.cancel(it?.message,it?.cause ?: it)
}
}
}
}

internal val StreamObserver<*>.exceptionHandler: CoroutineExceptionHandler
get() = CoroutineExceptionHandler { _, e ->
kotlin.runCatching { onError(e.toRpcException()) }
}

internal val StreamObserver<*>.completionHandler: CompletionHandler
get() = {
// If the call was cancelled already
// the stream observer will throw
runCatching {
if(it != null)
if (it != null)
onError(it.toRpcException()) else
onCompleted()
}
}

internal val SendChannel<*>.completionHandler: CompletionHandler
get() = {
if(!isClosedForSend){
close(it?.toRpcException())
}
}

internal val SendChannel<*>.abandonedRpcHandler: CompletionHandler
get() = { completionError ->
if(!isClosedForSend){

val rpcException = completionError
?.toRpcException()
?.let { it as? StatusRuntimeException }
?.takeUnless { it.status.code == Status.UNKNOWN.code }
?: Status.UNKNOWN
.withDescription("Abandoned Rpc")
.asRuntimeException()

close(rpcException)
}
}

internal fun Throwable.toRpcException(): Throwable =
when (this) {
is StatusException,
Expand Down Expand Up @@ -127,23 +131,32 @@ internal fun <T> CoroutineScope.newProducerScope(channel: SendChannel<T>): Produ
}

internal inline fun <T> StreamObserver<T>.handleUnaryRpc(block: ()->T){
runCatching { onNext(block()) }
.onSuccess { onCompleted() }
.onFailure { onError(it.toRpcException()) }
try{
onNext(block())
onCompleted()
}catch (e: Throwable){
onError(e.toRpcException())
}
}

internal inline fun <T> SendChannel<T>.handleStreamingRpc(block: (SendChannel<T>)->Unit){
runCatching { block(this) }
.onSuccess { close() }
.onFailure { close(it.toRpcException()) }
try{
block(this)
close()
}catch (e: Throwable){
close(e.toRpcException())
}
}

internal inline fun <ReqT, RespT> handleBidiStreamingRpc(
requestChannel: ReceiveChannel<ReqT>,
responseChannel: SendChannel<RespT>,
block: (ReceiveChannel<ReqT>, SendChannel<RespT>) -> Unit
) {
runCatching { block(requestChannel,responseChannel) }
.onSuccess { responseChannel.close() }
.onFailure { responseChannel.close(it.toRpcException()) }
try{
block(requestChannel,responseChannel)
responseChannel.close()
}catch (e:Throwable){
responseChannel.close(e.toRpcException())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,80 @@
package com.github.marcoferrer.krotoplus.coroutines.client

import com.github.marcoferrer.krotoplus.coroutines.*
import com.github.marcoferrer.krotoplus.coroutines.call.newRpcScope
import com.github.marcoferrer.krotoplus.coroutines.call.newSendChannelFromObserver
import com.github.marcoferrer.krotoplus.coroutines.call.toStreamObserver
import com.github.marcoferrer.krotoplus.coroutines.call.*
import io.grpc.MethodDescriptor
import io.grpc.stub.AbstractStub
import io.grpc.stub.ClientCalls.*
import io.grpc.stub.ClientResponseObserver
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel


/**
* Executes a unary rpc call using the [io.grpc.Channel] and [io.grpc.CallOptions] attached to the
* receiver [AbstractStub].
*
* This method will suspend the invoking coroutine until completion. Its execution is bound to the current
* [CancellableContinuation] as well as the current [Job]
*
* The server is notified of cancellation under one of the following conditions:
* * The current continuation is, or has become cancelled.
* * The current job is, or has become cancelled. Either exceptionally or normally.
*
* A cancellation of the current scopes job will not always directly correlate to a cancelled continuation.
* If the job of the receiver stub differs from that of the continuation, its cancellation will cause the
* this method to throw a [io.grpc.StatusRuntimeException] with a status code of [io.grpc.Status.CANCELLED].
*
* @throws io.grpc.StatusRuntimeException The error returned by the server or local scope cancellation.
*
*/
public suspend fun <ReqT, RespT, T : AbstractStub<T>> T.clientCallUnary(
request: ReqT,
method: MethodDescriptor<ReqT, RespT>
): RespT = suspendCancellableCoroutine { cont: CancellableContinuation<RespT> ->

with(newRpcScope(cont.context + coroutineContext, method)) {
asyncUnaryCall<ReqT, RespT>(
channel.newCall(method, callOptions.withCoroutineContext(coroutineContext)),
request,
SuspendingUnaryObserver(cont)
)
val call = channel.newCall(method, callOptions.withCoroutineContext(coroutineContext))
asyncUnaryCall<ReqT, RespT>(call, request, SuspendingUnaryObserver(cont))
cont.invokeOnCancellation { call.cancel(it?.message, it) }
bindScopeCancellationToCall(call)
}
}

/**
* Executes a server streaming rpc call using the [io.grpc.Channel] and [io.grpc.CallOptions] attached to the
* receiver [AbstractStub].
*
* This method will return a [ReceiveChannel] to its invoker so that response messages from the target server can
* be processed in a suspending manner.
*
* A new [observer][ClientResponseObserver] is created with back-pressure enabled. The observer will be
* used by grpc to handle incoming messages and errors from the target server. New messages are submitted to the
* resulting channel for consumption by the client.
*
* In the event of the server returning an [error][io.grpc.StatusRuntimeException], the resulting [ReceiveChannel] will
* be closed with it. If the local coroutine scope is cancelled then the resulting [ReceiveChannel] will be closed with
* a [io.grpc.StatusRuntimeException] with a status code of [io.grpc.Status.CANCELLED]
*
* The server is notified of cancellations once the current job is, or has become cancelled,
* either exceptionally or normally.
*
*/
public fun <ReqT, RespT, T : AbstractStub<T>> T.clientCallServerStreaming(
request: ReqT,
method: MethodDescriptor<ReqT, RespT>
): ReceiveChannel<RespT> {

with(newRpcScope(coroutineContext, method)) rpcScope@{
with(newRpcScope(coroutineContext, method)) {
val call = channel.newCall(method, callOptions.withCoroutineContext(coroutineContext))
val responseObserverChannel = ClientResponseObserverChannel<ReqT, RespT>(coroutineContext)

asyncServerStreamingCall<ReqT, RespT>(
channel.newCall(method, callOptions.withCoroutineContext(coroutineContext)),
call,
request,
responseObserverChannel
)
bindScopeCancellationToCall(call)
return responseObserverChannel
}
}
Expand All @@ -62,13 +99,16 @@ public fun <ReqT, RespT, T : AbstractStub<T>> T.clientCallBidiStreaming(
method: MethodDescriptor<ReqT, RespT>
): ClientBidiCallChannel<ReqT, RespT> {

with(newRpcScope(coroutineContext, method)){
val responseChannel = ClientResponseObserverChannel<ReqT, RespT>(coroutineContext)
with(newRpcScope(coroutineContext, method)) {
val call = channel.newCall(method, callOptions.withCoroutineContext(coroutineContext))
val responseDelegate = Channel<RespT>(capacity = 1)
val responseChannel = ClientResponseObserverChannel<ReqT, RespT>(coroutineContext, responseDelegate)
val requestObserver = asyncBidiStreamingCall<ReqT, RespT>(
channel.newCall(method, callOptions.withCoroutineContext(coroutineContext)),
responseChannel
call, responseChannel
)
bindScopeCancellationToCall(call)
val requestChannel = newSendChannelFromObserver(requestObserver)
responseDelegate.invokeOnClose { requestChannel.close(it) }

return ClientBidiCallChannelImpl(requestChannel, responseChannel)
}
Expand All @@ -78,12 +118,13 @@ public fun <ReqT, RespT, T : AbstractStub<T>> T.clientCallClientStreaming(
method: MethodDescriptor<ReqT, RespT>
): ClientStreamingCallChannel<ReqT, RespT> {

with(newRpcScope(coroutineContext, method)) rpcScope@{
val completableResponse = CompletableDeferred<RespT>()
with(newRpcScope(coroutineContext, method)) {
val completableResponse = CompletableDeferred<RespT>(parent = coroutineContext[Job])
val call = channel.newCall(method, callOptions.withCoroutineContext(coroutineContext))
val requestObserver = asyncClientStreamingCall<ReqT, RespT>(
channel.newCall(method, callOptions.withCoroutineContext(coroutineContext)),
completableResponse.toStreamObserver()
call, completableResponse.toStreamObserver()
)
bindScopeCancellationToCall(call)
val requestChannel = newSendChannelFromObserver(requestObserver)
return ClientStreamingCallChannelImpl(
requestChannel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public fun <ReqT, RespT> CoroutineScope.serverCallServerStreaming(

launch {
responseChannel.handleStreamingRpc { block(it) }
}.invokeOnCompletion(responseChannel.abandonedRpcHandler)
}
}
}

Expand Down Expand Up @@ -128,7 +128,7 @@ public fun <ReqT, RespT> CoroutineScope.serverCallBidiStreaming(

launch {
handleBidiStreamingRpc(requestChannel, responseChannel){ req, resp -> block(req,resp) }
}.invokeOnCompletion(responseChannel.abandonedRpcHandler)
}

return requestChannel
}
Expand Down
Loading

0 comments on commit 2f783ba

Please sign in to comment.