diff --git a/kotlinx-coroutines-core/jvm/test/MutexCancellationStressTest.kt b/kotlinx-coroutines-core/jvm/test/MutexCancellationStressTest.kt index eb6360dac0..20798b837d 100644 --- a/kotlinx-coroutines-core/jvm/test/MutexCancellationStressTest.kt +++ b/kotlinx-coroutines-core/jvm/test/MutexCancellationStressTest.kt @@ -8,7 +8,11 @@ import kotlinx.coroutines.internal.* import kotlinx.coroutines.selects.* import kotlinx.coroutines.sync.* import org.junit.* +import org.junit.Test import java.util.concurrent.* +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger +import kotlin.test.* class MutexCancellationStressTest : TestBase() { @Test @@ -18,13 +22,16 @@ class MutexCancellationStressTest : TestBase() { val mutexOwners = Array(mutexJobNumber) { "$it" } val dispatcher = Executors.newFixedThreadPool(mutexJobNumber + 2).asCoroutineDispatcher() var counter = 0 - val counterLocal = Array(mutexJobNumber) { LocalAtomicInt(0) } - val completed = LocalAtomicInt(0) + val counterLocal = Array(mutexJobNumber) { AtomicInteger(0) } + val completed = AtomicBoolean(false) val mutexJobLauncher: (jobNumber: Int) -> Job = { jobId -> val coroutineName = "MutexJob-$jobId" - launch(dispatcher + CoroutineName(coroutineName)) { - while (completed.value == 0) { + // ATOMIC to always have a chance to proceed + launch(dispatcher + CoroutineName(coroutineName), CoroutineStart.ATOMIC) { + while (!completed.get()) { + // Stress out holdsLock mutex.holdsLock(mutexOwners[(jobId + 1) % mutexJobNumber]) + // Stress out lock-like primitives if (mutex.tryLock(mutexOwners[jobId])) { counterLocal[jobId].incrementAndGet() counter++ @@ -47,18 +54,20 @@ class MutexCancellationStressTest : TestBase() { val mutexJobs = (0 until mutexJobNumber).map { mutexJobLauncher(it) }.toMutableList() val checkProgressJob = launch(dispatcher + CoroutineName("checkProgressJob")) { var lastCounterLocalSnapshot = (0 until mutexJobNumber).map { 0 } - while (completed.value == 0) { - delay(1000) + while (!completed.get()) { + delay(500) + // If we've caught the completion after delay, then there is a chance no progress were made whatsoever, bail out + if (completed.get()) return@launch val c = counterLocal.map { it.value } for (i in 0 until mutexJobNumber) { - assert(c[i] > lastCounterLocalSnapshot[i]) { "No progress in MutexJob-$i" } + assert(c[i] > lastCounterLocalSnapshot[i]) { "No progress in MutexJob-$i, last observed state: ${c[i]}" } } lastCounterLocalSnapshot = c } } val cancellationJob = launch(dispatcher + CoroutineName("cancellationJob")) { var cancellingJobId = 0 - while (completed.value == 0) { + while (!completed.get()) { val jobToCancel = mutexJobs.removeFirst() jobToCancel.cancelAndJoin() mutexJobs += mutexJobLauncher(cancellingJobId) @@ -66,11 +75,11 @@ class MutexCancellationStressTest : TestBase() { } } delay(2000L * stressTestMultiplier) - completed.value = 1 + completed.set(true) cancellationJob.join() mutexJobs.forEach { it.join() } checkProgressJob.join() - check(counter == counterLocal.sumOf { it.value }) + assertEquals(counter, counterLocal.sumOf { it.value }) dispatcher.close() } }