Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client scope cancellation propagation #34

Merged
merged 12 commits into from
Mar 11, 2019
12 changes: 12 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ subprojects{ subproject ->
}

apply plugin: 'idea'
apply plugin: 'jacoco'
apply plugin: 'kotlin'

group = 'com.github.marcoferrer.krotoplus'
Expand All @@ -54,4 +55,15 @@ subprojects{ subproject ->
testImplementation "org.jetbrains.kotlin:kotlin-test-junit"
}

jacoco {
toolVersion = "0.8.3"
}

jacocoTestReport {
reports {
html.enabled = true
html.destination file("$buildDir/reports/coverage")
csv.enabled = false
}
}
}
1 change: 1 addition & 0 deletions kroto-plus-coroutines/benchmark/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ protobuf {

all().each{ task ->
task.inputs.files krotoConfig
task.dependsOn ':protoc-gen-kroto-plus:bootJar'

task.builtins {
remove java
Expand Down
33 changes: 33 additions & 0 deletions kroto-plus-coroutines/build.gradle
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
description = "Kroto+ Grpc Coroutine Support"

apply from: "$rootDir/publishing.gradle"
apply plugin: 'com.google.protobuf'

def experimentalFlags = [
"-Xuse-experimental=kotlin.Experimental",
Expand All @@ -27,5 +28,37 @@ compileTestKotlin {
dependencies {
implementation "org.jetbrains.kotlinx:kotlinx-coroutines-core:${Versions.coroutines}"
implementation "io.grpc:grpc-stub:${Versions.grpc}"

testProtobuf project(':test-api')
testImplementation project(':test-api:grpc')
testImplementation project(':test-api:java')
testImplementation "io.mockk:mockk:${Versions.mockk}"
}

tasks.withType(JavaCompile) {
enabled = false
}

protobuf {
protoc { artifact = "com.google.protobuf:protoc:${Versions.protobuf}"}

//noinspection GroovyAssignabilityCheck
plugins {
coroutines {
path = "${rootProject.projectDir}/protoc-gen-grpc-coroutines/build/libs/protoc-gen-grpc-coroutines-${project.version}-jvm8.jar"
}
}

generateProtoTasks {

all().each{ task ->
task.dependsOn ':protoc-gen-grpc-coroutines:bootJar'
task.builtins {
remove java
}
task.plugins {
coroutines {}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

package com.github.marcoferrer.krotoplus.coroutines

import com.github.marcoferrer.krotoplus.coroutines.call.completionHandler
import com.github.marcoferrer.krotoplus.coroutines.call.newProducerScope
import com.github.marcoferrer.krotoplus.coroutines.call.toRpcException
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Job
Expand All @@ -44,5 +44,12 @@ public fun <T> CoroutineScope.launchProducerJob(
context: CoroutineContext = EmptyCoroutineContext,
block: suspend ProducerScope<T>.()->Unit
): Job =
launch(context) { newProducerScope(channel).block() }
.apply { invokeOnCompletion(channel.completionHandler) }
launch(context) {
newProducerScope(channel).block()
}.apply {
invokeOnCompletion {
if(!channel.isClosedForSend){
channel.close(it?.toRpcException())
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ import io.grpc.StatusException
import io.grpc.StatusRuntimeException
import io.grpc.stub.ServerCallStreamObserver
import io.grpc.stub.StreamObserver
import io.grpc.ClientCall
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.channels.Channel
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.coroutines.CoroutineContext

Expand All @@ -34,15 +36,21 @@ internal fun <RespT> CoroutineScope.newSendChannelFromObserver(
capacity: Int = 1
): SendChannel<RespT> =
actor<RespT>(
context = Dispatchers.Unconfined,
context = observer.exceptionHandler + Dispatchers.Unconfined,
capacity = capacity,
start = CoroutineStart.LAZY
) {
consumeEach { observer.onNext(it) }
}.apply {
try {
consumeEach { observer.onNext(it) }
channel.close()
}catch (e:Throwable){
channel.close(e)
}
}.apply{
invokeOnClose(observer.completionHandler)
}


internal fun <ReqT, RespT> CoroutineScope.newManagedServerResponseChannel(
responseObserver: ServerCallStreamObserver<RespT>,
isMessagePreloaded: AtomicBoolean,
Expand All @@ -60,40 +68,36 @@ internal fun CoroutineScope.bindToClientCancellation(observer: ServerCallStreamO
observer.setOnCancelHandler { [email protected]() }
}

internal fun CoroutineScope.bindScopeCancellationToCall(call: ClientCall<*, *>){

val job = coroutineContext[Job]
?: error("Unable to bind cancellation to call because scope does not have a job: $this")

job.apply {
invokeOnCompletion {
if(isCancelled){
call.cancel(it?.message,it?.cause ?: it)
}
}
}
}

internal val StreamObserver<*>.exceptionHandler: CoroutineExceptionHandler
get() = CoroutineExceptionHandler { _, e ->
kotlin.runCatching { onError(e.toRpcException()) }
}

internal val StreamObserver<*>.completionHandler: CompletionHandler
get() = {
// If the call was cancelled already
// the stream observer will throw
runCatching {
if(it != null)
if (it != null)
onError(it.toRpcException()) else
onCompleted()
}
}

internal val SendChannel<*>.completionHandler: CompletionHandler
get() = {
if(!isClosedForSend){
close(it?.toRpcException())
}
}

internal val SendChannel<*>.abandonedRpcHandler: CompletionHandler
get() = { completionError ->
if(!isClosedForSend){

val rpcException = completionError
?.toRpcException()
?.let { it as? StatusRuntimeException }
?.takeUnless { it.status.code == Status.UNKNOWN.code }
?: Status.UNKNOWN
.withDescription("Abandoned Rpc")
.asRuntimeException()

close(rpcException)
}
}

internal fun Throwable.toRpcException(): Throwable =
when (this) {
is StatusException,
Expand Down Expand Up @@ -127,23 +131,32 @@ internal fun <T> CoroutineScope.newProducerScope(channel: SendChannel<T>): Produ
}

internal inline fun <T> StreamObserver<T>.handleUnaryRpc(block: ()->T){
runCatching { onNext(block()) }
.onSuccess { onCompleted() }
.onFailure { onError(it.toRpcException()) }
try{
onNext(block())
onCompleted()
}catch (e: Throwable){
onError(e.toRpcException())
}
}

internal inline fun <T> SendChannel<T>.handleStreamingRpc(block: (SendChannel<T>)->Unit){
runCatching { block(this) }
.onSuccess { close() }
.onFailure { close(it.toRpcException()) }
try{
block(this)
close()
}catch (e: Throwable){
close(e.toRpcException())
}
}

internal inline fun <ReqT, RespT> handleBidiStreamingRpc(
requestChannel: ReceiveChannel<ReqT>,
responseChannel: SendChannel<RespT>,
block: (ReceiveChannel<ReqT>, SendChannel<RespT>) -> Unit
) {
runCatching { block(requestChannel,responseChannel) }
.onSuccess { responseChannel.close() }
.onFailure { responseChannel.close(it.toRpcException()) }
try{
block(requestChannel,responseChannel)
responseChannel.close()
}catch (e:Throwable){
responseChannel.close(e.toRpcException())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,80 @@
package com.github.marcoferrer.krotoplus.coroutines.client

import com.github.marcoferrer.krotoplus.coroutines.*
import com.github.marcoferrer.krotoplus.coroutines.call.newRpcScope
import com.github.marcoferrer.krotoplus.coroutines.call.newSendChannelFromObserver
import com.github.marcoferrer.krotoplus.coroutines.call.toStreamObserver
import com.github.marcoferrer.krotoplus.coroutines.call.*
import io.grpc.MethodDescriptor
import io.grpc.stub.AbstractStub
import io.grpc.stub.ClientCalls.*
import io.grpc.stub.ClientResponseObserver
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel


/**
* Executes a unary rpc call using the [io.grpc.Channel] and [io.grpc.CallOptions] attached to the
* receiver [AbstractStub].
*
* This method will suspend the invoking coroutine until completion. Its execution is bound to the current
* [CancellableContinuation] as well as the current [Job]
*
* The server is notified of cancellation under one of the following conditions:
* * The current continuation is, or has become cancelled.
* * The current job is, or has become cancelled. Either exceptionally or normally.
*
* A cancellation of the current scopes job will not always directly correlate to a cancelled continuation.
* If the job of the receiver stub differs from that of the continuation, its cancellation will cause the
* this method to throw a [io.grpc.StatusRuntimeException] with a status code of [io.grpc.Status.CANCELLED].
*
* @throws io.grpc.StatusRuntimeException The error returned by the server or local scope cancellation.
*
*/
public suspend fun <ReqT, RespT, T : AbstractStub<T>> T.clientCallUnary(
request: ReqT,
method: MethodDescriptor<ReqT, RespT>
): RespT = suspendCancellableCoroutine { cont: CancellableContinuation<RespT> ->

with(newRpcScope(cont.context + coroutineContext, method)) {
asyncUnaryCall<ReqT, RespT>(
channel.newCall(method, callOptions.withCoroutineContext(coroutineContext)),
request,
SuspendingUnaryObserver(cont)
)
val call = channel.newCall(method, callOptions.withCoroutineContext(coroutineContext))
asyncUnaryCall<ReqT, RespT>(call, request, SuspendingUnaryObserver(cont))
cont.invokeOnCancellation { call.cancel(it?.message, it) }
bindScopeCancellationToCall(call)
}
}

/**
* Executes a server streaming rpc call using the [io.grpc.Channel] and [io.grpc.CallOptions] attached to the
* receiver [AbstractStub].
*
* This method will return a [ReceiveChannel] to its invoker so that response messages from the target server can
* be processed in a suspending manner.
*
* A new [observer][ClientResponseObserver] is created with back-pressure enabled. The observer will be
* used by grpc to handle incoming messages and errors from the target server. New messages are submitted to the
* resulting channel for consumption by the client.
*
* In the event of the server returning an [error][io.grpc.StatusRuntimeException], the resulting [ReceiveChannel] will
* be closed with it. If the local coroutine scope is cancelled then the resulting [ReceiveChannel] will be closed with
* a [io.grpc.StatusRuntimeException] with a status code of [io.grpc.Status.CANCELLED]
*
* The server is notified of cancellations once the current job is, or has become cancelled,
* either exceptionally or normally.
*
*/
public fun <ReqT, RespT, T : AbstractStub<T>> T.clientCallServerStreaming(
request: ReqT,
method: MethodDescriptor<ReqT, RespT>
): ReceiveChannel<RespT> {

with(newRpcScope(coroutineContext, method)) rpcScope@{
with(newRpcScope(coroutineContext, method)) {
val call = channel.newCall(method, callOptions.withCoroutineContext(coroutineContext))
val responseObserverChannel = ClientResponseObserverChannel<ReqT, RespT>(coroutineContext)

asyncServerStreamingCall<ReqT, RespT>(
channel.newCall(method, callOptions.withCoroutineContext(coroutineContext)),
call,
request,
responseObserverChannel
)
bindScopeCancellationToCall(call)
return responseObserverChannel
}
}
Expand All @@ -62,13 +99,16 @@ public fun <ReqT, RespT, T : AbstractStub<T>> T.clientCallBidiStreaming(
method: MethodDescriptor<ReqT, RespT>
): ClientBidiCallChannel<ReqT, RespT> {

with(newRpcScope(coroutineContext, method)){
val responseChannel = ClientResponseObserverChannel<ReqT, RespT>(coroutineContext)
with(newRpcScope(coroutineContext, method)) {
val call = channel.newCall(method, callOptions.withCoroutineContext(coroutineContext))
val responseDelegate = Channel<RespT>(capacity = 1)
val responseChannel = ClientResponseObserverChannel<ReqT, RespT>(coroutineContext, responseDelegate)
val requestObserver = asyncBidiStreamingCall<ReqT, RespT>(
channel.newCall(method, callOptions.withCoroutineContext(coroutineContext)),
responseChannel
call, responseChannel
)
bindScopeCancellationToCall(call)
val requestChannel = newSendChannelFromObserver(requestObserver)
responseDelegate.invokeOnClose { requestChannel.close(it) }

return ClientBidiCallChannelImpl(requestChannel, responseChannel)
}
Expand All @@ -78,12 +118,13 @@ public fun <ReqT, RespT, T : AbstractStub<T>> T.clientCallClientStreaming(
method: MethodDescriptor<ReqT, RespT>
): ClientStreamingCallChannel<ReqT, RespT> {

with(newRpcScope(coroutineContext, method)) rpcScope@{
val completableResponse = CompletableDeferred<RespT>()
with(newRpcScope(coroutineContext, method)) {
val completableResponse = CompletableDeferred<RespT>(parent = coroutineContext[Job])
val call = channel.newCall(method, callOptions.withCoroutineContext(coroutineContext))
val requestObserver = asyncClientStreamingCall<ReqT, RespT>(
channel.newCall(method, callOptions.withCoroutineContext(coroutineContext)),
completableResponse.toStreamObserver()
call, completableResponse.toStreamObserver()
)
bindScopeCancellationToCall(call)
val requestChannel = newSendChannelFromObserver(requestObserver)
return ClientStreamingCallChannelImpl(
requestChannel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public fun <ReqT, RespT> CoroutineScope.serverCallServerStreaming(

launch {
responseChannel.handleStreamingRpc { block(it) }
}.invokeOnCompletion(responseChannel.abandonedRpcHandler)
}
}
}

Expand Down Expand Up @@ -128,7 +128,7 @@ public fun <ReqT, RespT> CoroutineScope.serverCallBidiStreaming(

launch {
handleBidiStreamingRpc(requestChannel, responseChannel){ req, resp -> block(req,resp) }
}.invokeOnCompletion(responseChannel.abandonedRpcHandler)
}

return requestChannel
}
Expand Down
Loading