From e0026e8fae7f66256df9411c1e520ec4e7de1f50 Mon Sep 17 00:00:00 2001
From: Maksim Zuev <maksim.zuev@jetbrains.com>
Date: Tue, 27 Oct 2020 23:42:08 +0300
Subject: [PATCH] Call before

---
 .../lincheck/RecoverabilityTransformer.kt     | 94 +++++++++++++++++++
 .../lincheck/annotations/Recoverable.kt       |  3 +-
 .../lincheck/runner/ParallelThreadsRunner.kt  |  2 +-
 .../lincheck/test/verifier/nlr/CounterTest.kt | 53 +++++++----
 4 files changed, 130 insertions(+), 22 deletions(-)
 create mode 100644 src/main/java/org/jetbrains/kotlinx/lincheck/RecoverabilityTransformer.kt

diff --git a/src/main/java/org/jetbrains/kotlinx/lincheck/RecoverabilityTransformer.kt b/src/main/java/org/jetbrains/kotlinx/lincheck/RecoverabilityTransformer.kt
new file mode 100644
index 000000000..46017ee86
--- /dev/null
+++ b/src/main/java/org/jetbrains/kotlinx/lincheck/RecoverabilityTransformer.kt
@@ -0,0 +1,94 @@
+package org.jetbrains.kotlinx.lincheck
+
+import org.jetbrains.kotlinx.lincheck.TransformationClassLoader.ASM_API
+import org.jetbrains.kotlinx.lincheck.annotations.Recoverable
+import org.jetbrains.kotlinx.lincheck.nvm.CrashError
+import org.objectweb.asm.*
+import org.objectweb.asm.commons.GeneratorAdapter
+import org.objectweb.asm.commons.Method
+
+
+class RecoverabilityTransformer(cv: ClassVisitor) : ClassVisitor(ASM_API, cv) {
+    private lateinit var name: String
+
+    override fun visit(
+        version: Int,
+        access: Int,
+        name: String?,
+        signature: String?,
+        superName: String?,
+        interfaces: Array<out String>?
+    ) {
+        super.visit(version, access, name, signature, superName, interfaces)
+        this.name = name!!
+    }
+
+
+    override fun visitMethod(
+        access: Int,
+        name: String?,
+        descriptor: String?,
+        signature: String?,
+        exceptions: Array<out String>?
+    ) = RecoverableMethodTransformer(
+        super.visitMethod(access, name, descriptor, signature, exceptions),
+        access,
+        name,
+        descriptor,
+        this.name
+    )
+}
+
+class RecoverableMethodTransformer(
+    mv: MethodVisitor,
+    access: Int,
+    name: String?,
+    private val descriptor: String?,
+    private val className: String
+) : GeneratorAdapter(ASM_API, mv, access, name, descriptor) {
+    private var shouldTransform = false
+    lateinit var beforeName: String
+    lateinit var recoverName: String
+
+    override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor {
+        val av = super.visitAnnotation(descriptor, visible)
+        if (descriptor != Type.getDescriptor(Recoverable::class.java)) return av
+        shouldTransform = true
+        return object : AnnotationVisitor(ASM_API, av) {
+            override fun visit(name: String?, value: Any?) {
+                super.visit(name, value)
+                if (name == "recoverMethod") {
+                    recoverName = value as String
+                } else if (name == "beforeMethod") {
+                    beforeName = value as String
+                }
+            }
+        }
+    }
+
+    override fun visitCode() {
+        super.visitCode()
+        if (!shouldTransform) return
+        val loop = Label()
+        val startLabel = Label()
+        val endLabel = Label()
+        val catchLabel = Label()
+        visitTryCatchBlock(startLabel, endLabel, catchLabel, Type.getInternalName(CrashError::class.java))
+
+        val resultIndex = newLocal(Type.BOOLEAN_TYPE)
+        push(false)
+        storeLocal(resultIndex)
+
+        visitLabel(loop)
+        visitLabel(startLabel)
+        visitVarInsn(Opcodes.ALOAD, 0)
+        visitVarInsn(Opcodes.ALOAD, 1) // get first parameter
+        invokeVirtual(Type.getType("L$className;"), Method(beforeName, "(I)V"))
+        push(true)
+        storeLocal(resultIndex)
+        visitLabel(endLabel)
+        visitLabel(catchLabel)
+        loadLocal(resultIndex)
+        ifZCmp(EQ, loop)
+    }
+}
diff --git a/src/main/java/org/jetbrains/kotlinx/lincheck/annotations/Recoverable.kt b/src/main/java/org/jetbrains/kotlinx/lincheck/annotations/Recoverable.kt
index c44a0c322..a76751115 100644
--- a/src/main/java/org/jetbrains/kotlinx/lincheck/annotations/Recoverable.kt
+++ b/src/main/java/org/jetbrains/kotlinx/lincheck/annotations/Recoverable.kt
@@ -8,5 +8,6 @@ package org.jetbrains.kotlinx.lincheck.annotations
 @Retention(AnnotationRetention.RUNTIME)
 @Target(AnnotationTarget.FUNCTION)
 annotation class Recoverable(
-    val recoverMethod: String
+    val recoverMethod: String,
+    val beforeMethod: String
 )
diff --git a/src/main/java/org/jetbrains/kotlinx/lincheck/runner/ParallelThreadsRunner.kt b/src/main/java/org/jetbrains/kotlinx/lincheck/runner/ParallelThreadsRunner.kt
index 5039e7285..594a3e959 100644
--- a/src/main/java/org/jetbrains/kotlinx/lincheck/runner/ParallelThreadsRunner.kt
+++ b/src/main/java/org/jetbrains/kotlinx/lincheck/runner/ParallelThreadsRunner.kt
@@ -293,7 +293,7 @@ internal open class ParallelThreadsRunner(
     override fun needsTransformation() = true
 
     override fun createTransformer(cv: ClassVisitor): ClassVisitor {
-        return CancellabilitySupportClassTransformer(cv)
+        return RecoverabilityTransformer(CancellabilitySupportClassTransformer(cv))
     }
 
     override fun getStateRepresentation(): String? =
diff --git a/src/test/java/org/jetbrains/kotlinx/lincheck/test/verifier/nlr/CounterTest.kt b/src/test/java/org/jetbrains/kotlinx/lincheck/test/verifier/nlr/CounterTest.kt
index 55ebcee7e..c4ae7bf1d 100644
--- a/src/test/java/org/jetbrains/kotlinx/lincheck/test/verifier/nlr/CounterTest.kt
+++ b/src/test/java/org/jetbrains/kotlinx/lincheck/test/verifier/nlr/CounterTest.kt
@@ -1,19 +1,25 @@
 package org.jetbrains.kotlinx.lincheck.test.verifier.nlr
 
-import org.jetbrains.kotlinx.lincheck.Options
+import org.jetbrains.kotlinx.lincheck.LinChecker
 import org.jetbrains.kotlinx.lincheck.annotations.Operation
 import org.jetbrains.kotlinx.lincheck.annotations.Param
+import org.jetbrains.kotlinx.lincheck.annotations.Recoverable
+import org.jetbrains.kotlinx.lincheck.nvm.Persistent
 import org.jetbrains.kotlinx.lincheck.paramgen.ThreadIdGen
-import org.jetbrains.kotlinx.lincheck.test.AbstractLincheckTest
+import org.jetbrains.kotlinx.lincheck.strategy.stress.StressCTest
 import org.jetbrains.kotlinx.lincheck.verifier.VerifierState
-import org.jetbrains.kotlinx.lincheck.verifier.linearizability.LinearizabilityVerifier
+import org.junit.Test
 
 private const val THREADS_NUMBER = 2
 
 /**
  * @see  <a href="https://www.cs.bgu.ac.il/~hendlerd/papers/NRL.pdf">Nesting-Safe Recoverable Linearizability</a>
  */
-class CounterTest : AbstractLincheckTest() {
+@StressCTest(
+    sequentialSpecification = SequentialCounter::class,
+    threads = THREADS_NUMBER
+)
+class CounterTest {
     private val counter = NRLCounter(THREADS_NUMBER + 2)
 
     @Operation
@@ -22,10 +28,8 @@ class CounterTest : AbstractLincheckTest() {
     @Operation
     fun get(@Param(gen = ThreadIdGen::class) threadId: Int) = counter.get(threadId)
 
-    override fun <O : Options<O, *>> O.customize() {
-        sequentialSpecification(SequentialCounter::class.java)
-        threads(THREADS_NUMBER)
-    }
+    @Test
+    fun test() = LinChecker.check(this::class.java)
 }
 
 class SequentialCounter : VerifierState() {
@@ -41,27 +45,36 @@ class SequentialCounter : VerifierState() {
 
 class NRLCounter(threadsCount: Int) : VerifierState() {
     private val R = List(threadsCount) { NRLReadWriteObject<Int>(1).also { it.write(0, 0) } }
-    private val Response = MutableList(threadsCount) { 0 }
-    private val LineInstruction = MutableList(threadsCount) { 0 }
+    private val Response = MutableList(threadsCount) { Persistent(0) }
+    private val CheckPointer = MutableList(threadsCount) { Persistent(0) }
+    private val CurrentValue = MutableList(threadsCount) { Persistent(0) }
 
     override fun extractState() = R.sumBy { it.read()!! }
 
-    fun get(p: Int): Int {                         LineInstruction[p] = 43
-        val returnValue = R.sumBy { it.read()!! }; LineInstruction[p] = 44
-        Response[p] = returnValue;                 LineInstruction[p] = 45
-        // flush
+    fun get(p: Int): Int {
+        val returnValue = R.sumBy { it.read()!! }
+        Response[p].write(p, returnValue)
+        Response[p].flush(p)
         return returnValue
     }
 
     fun getRecover(p: Int) = get(p)
 
-    fun increment(p: Int) {                        LineInstruction[p] = 52
-        val newValue = 1 + R[p].read()!!;          LineInstruction[p] = 53
-        R[p].write(0, newValue);                LineInstruction[p] = 54
-        // flush
+    @Recoverable(beforeMethod = "incrementBefore", recoverMethod = "")
+    fun increment(p: Int) {
+        R[p].write(0, 1 + CurrentValue[p].read(p)!!)
+        CheckPointer[p].write(p, 1)
+        CheckPointer[p].flush(p)
+    }
+
+    private fun incrementRecover(p: Int) {
+        if (CheckPointer[p].read(p) == 0) return increment(p)
     }
 
-    fun incrementRecover(p: Int) {
-        if (LineInstruction[p] < 54) return increment(p) // it is enough to have only one marker at line 54
+    fun incrementBefore(p: Int) {
+        CurrentValue[p].write(p, R[p].read()!!)
+        CheckPointer[p].write(p, 0)
+        CurrentValue[p].flush(p)
+        CurrentValue[p].flush(p)
     }
 }