Skip to content

Commit

Permalink
Introduce CoroutineContextThreadLocal API to integrate with thread-lo…
Browse files Browse the repository at this point in the history
…cal sensitive code

Fixes #119
  • Loading branch information
elizarov committed Jul 25, 2018
1 parent 9d31ffc commit 5eb01e9
Show file tree
Hide file tree
Showing 8 changed files with 225 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,11 @@ public final class kotlinx/coroutines/experimental/CoroutineContextKt {
public static final fun newCoroutineContext (Lkotlin/coroutines/experimental/CoroutineContext;)Lkotlin/coroutines/experimental/CoroutineContext;
public static final fun newCoroutineContext (Lkotlin/coroutines/experimental/CoroutineContext;Lkotlinx/coroutines/experimental/Job;)Lkotlin/coroutines/experimental/CoroutineContext;
public static synthetic fun newCoroutineContext$default (Lkotlin/coroutines/experimental/CoroutineContext;Lkotlinx/coroutines/experimental/Job;ILjava/lang/Object;)Lkotlin/coroutines/experimental/CoroutineContext;
public static final fun restoreThreadContext (Ljava/lang/String;)V
public static final fun updateThreadContext (Lkotlin/coroutines/experimental/CoroutineContext;)Ljava/lang/String;
}

public abstract interface class kotlinx/coroutines/experimental/CoroutineContextThreadLocal {
public abstract fun restoreThreadContext (Lkotlin/coroutines/experimental/CoroutineContext;Ljava/lang/Object;)V
public abstract fun updateThreadContext (Lkotlin/coroutines/experimental/CoroutineContext;)Ljava/lang/Object;
}

public abstract class kotlinx/coroutines/experimental/CoroutineDispatcher : kotlin/coroutines/experimental/AbstractCoroutineContextElement, kotlin/coroutines/experimental/ContinuationInterceptor {
Expand Down
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ configure(subprojects.findAll { !it.name.contains(sourceless) && it.name != "ben
main.kotlin.srcDirs = ['src']
test.kotlin.srcDirs = ['test']
main.resources.srcDirs = ['resources']
test.resources.srcDirs = ['test-resources']
}
}

Expand Down
55 changes: 33 additions & 22 deletions core/kotlinx-coroutines-core/src/CoroutineContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package kotlinx.coroutines.experimental

import java.util.*
import java.util.concurrent.atomic.AtomicLong
import kotlin.coroutines.experimental.AbstractCoroutineContextElement
import kotlin.coroutines.experimental.ContinuationInterceptor
Expand Down Expand Up @@ -40,6 +41,17 @@ internal val DEBUG = run {
}
}

@Suppress("UNCHECKED_CAST")
internal val coroutineContextThreadLocal: CoroutineContextThreadLocal<Any?>? = run {
val services = ServiceLoader.load(CoroutineContextThreadLocal::class.java).toMutableList()
if (DEBUG) services.add(0, DebugThreadName)
when (services.size) {
0 -> null
1 -> services.single() as CoroutineContextThreadLocal<Any?>
else -> CoroutineContextThreadLocalList((services as List<CoroutineContextThreadLocal<Any?>>).toTypedArray())
}
}

private val COROUTINE_ID = AtomicLong()

// for tests only
Expand Down Expand Up @@ -89,29 +101,33 @@ public actual fun newCoroutineContext(context: CoroutineContext, parent: Job? =
* Executes a block using a given coroutine context.
*/
internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, block: () -> T): T {
val oldName = context.updateThreadContext()
val oldValue = coroutineContextThreadLocal?.updateThreadContext(context)
try {
return block()
} finally {
restoreThreadContext(oldName)
coroutineContextThreadLocal?.restoreThreadContext(context, oldValue)
}
}

@PublishedApi
internal fun CoroutineContext.updateThreadContext(): String? {
if (!DEBUG) return null
val coroutineId = this[CoroutineId] ?: return null
val coroutineName = this[CoroutineName]?.name ?: "coroutine"
val currentThread = Thread.currentThread()
val oldName = currentThread.name
currentThread.name = buildString(oldName.length + coroutineName.length + 10) {
append(oldName)
append(" @")
append(coroutineName)
append('#')
append(coroutineId.id)
private object DebugThreadName : CoroutineContextThreadLocal<String?> {
override fun updateThreadContext(context: CoroutineContext): String? {
val coroutineId = context[CoroutineId] ?: return null
val coroutineName = context[CoroutineName]?.name ?: "coroutine"
val currentThread = Thread.currentThread()
val oldName = currentThread.name
currentThread.name = buildString(oldName.length + coroutineName.length + 10) {
append(oldName)
append(" @")
append(coroutineName)
append('#')
append(coroutineId.id)
}
return oldName
}

override fun restoreThreadContext(context: CoroutineContext, oldValue: String?) {
if (oldValue != null) Thread.currentThread().name = oldValue
}
return oldName
}

internal actual val CoroutineContext.coroutineName: String? get() {
Expand All @@ -121,12 +137,7 @@ internal actual val CoroutineContext.coroutineName: String? get() {
return "$coroutineName#${coroutineId.id}"
}

@PublishedApi
internal fun restoreThreadContext(oldName: String?) {
if (oldName != null) Thread.currentThread().name = oldName
}

private class CoroutineId(val id: Long) : AbstractCoroutineContextElement(CoroutineId) {
internal data class CoroutineId(val id: Long) : AbstractCoroutineContextElement(CoroutineId) {
companion object Key : CoroutineContext.Key<CoroutineId>
override fun toString(): String = "CoroutineId($id)"
}
93 changes: 93 additions & 0 deletions core/kotlinx-coroutines-core/src/CoroutineContextThreadLocal.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.experimental

import kotlin.coroutines.experimental.*

/**
* An extension point to define elements in [CoroutineContext] that are installed into thread local
* variables every time the coroutine from the specified context in resumed on a thread.
*
* Implementations on this interface are looked up via [java.util.ServiceLoader].
*
* Example usage looks like this:
*
* ```
* // declare custom coroutine context element
* class MyElement : AbstractCoroutineContextElement(Key) {
* companion object Key : CoroutineContext.Key<MyElement>
* // some state is kept here
* }
*
* // declare thread local variable
* private val myThreadLocal = ThreadLocal<MyElement?>()
*
* // declare extension point implementation
* class MyCoroutineContextThreadLocal : CoroutineContextThreadLocal<MyElement?> {
* // this is invoked before coroutine is resumed on current thread
* override fun updateThreadContext(context: CoroutineContext): MyElement? {
* val oldValue = myThreadLocal.get()
* myThreadLocal.set(context[MyElement])
* return oldValue
* }
*
* // this is invoked after coroutine has suspended on current thread
* override fun restoreThreadContext(context: CoroutineContext, oldValue: MyElement?) {
* myThreadLocal.set(oldValue)
* }
* }
* ```
*
* Now, `MyCoroutineContextThreadLocal` fully qualified class named shall be registered via
* `META-INF/services/kotlinx.coroutines.experimental.CoroutineContextThreadLocal` file.
*/
public interface CoroutineContextThreadLocal<T> {
/**
* Updates context of the current thread.
* This function is invoked before the coroutine in the specified [context] is resumed in the current thread.
* The result of this function is the old value that will be passed to [restoreThreadContext].
*/
public fun updateThreadContext(context: CoroutineContext): T

/**
* Restores context of the current thread.
* This function is invoked after the coroutine in the specified [context] is suspended in the current thread.
* The value of [oldValue] is the result of the previous invocation of [updateThreadContext].
*/
public fun restoreThreadContext(context: CoroutineContext, oldValue: T)
}

/**
* This class is used when multiple [CoroutineContextThreadLocal] are installed.
*/
internal class CoroutineContextThreadLocalList(
private val impls: Array<CoroutineContextThreadLocal<Any?>>
) : CoroutineContextThreadLocal<Any?> {
init {
require(impls.size > 1)
}

private val threadLocalStack = ThreadLocal<ArrayList<Any?>?>()

override fun updateThreadContext(context: CoroutineContext): Any? {
val stack = threadLocalStack.get() ?: ArrayList<Any?>().also {
threadLocalStack.set(it)
}
val lastIndex = impls.lastIndex
for (i in 0 until lastIndex) {
stack.add(impls[i].updateThreadContext(context))
}
return impls[lastIndex].updateThreadContext(context)
}

override fun restoreThreadContext(context: CoroutineContext, oldValue: Any?) {
val stack = threadLocalStack.get()!! // must be there
val lastIndex = impls.lastIndex
impls[lastIndex].restoreThreadContext(context, oldValue)
for (i in lastIndex - 1 downTo 0) {
impls[i].restoreThreadContext(context, stack.removeAt(stack.lastIndex))
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
kotlinx.coroutines.experimental.MyCoroutineContextThreadLocal
kotlinx.coroutines.experimental.ValidatingCoroutineContextThreadLocal
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.experimental

import org.junit.Test
import kotlin.coroutines.experimental.*
import kotlin.test.*

class CoroutineContextThreadLocalTest : TestBase() {
@Test
fun testExample() = runTest {
val mainDispatcher = coroutineContext[ContinuationInterceptor]!!
val mainThread = Thread.currentThread()
val element = MyElement()
assertEquals(null, myThreadLocal.get())
val job = launch(element) {
assertTrue(mainThread != Thread.currentThread())
assertSame(element, coroutineContext[MyElement])
assertSame(element, myThreadLocal.get())
withContext(mainDispatcher) {
assertSame(mainThread, Thread.currentThread())
assertSame(element, coroutineContext[MyElement])
assertSame(element, myThreadLocal.get())
}
assertTrue(mainThread != Thread.currentThread())
assertSame(element, coroutineContext[MyElement])
assertSame(element, myThreadLocal.get())
}
assertEquals(null, myThreadLocal.get())
job.join()
assertEquals(null, myThreadLocal.get())
}
}

// declare custom coroutine context element
class MyElement : AbstractCoroutineContextElement(Key) {
companion object Key : CoroutineContext.Key<MyElement>
// some state is kept here
}

// declare thread local variable
private val myThreadLocal = ThreadLocal<MyElement?>()

// declare extension point implementation
class MyCoroutineContextThreadLocal : CoroutineContextThreadLocal<MyElement?> {
// this is invoked before coroutine is resumed on current thread
override fun updateThreadContext(context: CoroutineContext): MyElement? {
val oldValue = myThreadLocal.get()
myThreadLocal.set(context[MyElement])
return oldValue
}

// this is invoked after coroutine has suspended on current thread
override fun restoreThreadContext(context: CoroutineContext, oldValue: MyElement?) {
myThreadLocal.set(oldValue)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.experimental

import kotlin.coroutines.experimental.*

private val currentCoroutineId = ThreadLocal<CoroutineId?>()

internal class ValidatingCoroutineContextThreadLocal : CoroutineContextThreadLocal<CoroutineId?> {
override fun updateThreadContext(context: CoroutineContext): CoroutineId? {
val id = context[CoroutineId] ?: error("Tests should be run in debug mode (enable assertions?)")
val top = currentCoroutineId.get()
require( top != id) {
"Thread ${Thread.currentThread().name} already has coroutine context for coroutine $context"
}
currentCoroutineId.set(id)
return top
}

override fun restoreThreadContext(context: CoroutineContext, oldValue: CoroutineId?) {
val id = context[CoroutineId]
val top = currentCoroutineId.get()
require(top == id) {
"Thread ${Thread.currentThread().name} does not have coroutine context for coroutine $context, but has for coroutine id $top"
}
currentCoroutineId.set(oldValue)
}
}
3 changes: 2 additions & 1 deletion integration/kotlinx-coroutines-quasar/src/Quasar.kt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ fun <T> runFiberBlocking(block: suspend () -> T): T =
private class CoroutineAsync<T>(
private val block: suspend () -> T
) : FiberAsync<T, Throwable>(), Continuation<T> {
override val context: CoroutineContext = Fiber.currentFiber().scheduler.executor.asCoroutineDispatcher()
override val context: CoroutineContext =
newCoroutineContext(Fiber.currentFiber().scheduler.executor.asCoroutineDispatcher())
override fun resume(value: T) { asyncCompleted(value) }
override fun resumeWithException(exception: Throwable) { asyncFailed(exception) }

Expand Down

0 comments on commit 5eb01e9

Please sign in to comment.