Skip to content

Commit

Permalink
Add switch points into cancellation handlers (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
alefedor authored Dec 7, 2020
1 parent 0f93582 commit 3bc3c60
Show file tree
Hide file tree
Showing 18 changed files with 250 additions and 92 deletions.
15 changes: 10 additions & 5 deletions src/jvm/main/org/jetbrains/kotlinx/lincheck/Utils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -190,20 +190,25 @@ internal class StoreExceptionHandler : AbstractCoroutineContextElement(Coroutine
}

@Suppress("INVISIBLE_REFERENCE", "INVISIBLE_MEMBER")
fun <T> CancellableContinuation<T>.cancelByLincheck(promptCancellation: Boolean): Boolean {
internal fun <T> CancellableContinuation<T>.cancelByLincheck(promptCancellation: Boolean): CancellationResult {
val exceptionHandler = context[CoroutineExceptionHandler] as StoreExceptionHandler
exceptionHandler.exception = null
val cancelled = cancel(cancellationByLincheckException)
exceptionHandler.exception?.let {
throw it.cause!! // let's throw the original exception, ignoring the internal coroutines details
}
if (!cancelled && promptCancellation) {
context[Job]!!.cancel() // we should always put a job into the context for prompt cancellation
return true
return when {
cancelled -> CancellationResult.CANCELLED_BEFORE_RESUMPTION
promptCancellation -> {
context[Job]!!.cancel() // we should always put a job into the context for prompt cancellation
CancellationResult.CANCELLED_AFTER_RESUMPTION
}
else -> CancellationResult.CANCELLATION_FAILED
}
return cancelled
}

internal enum class CancellationResult { CANCELLED_BEFORE_RESUMPTION, CANCELLED_AFTER_RESUMPTION, CANCELLATION_FAILED }

@Suppress("INVISIBLE_REFERENCE", "INVISIBLE_MEMBER")
private val cancelCompletedResultMethod = DispatchedTask::class.declaredFunctions.find { it.name == "cancelCompletedResult" }!!.javaMethod!!

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package org.jetbrains.kotlinx.lincheck.runner

import kotlinx.coroutines.*
import org.jetbrains.kotlinx.lincheck.*
import org.jetbrains.kotlinx.lincheck.CancellationResult.*
import org.jetbrains.kotlinx.lincheck.execution.*
import org.jetbrains.kotlinx.lincheck.runner.FixedActiveThreadsExecutor.TestThread
import org.jetbrains.kotlinx.lincheck.runner.UseClocks.*
Expand Down Expand Up @@ -181,12 +182,11 @@ internal open class ParallelThreadsRunner(
val finalResult = if (res === COROUTINE_SUSPENDED) {
val t = Thread.currentThread() as TestThread
val cont = t.cont.also { t.cont = null }
if (actor.cancelOnSuspension && cont !== null && cont.cancelByLincheck(actor.promptCancellation)) {
if (actor.cancelOnSuspension && cont !== null && cancelByLincheck(cont, actor.promptCancellation) != CANCELLATION_FAILED) {
if (!trySetCancelledStatus(iThread, actorId)) {
// already resumed, increment `completedOrSuspendedThreads` back
completedOrSuspendedThreads.incrementAndGet()
}
afterCoroutineCancelled(iThread)
Cancelled
} else waitAndInvokeFollowUp(iThread, actorId)
} else createLincheckResult(res)
Expand Down Expand Up @@ -226,14 +226,19 @@ internal open class ParallelThreadsRunner(
return suspensionPointResults[iThread][actorId]
}

/**
* This method is used for communication between `ParallelThreadsRunner` and `ManagedStrategy` via overriding,
* so that runner do not know about managed strategy details.
*/
internal open fun <T> cancelByLincheck(cont: CancellableContinuation<T>, promptCancellation: Boolean): CancellationResult =
cont.cancelByLincheck(promptCancellation)

override fun afterCoroutineSuspended(iThread: Int) {
completedOrSuspendedThreads.incrementAndGet()
}

override fun afterCoroutineResumed(iThread: Int) {}

override fun afterCoroutineCancelled(iThread: Int) {}

// We cannot use `completionStatuses` here since
// they are set _before_ the result is published.
override fun isCoroutineResumed(iThread: Int, actorId: Int) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
*/
package org.jetbrains.kotlinx.lincheck.strategy.managed

import kotlinx.coroutines.*
import org.jetbrains.kotlinx.lincheck.*
import org.jetbrains.kotlinx.lincheck.CancellationResult.*
import org.jetbrains.kotlinx.lincheck.execution.*
import org.jetbrains.kotlinx.lincheck.runner.*
import org.jetbrains.kotlinx.lincheck.strategy.*
Expand Down Expand Up @@ -286,7 +288,7 @@ abstract class ManagedStrategy(
if (loopDetector.visitCodeLocation(iThread, codeLocation)) {
failIfObstructionFreedomIsRequired {
// Log the last event that caused obstruction freedom violation
traceCollector?.passCodeLocation(codeLocation, tracePoint)
traceCollector?.passCodeLocation(tracePoint)
"Obstruction-freedom is required but an active lock has been found"
}
checkLiveLockHappened(loopDetector.totalOperations)
Expand All @@ -297,7 +299,7 @@ abstract class ManagedStrategy(
val reason = if (isLoop) SwitchReason.ACTIVE_LOCK else SwitchReason.STRATEGY_SWITCH
switchCurrentThread(iThread, reason)
}
traceCollector?.passCodeLocation(codeLocation, tracePoint)
traceCollector?.passCodeLocation(tracePoint)
// continue the operation
}

Expand Down Expand Up @@ -462,7 +464,7 @@ abstract class ManagedStrategy(
internal fun beforeLockRelease(iThread: Int, codeLocation: Int, tracePoint: MonitorExitTracePoint?, monitor: Any): Boolean {
if (inIgnoredSection(iThread)) return true
monitorTracker.releaseMonitor(monitor)
traceCollector?.passCodeLocation(codeLocation, tracePoint)
traceCollector?.passCodeLocation(tracePoint)
return false
}

Expand All @@ -485,7 +487,7 @@ abstract class ManagedStrategy(
*/
@Suppress("UNUSED_PARAMETER")
internal fun afterUnpark(iThread: Int, codeLocation: Int, tracePoint: UnparkTracePoint?, thread: Any) {
traceCollector?.passCodeLocation(codeLocation, tracePoint)
traceCollector?.passCodeLocation(tracePoint)
}

/**
Expand Down Expand Up @@ -515,7 +517,7 @@ abstract class ManagedStrategy(
monitorTracker.notifyAll(monitor)
else
monitorTracker.notify(monitor)
traceCollector?.passCodeLocation(codeLocation, tracePoint)
traceCollector?.passCodeLocation(tracePoint)
}

/**
Expand Down Expand Up @@ -586,7 +588,7 @@ abstract class ManagedStrategy(
*/
@Suppress("UNUSED_PARAMETER")
internal fun beforeMethodCall(iThread: Int, codeLocation: Int, tracePoint: MethodCallTracePoint) {
if (isTestThread(iThread)) {
if (isTestThread(iThread) && !inIgnoredSection(iThread)) {
check(collectTrace) { "This method should be called only when logging is enabled" }
val callStackTrace = callStackTrace[iThread]
val suspendedMethodStack = suspendedFunctionsStack[iThread]
Expand All @@ -611,7 +613,7 @@ abstract class ManagedStrategy(
* @param tracePoint the corresponding trace point for the invocation
*/
internal fun afterMethodCall(iThread: Int, tracePoint: MethodCallTracePoint) {
if (isTestThread(iThread)) {
if (isTestThread(iThread) && !inIgnoredSection(iThread)) {
check(collectTrace) { "This method should be called only when logging is enabled" }
val callStackTrace = callStackTrace[iThread]
if (tracePoint.wasSuspended) {
Expand All @@ -631,11 +633,26 @@ abstract class ManagedStrategy(
* @param constructorId which constructor to use for creating code location
* @return the created interleaving point
*/
fun createTracePoint(constructorId: Int): TracePoint {
fun createTracePoint(constructorId: Int): TracePoint = doCreateTracePoint(tracePointConstructors[constructorId])

/**
* Creates a new [CoroutineCancellationTracePoint].
* This method is similar to [createTracePoint] method, but also adds the new trace point to the trace.
*/
internal fun createAndLogCancellationTracePoint(): CoroutineCancellationTracePoint? {
if (collectTrace) {
val cancellationTracePoint = doCreateTracePoint(::CoroutineCancellationTracePoint)
traceCollector?.passCodeLocation(cancellationTracePoint)
return cancellationTracePoint
}
return null
}

private fun <T : TracePoint> doCreateTracePoint(constructor: (iThread: Int, actorId: Int, CallStackTrace) -> T): T {
val iThread = currentThreadNumber()
// use any actor id for non-test threads
val actorId = if (!isTestThread(iThread)) Int.MIN_VALUE else currentActorId[iThread]
return tracePointConstructors[constructorId](iThread, actorId, callStackTrace.getOrNull(iThread)?.toList() ?: emptyList())
return constructor(iThread, actorId, callStackTrace.getOrNull(iThread)?.toList() ?: emptyList())
}

/**
Expand Down Expand Up @@ -688,17 +705,12 @@ abstract class ManagedStrategy(
_trace += FinishThreadTracePoint(iThread)
}

fun passCodeLocation(codeLocation: Int, tracePoint: TracePoint?) {
// Ignore coroutine suspensions - they are processed in another place.
if (codeLocation == COROUTINE_SUSPENSION_CODE_LOCATION) return
fun passCodeLocation(tracePoint: TracePoint?) {
_trace += tracePoint!!
}

fun addStateRepresentation(iThread: Int) {
// enter ignored section, because stateRepresentation invokes transformed method with switch points
enterIgnoredSection(iThread)
val stateRepresentation = runner.constructStateRepresentation()!!
leaveIgnoredSection(iThread)
// use call stack trace of the previous trace point
val callStackTrace = _trace.last().callStackTrace.toList()
_trace += StateRepresentationTracePoint(iThread, currentActorId[iThread], stateRepresentation, callStackTrace)
Expand Down Expand Up @@ -743,6 +755,34 @@ private class ManagedStrategyRunner(
super.afterCoroutineCancelled(iThread)
managedStrategy.afterCoroutineCancelled(iThread)
}

override fun constructStateRepresentation(): String? {
// Enter ignored section, because Runner will call transformed state representation method
val iThread = managedStrategy.currentThreadNumber()
managedStrategy.enterIgnoredSection(iThread)
val stateRepresentation = super.constructStateRepresentation()
managedStrategy.leaveIgnoredSection(iThread)
return stateRepresentation
}

override fun <T> cancelByLincheck(cont: CancellableContinuation<T>, promptCancellation: Boolean): CancellationResult {
// Create a cancellation trace point before `cancel`, so that cancellation trace point
// precede the events in `onCancellation` handler.
val cancellationTracePoint = managedStrategy.createAndLogCancellationTracePoint()
try {
// Call the `cancel` method.
val cancellationResult = super.cancelByLincheck(cont, promptCancellation)
// Pass the result to `cancellationTracePoint`.
cancellationTracePoint?.initializeCancellationResult(cancellationResult)
// Invoke `strategy.afterCoroutineCancelled` if the coroutine was cancelled successfully.
if (cancellationResult != CANCELLATION_FAILED)
managedStrategy.afterCoroutineCancelled(managedStrategy.currentThreadNumber())
return cancellationResult
} catch(e: Throwable) {
cancellationTracePoint?.initializeException(e)
throw e // throw further
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ internal class ManagedStrategyTransformer(
mv = WaitNotifyTransformer(mname, GeneratorAdapter(mv, access, mname, desc))
mv = ParkUnparkTransformer(mname, GeneratorAdapter(mv, access, mname, desc))
mv = LocalObjectManagingTransformer(mname, GeneratorAdapter(mv, access, mname, desc))
mv = CancellabilitySupportMethodTransformer(mname, GeneratorAdapter(mv, access, mname, desc))
mv = SharedVariableAccessMethodTransformer(mname, GeneratorAdapter(mv, access, mname, desc))
mv = TimeStubTransformer(GeneratorAdapter(mv, access, mname, desc))
mv = RandomTransformer(GeneratorAdapter(mv, access, mname, desc))
Expand Down Expand Up @@ -949,31 +948,6 @@ internal class ManagedStrategyTransformer(
}
}

/**
* Removes switch points in CancellableContinuationImpl.cancel, so that they will not be reported
* when a continuation is cancelled by lincheck
*/
private inner class CancellabilitySupportMethodTransformer(methodName: String, mv: GeneratorAdapter) : ManagedStrategyMethodVisitor(methodName, mv) {
private val isCancel = className == "kotlinx/coroutines/CancellableContinuationImpl" &&
(methodName == "cancel" || methodName == "cancelCompletedResult")

override fun visitCode() {
if (isCancel)
invokeBeforeIgnoredSectionEntering()
mv.visitCode()
}

override fun visitInsn(opcode: Int) {
if (isCancel) {
when (opcode) {
ARETURN, DRETURN, FRETURN, IRETURN, LRETURN, RETURN -> invokeAfterIgnoredSectionLeaving()
else -> { }
}
}
mv.visitInsn(opcode)
}
}

/**
* Track local objects for odd switch points elimination.
* A local object is an object that can be possible viewed only from one thread.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
*/
package org.jetbrains.kotlinx.lincheck.strategy.managed

import org.jetbrains.kotlinx.lincheck.*
import org.jetbrains.kotlinx.lincheck.CancellationResult.*
import java.math.*
import java.util.*
import kotlin.coroutines.*
import kotlin.coroutines.intrinsics.*

class Trace(val trace: List<TracePoint>, val verboseTrace: Boolean)
data class Trace(val trace: List<TracePoint>, val verboseTrace: Boolean)

/**
* Essentially, a trace is a list of trace points, which represent
Expand Down Expand Up @@ -216,6 +217,31 @@ internal class UnparkTracePoint(
override fun toStringImpl(): String = "UNPARK at " + stackTraceElement.shorten()
}

internal class CoroutineCancellationTracePoint(
iThread: Int, actorId: Int,
callStackTrace: CallStackTrace,
) : TracePoint(iThread, actorId, callStackTrace) {
private lateinit var cancellationResult: CancellationResult
private var exception: Throwable? = null

fun initializeCancellationResult(cancellationResult: CancellationResult) {
this.cancellationResult = cancellationResult
}

fun initializeException(e: Throwable) {
this.exception = e;
}

override fun toStringImpl(): String {
if (exception != null) return "EXCEPTION WHILE CANCELLATION"
return when (cancellationResult) {
CANCELLED_BEFORE_RESUMPTION -> "CANCELLED BEFORE RESUMPTION"
CANCELLED_AFTER_RESUMPTION -> "PROMPT CANCELLED AFTER RESUMPTION"
CANCELLATION_FAILED -> "CANCELLATION ATTEMPT FAILED"
}
}
}

/**
* Removes package info in the stack trace element representation.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,6 @@ private class ActorResultNode(iThread: Int, last: TraceNode?, verboseTrace: Bool
override val shouldBeExpanded: Boolean = false

override fun addRepresentationTo(traceRepresentation: MutableList<TraceEventRepresentation>): TraceNode? {
if (result == Cancelled)
traceRepresentation.add(TraceEventRepresentation(iThread, traceIndentation() + "CONTINUATION CANCELLED"))
if (result != null)
traceRepresentation.add(TraceEventRepresentation(iThread, traceIndentation() + "result: $result"))
return next
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package org.jetbrains.kotlinx.lincheck.test

import org.jetbrains.kotlinx.lincheck.*
import org.jetbrains.kotlinx.lincheck.strategy.*
import org.jetbrains.kotlinx.lincheck.strategy.managed.*
import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelCheckingOptions
import org.jetbrains.kotlinx.lincheck.strategy.stress.*
import org.jetbrains.kotlinx.lincheck.verifier.*
Expand All @@ -42,6 +43,7 @@ abstract class AbstractLincheckTest(
"This test should fail, but no error has been occurred (see the logs for details)"
}
} else {
failure.trace?.let { checkTraceHasNoLincheckEvents(it.toString()) }
assert(expectedFailures.contains(failure::class)) {
"This test has failed with an unexpected error: \n $failure"
}
Expand Down Expand Up @@ -73,4 +75,10 @@ abstract class AbstractLincheckTest(
}
}

private const val TIMEOUT = 100_000L
private const val TIMEOUT = 100_000L

fun checkTraceHasNoLincheckEvents(trace: String) {
val testPackageOccurrences = trace.split("org.jetbrains.kotlinx.lincheck.test.").size - 1
val lincheckPackageOccurrences = trace.split("org.jetbrains.kotlinx.lincheck.").size - 1
check(testPackageOccurrences == lincheckPackageOccurrences) { "Internal Lincheck events were found in the trace" }
}
Loading

0 comments on commit 3bc3c60

Please sign in to comment.