Skip to content

Commit

Permalink
Call before
Browse files Browse the repository at this point in the history
  • Loading branch information
zuevmaxim committed Oct 27, 2020
1 parent 9f62e11 commit e0026e8
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package org.jetbrains.kotlinx.lincheck

import org.jetbrains.kotlinx.lincheck.TransformationClassLoader.ASM_API
import org.jetbrains.kotlinx.lincheck.annotations.Recoverable
import org.jetbrains.kotlinx.lincheck.nvm.CrashError
import org.objectweb.asm.*
import org.objectweb.asm.commons.GeneratorAdapter
import org.objectweb.asm.commons.Method


class RecoverabilityTransformer(cv: ClassVisitor) : ClassVisitor(ASM_API, cv) {
private lateinit var name: String

override fun visit(
version: Int,
access: Int,
name: String?,
signature: String?,
superName: String?,
interfaces: Array<out String>?
) {
super.visit(version, access, name, signature, superName, interfaces)
this.name = name!!
}


override fun visitMethod(
access: Int,
name: String?,
descriptor: String?,
signature: String?,
exceptions: Array<out String>?
) = RecoverableMethodTransformer(
super.visitMethod(access, name, descriptor, signature, exceptions),
access,
name,
descriptor,
this.name
)
}

class RecoverableMethodTransformer(
mv: MethodVisitor,
access: Int,
name: String?,
private val descriptor: String?,
private val className: String
) : GeneratorAdapter(ASM_API, mv, access, name, descriptor) {
private var shouldTransform = false
lateinit var beforeName: String
lateinit var recoverName: String

override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor {
val av = super.visitAnnotation(descriptor, visible)
if (descriptor != Type.getDescriptor(Recoverable::class.java)) return av
shouldTransform = true
return object : AnnotationVisitor(ASM_API, av) {
override fun visit(name: String?, value: Any?) {
super.visit(name, value)
if (name == "recoverMethod") {
recoverName = value as String
} else if (name == "beforeMethod") {
beforeName = value as String
}
}
}
}

override fun visitCode() {
super.visitCode()
if (!shouldTransform) return
val loop = Label()
val startLabel = Label()
val endLabel = Label()
val catchLabel = Label()
visitTryCatchBlock(startLabel, endLabel, catchLabel, Type.getInternalName(CrashError::class.java))

val resultIndex = newLocal(Type.BOOLEAN_TYPE)
push(false)
storeLocal(resultIndex)

visitLabel(loop)
visitLabel(startLabel)
visitVarInsn(Opcodes.ALOAD, 0)
visitVarInsn(Opcodes.ALOAD, 1) // get first parameter
invokeVirtual(Type.getType("L$className;"), Method(beforeName, "(I)V"))
push(true)
storeLocal(resultIndex)
visitLabel(endLabel)
visitLabel(catchLabel)
loadLocal(resultIndex)
ifZCmp(EQ, loop)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ package org.jetbrains.kotlinx.lincheck.annotations
@Retention(AnnotationRetention.RUNTIME)
@Target(AnnotationTarget.FUNCTION)
annotation class Recoverable(
val recoverMethod: String
val recoverMethod: String,
val beforeMethod: String
)
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ internal open class ParallelThreadsRunner(
override fun needsTransformation() = true

override fun createTransformer(cv: ClassVisitor): ClassVisitor {
return CancellabilitySupportClassTransformer(cv)
return RecoverabilityTransformer(CancellabilitySupportClassTransformer(cv))
}

override fun getStateRepresentation(): String? =
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
package org.jetbrains.kotlinx.lincheck.test.verifier.nlr

import org.jetbrains.kotlinx.lincheck.Options
import org.jetbrains.kotlinx.lincheck.LinChecker
import org.jetbrains.kotlinx.lincheck.annotations.Operation
import org.jetbrains.kotlinx.lincheck.annotations.Param
import org.jetbrains.kotlinx.lincheck.annotations.Recoverable
import org.jetbrains.kotlinx.lincheck.nvm.Persistent
import org.jetbrains.kotlinx.lincheck.paramgen.ThreadIdGen
import org.jetbrains.kotlinx.lincheck.test.AbstractLincheckTest
import org.jetbrains.kotlinx.lincheck.strategy.stress.StressCTest
import org.jetbrains.kotlinx.lincheck.verifier.VerifierState
import org.jetbrains.kotlinx.lincheck.verifier.linearizability.LinearizabilityVerifier
import org.junit.Test

private const val THREADS_NUMBER = 2

/**
* @see <a href="https://www.cs.bgu.ac.il/~hendlerd/papers/NRL.pdf">Nesting-Safe Recoverable Linearizability</a>
*/
class CounterTest : AbstractLincheckTest() {
@StressCTest(
sequentialSpecification = SequentialCounter::class,
threads = THREADS_NUMBER
)
class CounterTest {
private val counter = NRLCounter(THREADS_NUMBER + 2)

@Operation
Expand All @@ -22,10 +28,8 @@ class CounterTest : AbstractLincheckTest() {
@Operation
fun get(@Param(gen = ThreadIdGen::class) threadId: Int) = counter.get(threadId)

override fun <O : Options<O, *>> O.customize() {
sequentialSpecification(SequentialCounter::class.java)
threads(THREADS_NUMBER)
}
@Test
fun test() = LinChecker.check(this::class.java)
}

class SequentialCounter : VerifierState() {
Expand All @@ -41,27 +45,36 @@ class SequentialCounter : VerifierState() {

class NRLCounter(threadsCount: Int) : VerifierState() {
private val R = List(threadsCount) { NRLReadWriteObject<Int>(1).also { it.write(0, 0) } }
private val Response = MutableList(threadsCount) { 0 }
private val LineInstruction = MutableList(threadsCount) { 0 }
private val Response = MutableList(threadsCount) { Persistent(0) }
private val CheckPointer = MutableList(threadsCount) { Persistent(0) }
private val CurrentValue = MutableList(threadsCount) { Persistent(0) }

override fun extractState() = R.sumBy { it.read()!! }

fun get(p: Int): Int { LineInstruction[p] = 43
val returnValue = R.sumBy { it.read()!! }; LineInstruction[p] = 44
Response[p] = returnValue; LineInstruction[p] = 45
// flush
fun get(p: Int): Int {
val returnValue = R.sumBy { it.read()!! }
Response[p].write(p, returnValue)
Response[p].flush(p)
return returnValue
}

fun getRecover(p: Int) = get(p)

fun increment(p: Int) { LineInstruction[p] = 52
val newValue = 1 + R[p].read()!!; LineInstruction[p] = 53
R[p].write(0, newValue); LineInstruction[p] = 54
// flush
@Recoverable(beforeMethod = "incrementBefore", recoverMethod = "")
fun increment(p: Int) {
R[p].write(0, 1 + CurrentValue[p].read(p)!!)
CheckPointer[p].write(p, 1)
CheckPointer[p].flush(p)
}

private fun incrementRecover(p: Int) {
if (CheckPointer[p].read(p) == 0) return increment(p)
}

fun incrementRecover(p: Int) {
if (LineInstruction[p] < 54) return increment(p) // it is enough to have only one marker at line 54
fun incrementBefore(p: Int) {
CurrentValue[p].write(p, R[p].read()!!)
CheckPointer[p].write(p, 0)
CurrentValue[p].flush(p)
CurrentValue[p].flush(p)
}
}

0 comments on commit e0026e8

Please sign in to comment.