From e0026e8fae7f66256df9411c1e520ec4e7de1f50 Mon Sep 17 00:00:00 2001 From: Maksim Zuev <maksim.zuev@jetbrains.com> Date: Tue, 27 Oct 2020 23:42:08 +0300 Subject: [PATCH] Call before --- .../lincheck/RecoverabilityTransformer.kt | 94 +++++++++++++++++++ .../lincheck/annotations/Recoverable.kt | 3 +- .../lincheck/runner/ParallelThreadsRunner.kt | 2 +- .../lincheck/test/verifier/nlr/CounterTest.kt | 53 +++++++---- 4 files changed, 130 insertions(+), 22 deletions(-) create mode 100644 src/main/java/org/jetbrains/kotlinx/lincheck/RecoverabilityTransformer.kt diff --git a/src/main/java/org/jetbrains/kotlinx/lincheck/RecoverabilityTransformer.kt b/src/main/java/org/jetbrains/kotlinx/lincheck/RecoverabilityTransformer.kt new file mode 100644 index 000000000..46017ee86 --- /dev/null +++ b/src/main/java/org/jetbrains/kotlinx/lincheck/RecoverabilityTransformer.kt @@ -0,0 +1,94 @@ +package org.jetbrains.kotlinx.lincheck + +import org.jetbrains.kotlinx.lincheck.TransformationClassLoader.ASM_API +import org.jetbrains.kotlinx.lincheck.annotations.Recoverable +import org.jetbrains.kotlinx.lincheck.nvm.CrashError +import org.objectweb.asm.* +import org.objectweb.asm.commons.GeneratorAdapter +import org.objectweb.asm.commons.Method + + +class RecoverabilityTransformer(cv: ClassVisitor) : ClassVisitor(ASM_API, cv) { + private lateinit var name: String + + override fun visit( + version: Int, + access: Int, + name: String?, + signature: String?, + superName: String?, + interfaces: Array<out String>? + ) { + super.visit(version, access, name, signature, superName, interfaces) + this.name = name!! + } + + + override fun visitMethod( + access: Int, + name: String?, + descriptor: String?, + signature: String?, + exceptions: Array<out String>? + ) = RecoverableMethodTransformer( + super.visitMethod(access, name, descriptor, signature, exceptions), + access, + name, + descriptor, + this.name + ) +} + +class RecoverableMethodTransformer( + mv: MethodVisitor, + access: Int, + name: String?, + private val descriptor: String?, + private val className: String +) : GeneratorAdapter(ASM_API, mv, access, name, descriptor) { + private var shouldTransform = false + lateinit var beforeName: String + lateinit var recoverName: String + + override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor { + val av = super.visitAnnotation(descriptor, visible) + if (descriptor != Type.getDescriptor(Recoverable::class.java)) return av + shouldTransform = true + return object : AnnotationVisitor(ASM_API, av) { + override fun visit(name: String?, value: Any?) { + super.visit(name, value) + if (name == "recoverMethod") { + recoverName = value as String + } else if (name == "beforeMethod") { + beforeName = value as String + } + } + } + } + + override fun visitCode() { + super.visitCode() + if (!shouldTransform) return + val loop = Label() + val startLabel = Label() + val endLabel = Label() + val catchLabel = Label() + visitTryCatchBlock(startLabel, endLabel, catchLabel, Type.getInternalName(CrashError::class.java)) + + val resultIndex = newLocal(Type.BOOLEAN_TYPE) + push(false) + storeLocal(resultIndex) + + visitLabel(loop) + visitLabel(startLabel) + visitVarInsn(Opcodes.ALOAD, 0) + visitVarInsn(Opcodes.ALOAD, 1) // get first parameter + invokeVirtual(Type.getType("L$className;"), Method(beforeName, "(I)V")) + push(true) + storeLocal(resultIndex) + visitLabel(endLabel) + visitLabel(catchLabel) + loadLocal(resultIndex) + ifZCmp(EQ, loop) + } +} diff --git a/src/main/java/org/jetbrains/kotlinx/lincheck/annotations/Recoverable.kt b/src/main/java/org/jetbrains/kotlinx/lincheck/annotations/Recoverable.kt index c44a0c322..a76751115 100644 --- a/src/main/java/org/jetbrains/kotlinx/lincheck/annotations/Recoverable.kt +++ b/src/main/java/org/jetbrains/kotlinx/lincheck/annotations/Recoverable.kt @@ -8,5 +8,6 @@ package org.jetbrains.kotlinx.lincheck.annotations @Retention(AnnotationRetention.RUNTIME) @Target(AnnotationTarget.FUNCTION) annotation class Recoverable( - val recoverMethod: String + val recoverMethod: String, + val beforeMethod: String ) diff --git a/src/main/java/org/jetbrains/kotlinx/lincheck/runner/ParallelThreadsRunner.kt b/src/main/java/org/jetbrains/kotlinx/lincheck/runner/ParallelThreadsRunner.kt index 5039e7285..594a3e959 100644 --- a/src/main/java/org/jetbrains/kotlinx/lincheck/runner/ParallelThreadsRunner.kt +++ b/src/main/java/org/jetbrains/kotlinx/lincheck/runner/ParallelThreadsRunner.kt @@ -293,7 +293,7 @@ internal open class ParallelThreadsRunner( override fun needsTransformation() = true override fun createTransformer(cv: ClassVisitor): ClassVisitor { - return CancellabilitySupportClassTransformer(cv) + return RecoverabilityTransformer(CancellabilitySupportClassTransformer(cv)) } override fun getStateRepresentation(): String? = diff --git a/src/test/java/org/jetbrains/kotlinx/lincheck/test/verifier/nlr/CounterTest.kt b/src/test/java/org/jetbrains/kotlinx/lincheck/test/verifier/nlr/CounterTest.kt index 55ebcee7e..c4ae7bf1d 100644 --- a/src/test/java/org/jetbrains/kotlinx/lincheck/test/verifier/nlr/CounterTest.kt +++ b/src/test/java/org/jetbrains/kotlinx/lincheck/test/verifier/nlr/CounterTest.kt @@ -1,19 +1,25 @@ package org.jetbrains.kotlinx.lincheck.test.verifier.nlr -import org.jetbrains.kotlinx.lincheck.Options +import org.jetbrains.kotlinx.lincheck.LinChecker import org.jetbrains.kotlinx.lincheck.annotations.Operation import org.jetbrains.kotlinx.lincheck.annotations.Param +import org.jetbrains.kotlinx.lincheck.annotations.Recoverable +import org.jetbrains.kotlinx.lincheck.nvm.Persistent import org.jetbrains.kotlinx.lincheck.paramgen.ThreadIdGen -import org.jetbrains.kotlinx.lincheck.test.AbstractLincheckTest +import org.jetbrains.kotlinx.lincheck.strategy.stress.StressCTest import org.jetbrains.kotlinx.lincheck.verifier.VerifierState -import org.jetbrains.kotlinx.lincheck.verifier.linearizability.LinearizabilityVerifier +import org.junit.Test private const val THREADS_NUMBER = 2 /** * @see <a href="https://www.cs.bgu.ac.il/~hendlerd/papers/NRL.pdf">Nesting-Safe Recoverable Linearizability</a> */ -class CounterTest : AbstractLincheckTest() { +@StressCTest( + sequentialSpecification = SequentialCounter::class, + threads = THREADS_NUMBER +) +class CounterTest { private val counter = NRLCounter(THREADS_NUMBER + 2) @Operation @@ -22,10 +28,8 @@ class CounterTest : AbstractLincheckTest() { @Operation fun get(@Param(gen = ThreadIdGen::class) threadId: Int) = counter.get(threadId) - override fun <O : Options<O, *>> O.customize() { - sequentialSpecification(SequentialCounter::class.java) - threads(THREADS_NUMBER) - } + @Test + fun test() = LinChecker.check(this::class.java) } class SequentialCounter : VerifierState() { @@ -41,27 +45,36 @@ class SequentialCounter : VerifierState() { class NRLCounter(threadsCount: Int) : VerifierState() { private val R = List(threadsCount) { NRLReadWriteObject<Int>(1).also { it.write(0, 0) } } - private val Response = MutableList(threadsCount) { 0 } - private val LineInstruction = MutableList(threadsCount) { 0 } + private val Response = MutableList(threadsCount) { Persistent(0) } + private val CheckPointer = MutableList(threadsCount) { Persistent(0) } + private val CurrentValue = MutableList(threadsCount) { Persistent(0) } override fun extractState() = R.sumBy { it.read()!! } - fun get(p: Int): Int { LineInstruction[p] = 43 - val returnValue = R.sumBy { it.read()!! }; LineInstruction[p] = 44 - Response[p] = returnValue; LineInstruction[p] = 45 - // flush + fun get(p: Int): Int { + val returnValue = R.sumBy { it.read()!! } + Response[p].write(p, returnValue) + Response[p].flush(p) return returnValue } fun getRecover(p: Int) = get(p) - fun increment(p: Int) { LineInstruction[p] = 52 - val newValue = 1 + R[p].read()!!; LineInstruction[p] = 53 - R[p].write(0, newValue); LineInstruction[p] = 54 - // flush + @Recoverable(beforeMethod = "incrementBefore", recoverMethod = "") + fun increment(p: Int) { + R[p].write(0, 1 + CurrentValue[p].read(p)!!) + CheckPointer[p].write(p, 1) + CheckPointer[p].flush(p) + } + + private fun incrementRecover(p: Int) { + if (CheckPointer[p].read(p) == 0) return increment(p) } - fun incrementRecover(p: Int) { - if (LineInstruction[p] < 54) return increment(p) // it is enough to have only one marker at line 54 + fun incrementBefore(p: Int) { + CurrentValue[p].write(p, R[p].read()!!) + CheckPointer[p].write(p, 0) + CurrentValue[p].flush(p) + CurrentValue[p].flush(p) } }