From b4189c737ae63ff2687bf911c9c129e9f606893d Mon Sep 17 00:00:00 2001 From: acertain Date: Sat, 8 May 2021 19:27:00 -0700 Subject: [PATCH] ffi --- .gitignore | 2 +- build.gradle.kts | 5 + src/c/rts.c | 41 ++++ src/trufflestg/data/AlgData.kt | 6 +- src/trufflestg/data/Closure.kt | 2 - src/trufflestg/data/stg_data.kt | 191 +++++++++++++++--- src/trufflestg/frame/DataFrame.kt | 5 +- src/trufflestg/frame/frame_assembly.kt | 1 - src/trufflestg/jit/StgFCall.kt | 266 ++++++++++++++++++++----- src/trufflestg/jit/StgPrim.kt | 111 +++++++---- src/trufflestg/jit/code.kt | 5 +- src/trufflestg/jit/dispatch.kt | 9 +- src/trufflestg/language.kt | 3 +- src/trufflestg/stg/elab.kt | 46 +++-- 14 files changed, 535 insertions(+), 158 deletions(-) create mode 100644 src/c/rts.c diff --git a/.gitignore b/.gitignore index f91974f..76948d7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ **/*.iml **/.idea -**/build/* +build/* **/dist-newstyle/** **/dist/** *.ipr diff --git a/build.gradle.kts b/build.gradle.kts index 957492e..ed50b63 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -136,6 +136,8 @@ application { "--module-path=${compiler.asPath}", "--upgrade-module-path=${compiler.asPath}", "--add-opens=jdk.internal.vm.compiler/org.graalvm.compiler.truffle.runtime=ALL-UNNAMED", + "--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED", + "--add-opens=java.base/jdk.internal.ref=ALL-UNNAMED", "-Dtruffle.class.path.append=@TRUFFLESTG_APP_HOME@/lib/trufflestg-${project.version}.jar", "-Xss32m" ) @@ -150,6 +152,8 @@ val graalArgs = listOf( "--upgrade-module-path=${compiler.asPath}", // "-XX:-UseJVMCIClassLoader", // "-Dgraalvm.locatorDisabled=true", + "--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED", + "--add-opens=java.base/jdk.internal.ref=ALL-UNNAMED", "--add-opens=jdk.internal.vm.compiler/org.graalvm.compiler.truffle.runtime=ALL-UNNAMED", "-Dtruffle.class.path.append=build/libs/trufflestg-${project.version}.jar", "-Xss32m" @@ -252,6 +256,7 @@ tasks.withType { from(files("etc/symlinks","etc/permissions")) { rename("(.*)","META-INF/$1") } + from("src/c/librts.so") } tasks.withType { diff --git a/src/c/rts.c b/src/c/rts.c new file mode 100644 index 0000000..605ecad --- /dev/null +++ b/src/c/rts.c @@ -0,0 +1,41 @@ +#include +#include +#include + +// TODO +extern int64_t base_GHCziTopHandler_runIO_closure[] = {}; + +// TODO: figure out how to deal with foreign exports +// look at Module.foreignStubs +extern int64_t base_ControlziConcurrent_zdfstableZZC0ZZCbaseZZCControlzziConcurrentZZCforkOSzzuentry_closure[] = {}; + +extern bool keepCAFs = false; + +extern void* stable_ptr_table = NULL; +extern unsigned int n_capabilities = 4; +extern int64_t RtsFlags[1048576] = {0}; // is actually a struct, allocate 8MB to be safe + +void unblockUserSignals(void) {} +void blockUserSignals(void) {} + +// used for profiling, so we don't need it i think +void startTimer(void) {} +void stopTimer(void) {} + + +// TODO +void debugBelch() {} +void errorBelch() {} +void _assertFail() {} +void rts_mkStablePtr() {} +void rts_unlock() {} +void rts_lock() {} +void rts_checkSchedStatus() {} +void rts_mkInt32() {} +void rts_getInt32() {} +void rts_evalIO() {} +void rts_apply() {} +void foreignExportStablePtr() {} +void getProcessElapsedTime() {} + +void __dummy_hfehfewhoihjofijwoiefhua() {} diff --git a/src/trufflestg/data/AlgData.kt b/src/trufflestg/data/AlgData.kt index cf70adc..1ad030d 100644 --- a/src/trufflestg/data/AlgData.kt +++ b/src/trufflestg/data/AlgData.kt @@ -153,9 +153,6 @@ class DynamicClassDataConInfo( // TODO: use CompilerDirectives.isExact once its released if (klass!!.isInstance(x)) CompilerDirectives.castExact(x, klass) else null - // TODO: unbox fields, and set field type when we know it can't be a thunk (bang patterns) - // ghc-wpc isn't telling us about bang patterns though :( - // maybe could look at worker type? private val klass: Class? = if (size == 0) null else run { // TODO: better mangling fun mangle(x: String) = x.replace("""[\[\]]""".toRegex(), "_") @@ -169,6 +166,8 @@ class DynamicClassDataConInfo( is Stg.PrimRep.IntRep -> stgIntFieldInfo is Stg.PrimRep.WordRep -> stgWordFieldInfo is Stg.PrimRep.LiftedRep -> objectFieldInfo + // TODO: set field type when we know it (bang patterns) + // need to add bang pattern & type info to ghc-wpc else -> { // println("todo PrimRep $it") objectFieldInfo @@ -195,7 +194,6 @@ class DynamicClassDataConInfo( val field = kls.getDeclaredField("_info") field.set(null, this) val modifiers = Field::class.java.getDeclaredField("modifiers") - // TODO: make sure the final from here gets picked up by the jit modifiers.isAccessible = true modifiers.setInt(field, field.modifiers and Modifier.FINAL) modifiers.isAccessible = false diff --git a/src/trufflestg/data/Closure.kt b/src/trufflestg/data/Closure.kt index 1ca1e98..76303fa 100644 --- a/src/trufflestg/data/Closure.kt +++ b/src/trufflestg/data/Closure.kt @@ -2,8 +2,6 @@ package trufflestg.data import com.oracle.truffle.api.CallTarget import trufflestg.array_utils.* -import trufflestg.frame.DataFrame -import trufflestg.jit.CallUtils import trufflestg.jit.ClosureRootNode import trufflestg.panic import com.oracle.truffle.api.CompilerDirectives diff --git a/src/trufflestg/data/stg_data.kt b/src/trufflestg/data/stg_data.kt index 1defd6d..0e8adcd 100644 --- a/src/trufflestg/data/stg_data.kt +++ b/src/trufflestg/data/stg_data.kt @@ -1,27 +1,28 @@ +@file:Suppress("JAVA_MODULE_DOES_NOT_EXPORT_PACKAGE") package trufflestg.data -import trufflestg.panic -import trufflestg.stg.Stg import com.oracle.truffle.api.CompilerDirectives +import com.oracle.truffle.api.interop.InteropLibrary +import com.oracle.truffle.api.interop.TruffleObject +import com.oracle.truffle.api.interop.UnsupportedMessageException +import com.oracle.truffle.api.library.ExportLibrary +import com.oracle.truffle.api.library.ExportMessage +import jdk.internal.misc.Unsafe import org.intelligence.diagnostics.Severity import org.intelligence.diagnostics.error import org.intelligence.pretty.Pretty -import java.lang.ref.WeakReference -import kotlin.reflect.KClass +import trufflestg.array_utils.write +import trufflestg.jit.asCString +import java.lang.reflect.Constructor import java.nio.ByteBuffer +import java.nio.ByteOrder -// JvmField should make PE faster and might help graal - -// everything here is temporary until i implement unboxed fields etc // TODO: sealed (or just abstract?) class for possible haskell values // TODO: CompilerDirectives.ValueType // VoidRep object VoidInh -object NullAddr - - data class FullName( val unitId: String, val module: String, @@ -34,37 +35,158 @@ data class FullName( // note that currently ghc only has Int#, not Int32# etc, so we only need StgInt // FIXME ghc actually does have various Int*# variants, but they aren't used in Int32 etc? what's going on? @CompilerDirectives.ValueType -data class StgInt(@JvmField val x: Long) { +@ExportLibrary(InteropLibrary::class) +data class StgInt(@JvmField val x: Long) : TruffleObject { fun toInt(): Int = x.toInt() operator fun compareTo(y: StgInt): Int = x.compareTo(y.x) + @ExportMessage fun isNumber(): Boolean = true + @ExportMessage fun fitsInLong(): Boolean = true + @ExportMessage fun asLong(): Long = x + + @ExportMessage fun fitsInByte(): Boolean = false + @ExportMessage fun fitsInShort(): Boolean = false + @ExportMessage fun fitsInInt(): Boolean = false + @ExportMessage fun fitsInFloat(): Boolean = false + @ExportMessage fun fitsInDouble(): Boolean = false + + @ExportMessage fun asByte(): Byte { throw UnsupportedMessageException.create() } + @ExportMessage fun asShort(): Short { throw UnsupportedMessageException.create() } + @ExportMessage fun asInt(): Int { throw UnsupportedMessageException.create() } + @ExportMessage fun asFloat(): Float { throw UnsupportedMessageException.create() } + @ExportMessage fun asDouble(): Double { throw UnsupportedMessageException.create() } + fun unbox(): Long = x companion object { @JvmStatic fun box(x: Long): StgInt = StgInt(x) } } + @CompilerDirectives.ValueType -data class StgWord(@JvmField val x: ULong) { +@OptIn(ExperimentalUnsignedTypes::class) +@ExportLibrary(InteropLibrary::class) +data class StgWord(@JvmField val x: ULong) : TruffleObject { fun asChar(): Int = x.toInt() + @ExportMessage fun isNumber(): Boolean = true + @ExportMessage fun fitsInLong(): Boolean = true + // TODO: is this right?? + @ExportMessage fun asLong(): Long = x.toLong() + + @ExportMessage fun fitsInByte(): Boolean = false + @ExportMessage fun fitsInShort(): Boolean = false + @ExportMessage fun fitsInInt(): Boolean = false + @ExportMessage fun fitsInFloat(): Boolean = false + @ExportMessage fun fitsInDouble(): Boolean = false + + @ExportMessage fun asByte(): Byte { throw UnsupportedMessageException.create() } + @ExportMessage fun asShort(): Short { throw UnsupportedMessageException.create() } + @ExportMessage fun asInt(): Int { throw UnsupportedMessageException.create() } + @ExportMessage fun asFloat(): Float { throw UnsupportedMessageException.create() } + @ExportMessage fun asDouble(): Double { throw UnsupportedMessageException.create() } + fun unbox(): Long = x.toLong() companion object { @JvmStatic fun box(x: Long): StgWord = StgWord(x.toULong()) } } + @CompilerDirectives.ValueType -data class StgDouble(@JvmField val x: Double) {} +data class StgDouble(@JvmField val x: Double) -// for now an Addr# is an offset into an array -// TODO: make arr private & add more methods to avoid errors -class StgAddr( - @JvmField val arr: ByteArray, - @JvmField val offset: Int -) { - operator fun get(ix: Int): Byte = arr[offset + ix] - operator fun set(ix: StgInt, y: Byte) { arr[offset + ix.x.toInt()] = y } +val unsafe: Unsafe = Unsafe.getUnsafe() - fun asArray(): ByteArray = arr.copyOfRange(offset, arr.size) - fun asBuffer(): ByteBuffer = ByteBuffer.wrap(arr, offset, arr.size - offset) +val directBufferCtor: Constructor<*> = run { + val x = Class.forName("java.nio.DirectByteBuffer").getDeclaredConstructor(java.lang.Long.TYPE, java.lang.Integer.TYPE) + x.isAccessible = true + x +} +fun newDirectByteBuffer(addr: Long, cap: Int): ByteBuffer { + return (directBufferCtor.newInstance(addr, cap) as ByteBuffer).order(ByteOrder.LITTLE_ENDIAN) } +sealed class StgAddr { + abstract fun getRange(x: IntRange): ByteArray + abstract fun write(offset: Int, data: ByteArray) + + // x is always in bytes + abstract fun readByte(x: Int): Byte + abstract fun readInt(x: Int): Int + abstract fun readLong(x: Int): Long + abstract fun writeByte(x: Int, y: Byte) + abstract fun writeInt(x: Int, y: Int) + abstract fun writeLong(x: Int, y: Long) + + abstract fun addOffset(x: Int): StgAddr + + abstract fun asCString(): String + + companion object { + fun fromArray(x: ByteArray) = StgArrayOffsetAddr(x, 0) + val nullAddr: StgAddr = StgFFIAddr(0L) + } + + // for now an Addr# is an offset into an array + // TODO: make arr private & add more methods to avoid errors + class StgArrayOffsetAddr( + @JvmField val arr: ByteArray, + @JvmField val offset: Int + ): StgAddr() { + override fun equals(other: Any?): Boolean = other is StgArrayOffsetAddr && arr === other.arr && offset == other.offset + + override fun getRange(x: IntRange): ByteArray = arr.copyOfRange(offset + x.first, offset + x.last + 1) + override fun write(offset: Int, data: ByteArray) = arr.write(offset, data) + override fun readByte(x: Int): Byte = arr[offset + x] + override fun readInt(x: Int): Int = ByteBuffer.wrap(arr).getInt(offset + x) + override fun readLong(x: Int): Long = ByteBuffer.wrap(arr).getLong(offset + x) + override fun writeByte(x: Int, y: Byte) { arr[offset + x] = y } + override fun writeInt(x: Int, y: Int) = arr.write(offset + x, y) + override fun writeLong(x: Int, y: Long) = arr.write(offset + x, y) + + override fun addOffset(x: Int): StgAddr = StgArrayOffsetAddr(arr, offset + x) + + override fun asCString(): String = asArray().asCString() + + fun asArray() = arr.copyOfRange(offset, arr.size) + + operator fun get(ix: Int): Byte = arr[offset + ix] + operator fun set(ix: StgInt, y: Byte) { arr[offset + ix.x.toInt()] = y } + + } + + @ExportLibrary(InteropLibrary::class) + class StgFFIAddr(@JvmField val addr: Long): StgAddr(), TruffleObject { + override fun equals(other: Any?): Boolean = other is StgFFIAddr && addr == other.addr + override fun toString(): String = "StgFFIAddr($addr)" + + @ExportMessage fun isPointer(): Boolean = true + @ExportMessage fun asPointer(): Long = addr + + // TODO: unsafe's doc implies it's undefined to use it on addresses not from the jvm, figure out if that is actually true and if so replace it with something else + + override fun getRange(x: IntRange): ByteArray { + val size = x.last - x.first + 1 + val arr = ByteArray(size) + unsafe.copyMemory(null, addr, arr, Unsafe.ARRAY_BYTE_BASE_OFFSET.toLong(), size.toLong()) + return arr + } + override fun write(offset: Int, data: ByteArray) { unsafe.copyMemory(data, (Unsafe.ARRAY_BYTE_BASE_OFFSET + offset).toLong(), null, addr, data.size.toLong()) } + + override fun readByte(x: Int): Byte = unsafe.getByte(addr + x) + override fun readInt(x: Int): Int = unsafe.getInt(addr + x) + override fun readLong(x: Int): Long = unsafe.getLong(addr + x) + + override fun writeByte(x: Int, y: Byte) = unsafe.putByte(addr + x, y) + override fun writeInt(x: Int, y: Int) = unsafe.putInt(addr + x, y) + override fun writeLong(x: Int, y: Long) = unsafe.putLong(addr + x, y) + + override fun addOffset(x: Int): StgAddr = StgFFIAddr(addr + x) + + override fun asCString(): String { + TODO("Not yet implemented") + } + } +} + + // afaict this guy can be mutable (& freeze & unfreeze operate on it)? +// :/ class StgArray( @JvmField val arr: Array ) { @@ -72,12 +194,21 @@ class StgArray( operator fun set(y: StgInt, value: Any) { arr[y.toInt()] = value } } -class StgMutableByteArray( - @JvmField val arr: ByteArray -) { - @JvmField var frozen: Boolean = false +sealed class StgByteArray { + abstract fun asAddr(): StgAddr + + class StgJvmByteArray(@JvmField val arr: ByteArray) : StgByteArray() { + override fun asAddr(): StgAddr = StgAddr.StgArrayOffsetAddr(arr, 0) + } + class StgPinnedByteArray(@JvmField val len: Long) : StgByteArray() { + @JvmField val addr: Long = unsafe.allocateMemory(len) + class StgFFIByteArrayCleanup(private val addr: Long) : Runnable { + override fun run() { unsafe.freeMemory(addr) } + } + init { jdk.internal.ref.Cleaner.create(this, StgFFIByteArrayCleanup(addr)) } - fun asBuffer(): ByteBuffer = ByteBuffer.wrap(arr) + override fun asAddr(): StgAddr = StgAddr.StgFFIAddr(addr) + } } //class StgByteArray( @@ -120,7 +251,7 @@ class WeakRef( // TODO: UnboxedTuple1 etc? @CompilerDirectives.ValueType class UnboxedTuple( - @JvmField val x: Array + @JvmField @CompilerDirectives.CompilationFinal(dimensions = 1) val x: Array ) { override fun toString(): String = "(# " + x.joinToString(", ") { it.toString() } + " #)" } @@ -130,8 +261,8 @@ class ThreadId( @JvmField val id: Long ) - // TODO: store a weakref here class StablePtr( @JvmField var x: Any? ) + diff --git a/src/trufflestg/frame/DataFrame.kt b/src/trufflestg/frame/DataFrame.kt index 6437292..067ed6f 100644 --- a/src/trufflestg/frame/DataFrame.kt +++ b/src/trufflestg/frame/DataFrame.kt @@ -1,10 +1,7 @@ package trufflestg.frame -import com.oracle.truffle.api.frame.FrameSlotTypeException - typealias Slot = Int -// these are all the distinctions the JVM cares about -// TODO: i'm currently only using getValue, and probably won't use any of the others + // TODO: make this an abstract class: casting to superclasses should be faster than casting to interfaces interface DataFrame { fun getValue(slot: Slot): Any? diff --git a/src/trufflestg/frame/frame_assembly.kt b/src/trufflestg/frame/frame_assembly.kt index 9386160..af6ca52 100644 --- a/src/trufflestg/frame/frame_assembly.kt +++ b/src/trufflestg/frame/frame_assembly.kt @@ -45,7 +45,6 @@ fun ClassNode.frameBody(types: Array, superCls: Type = type(Object::c iconst_0 istore_2 - // FIXME: is this wrong? i think it needs to unbox? (only works for object as is?) for (i in types.indices) { aload_0 aload_1 diff --git a/src/trufflestg/jit/StgFCall.kt b/src/trufflestg/jit/StgFCall.kt index 2236ca1..0ba7d2d 100644 --- a/src/trufflestg/jit/StgFCall.kt +++ b/src/trufflestg/jit/StgFCall.kt @@ -1,7 +1,9 @@ +@file:Suppress("JAVA_MODULE_DOES_NOT_EXPORT_PACKAGE") package trufflestg.jit import com.oracle.truffle.api.CompilerDirectives import com.oracle.truffle.api.TruffleStackTrace +import com.oracle.truffle.api.dsl.Specialization import com.oracle.truffle.api.exception.AbstractTruffleException import com.oracle.truffle.api.frame.VirtualFrame import com.oracle.truffle.api.interop.ExceptionType @@ -9,26 +11,147 @@ import com.oracle.truffle.api.interop.InteropLibrary import com.oracle.truffle.api.library.ExportLibrary import com.oracle.truffle.api.library.ExportMessage import com.oracle.truffle.api.nodes.Node +import com.oracle.truffle.api.source.Source +import jdk.internal.vm.annotation.Stable import trufflestg.Language -import trufflestg.array_utils.* +import trufflestg.array_utils.map import trufflestg.data.* import trufflestg.panic +import trufflestg.stg.CborModuleDir import trufflestg.stg.Stg +import java.io.File +import java.nio.file.Files +import java.nio.file.StandardCopyOption import java.security.MessageDigest + +fun getNFIType(x: Any): String = when (x) { + is StgAddr.StgFFIAddr -> "POINTER" + // FIXME: deal with Int8 etc somehow + is StgWord -> "UINT64" + is StgInt -> "SINT64" + else -> TODO("getNFIType $x") +// is TruffleObject -> when { +// InteropLibrary.isNumber(x) -> when {} +// else -> null +// } +// else -> null +} + +fun Stg.PrimRep?.asNFIType(): String = when(this) { + is Stg.PrimRep.AddrRep -> "POINTER" + is Stg.PrimRep.IntRep -> "SINT64" + is Stg.PrimRep.WordRep -> "UINT64" + else -> TODO("$this asNFIType()") +} + +fun dlopen(str: String): Any { +// println("dlopen $str") + val ctx = Language.currentContext().env + val src = Source.newBuilder("nfi", "load (RTLD_GLOBAL) \"${str}\"", "(haskell ffi call)").build() + return ctx.parseInternal(src).call() +} + +@CompilerDirectives.CompilationFinal var rtsLoaded: Boolean = false +fun loadRts() { + if (!rtsLoaded) { + CompilerDirectives.transferToInterpreterAndInvalidate() + rtsLoaded = true + + val libName = System.mapLibraryName("rts") + + val libFile = File.createTempFile("trufflestg-rts",libName) + libFile.deleteOnExit() + + println("Module: ${StgFCall::class.java.module}") + + val s = StgFCall::class.java.getResourceAsStream("/$libName") + if (s === null) { + throw Exception("Can't find $libName: maybe trufflestg doesn't support ffi calls on your OS?") + } + Files.copy(s, libFile.toPath(), StandardCopyOption.REPLACE_EXISTING) + + println(libFile.path) + dlopen(libFile.path) + dlopen("/home/zcarterc/Sync/Code/cadenza/src/c/librts.so") + } +} + + + +val realFFICalls: MutableSet = mutableSetOf() + class StgFCall( + val type: Stg.Type, val x: Stg.ForeignCall, @field:Children val args: Array ) : Code(null) { val name: String = (x.ctarget as Stg.CCallTarget.StaticTarget).string - @field:Child var opNode: StgPrimOp? = primFCalls[name]?.let { it() } + @Child var opNode: StgPrimOp? = primFCalls[name]?.let { it() } + @Child var realFCall: StgRealFCall? = if (opNode === null) StgRealFCall(type, x) else null override fun execute(frame: VirtualFrame): Any { val xs = map(args) { it.execute(frame) } - if (opNode != null) { - return opNode!!.run(frame, xs) + return if (opNode != null) { + opNode!!.run(frame, xs) } else { - panic{"foreign call nyi $x ${xs.contentToString()}"} + realFCall!!.execute(xs) + } + } +} + +class StgRealFCall( + val type: Stg.Type, + val x: Stg.ForeignCall +) : Node() { + val name: String = (x.ctarget as Stg.CCallTarget.StaticTarget).string + private val retType: Stg.PrimRep? = when (type) { + is Stg.Type.UnboxedTuple -> when (type.rep.size) { + 0 -> null + 1 -> type.rep[0] + else -> panic{"foreign call $x $type returning a multiple element unboxed tuple"} + } + is Stg.Type.PolymorphicRep -> panic("foreign call without known return kind") + is Stg.Type.SingleValue -> type.rep + } + + // TODO: should limit be 1? + @Child var interop: InteropLibrary = InteropLibrary.getFactory().createDispatched(1) + + @CompilerDirectives.CompilationFinal var boundFn: Any? = null + + fun execute(xs: Array): Any { + val ys = xs.dropLast(1).toTypedArray() + if (boundFn === null) { + CompilerDirectives.transferToInterpreterAndInvalidate() + + if (name !in realFFICalls) { + realFFICalls.add(name) + println("trufflestg: new real ffi call: $name $x ${xs.contentToString()}") + } + + loadRts() + val lib = (rootNode as ClosureRootNode).module.moduleDir.nativeLib() + + val interopSlow = InteropLibrary.getUncached() + val fn = interopSlow.readMember(lib, name) + + if (xs.last() !is VoidInh) { + panic { "foreign call with non-RealWorld last argument, TODO" } + } + + val tyString = "(${ys.joinToString(",") { getNFIType(it) }}):${retType.asNFIType()}" + boundFn = interopSlow.invokeMember(fn, "bind", tyString) + } + + val r = interop.execute(boundFn, *ys) + return when (retType) { + is Stg.PrimRep.AddrRep -> UnboxedTuple(arrayOf(StgAddr.StgFFIAddr(interop.asPointer(r)))) + is Stg.PrimRep.IntRep -> UnboxedTuple(arrayOf(StgInt(interop.asLong(r)))) + is Stg.PrimRep.WordRep -> UnboxedTuple(arrayOf(StgWord(interop.asLong(r).toULong()))) // TODO: is this right? + else -> { + panic{"foreign call return nyi $x ${xs.contentToString()} $retType"} + } } } } @@ -54,14 +177,40 @@ class TruffleStgExitException(val status: Int) : AbstractTruffleException() { // mutable state var md5: MessageDigest? = null -object GlobalStore { - var GHCConcSignalSignalHandlerStore: Any? = null -} -var globalHeap = ByteArray(0) +val GlobalStore: MutableMap = mutableMapOf() -@OptIn(kotlin.ExperimentalUnsignedTypes::class) -val primFCalls: Map StgPrimOp> = mapOf( +@OptIn(ExperimentalUnsignedTypes::class) +val primFCalls: Map StgPrimOp> = + listOf( + "GHCConcSignalSignalHandlerStore", + "GHCConcWindowsPendingDelaysStore", + "GHCConcWindowsIOManagerThreadStore", + "GHCConcWindowsProddingStore", + "SystemEventThreadEventManagerStore", + "SystemEventThreadIOManagerThreadStore", + "SystemTimerThreadEventManagerStore", + "SystemTimerThreadIOManagerThreadStore", + "LibHSghcFastStringTable", + "LibHSghcPersistentLinkerState", + "LibHSghcInitLinkerDone", + "LibHSghcGlobalDynFlags", + "LibHSghcStaticOptions", + "LibHSghcStaticOptionsReady", + "MaxStoreKey" + ).associate { + ("getOrSet$it") to wrap2 { a: StablePtr, _: VoidInh -> + synchronized(GlobalStore) { + val x = GlobalStore[it] + UnboxedTuple(arrayOf(if (x === null) { + GlobalStore[it] = a + a + } else { + x + })) + } + } + } + mapOf( // TODO: do something safer! (use x to store MessageDigest object?) "__hsbase_MD5Init" to wrap2Boundary { x: StgAddr, y: VoidInh -> if (md5 != null) panic("") @@ -69,63 +218,64 @@ val primFCalls: Map StgPrimOp> = mapOf( y }, "__hsbase_MD5Update" to wrap4 { _: StgAddr, y: StgAddr, z: StgInt, v: VoidInh -> - val c = y.arr.copyOfRange(y.offset, y.offset + z.x.toInt()) + val c = y.getRange(0 until z.x.toInt()) md5!!.update(c) v }, "__hsbase_MD5Final" to wrap3Boundary { out: StgAddr, _: StgAddr, v: VoidInh -> - out.arr.write(out.offset, md5!!.digest()) + out.write(0, md5!!.digest()) md5 = null v }, "errorBelch2" to wrap3Boundary { x: StgAddr, y: StgAddr, v: VoidInh -> // TODO: this is supposed to be printf - System.err.println("errorBelch2: ${x.asArray().asCString()} ${y.asArray().asCString()}") + // TODO: use Language.currentContext().env.{in,out,err} instead of System.blah everywhere + System.err.println("errorBelch2: ${x.asCString()} ${y.asCString()}") v }, "debugBelch2" to wrap3Boundary { x: StgAddr, y: StgAddr, v: VoidInh -> - System.err.println("debugBelch2: ${x.asArray().asCString()} ${y.asArray().asCString()}"); v + System.err.println("debugBelch2: ${x.asCString()} ${y.asCString()}"); v }, "rts_setMainThread" to wrap2 { x: WeakRef, _: VoidInh -> UnboxedTuple(arrayOf()) }, - "getOrSetGHCConcSignalSignalHandlerStore" to wrap2 { a: StablePtr, _: VoidInh -> - synchronized(GlobalStore) { - val x = GlobalStore.GHCConcSignalSignalHandlerStore - UnboxedTuple(arrayOf(if (x == null) { - GlobalStore.GHCConcSignalSignalHandlerStore = a - a - } else { x })) - } - }, + // TODO "isatty" to wrap2 { x: StgInt, _: VoidInh -> UnboxedTuple(arrayOf(StgInt(if (x.x in 0..2) 1L else 0L))) }, - "fdReady" to { object : StgPrimOp(5) { - // FIXME: implement this, might need to use JNI - override fun run(frame: VirtualFrame, args: Array): Any { - val fd = (args[0] as StgInt).x - if (fd == 1L || fd == 0L) return UnboxedTuple(arrayOf(StgInt(1L))) - panic("todo: fdReady ${args[0]}") - } - } }, +// "fdReady" to { object : StgPrimOp(5) { +// // FIXME: implement this, might need to use JNI +// override fun run(frame: VirtualFrame, args: Array): Any { +// val fd = (args[0] as StgInt).x +// if (fd == 1L || fd == 0L) return UnboxedTuple(arrayOf(StgInt(1L))) +// panic("todo: fdReady ${args[0]}") +// } +// } }, "rintDouble" to wrap2 { x: StgDouble, _: VoidInh -> UnboxedTuple(arrayOf(StgDouble(Math.rint(x.x)))) }, "rtsSupportsBoundThreads" to wrap1 { _: VoidInh -> UnboxedTuple(arrayOf(StgInt(0L))) }, - "ghczuwrapperZC20ZCbaseZCSystemziPosixziInternalsZCwrite" to wrap4Boundary { x: StgInt, y: StgAddr, z: StgWord, _: VoidInh -> - // stdout - if (x.x == 1L) { - val s = y.asArray().copyOfRange(0, z.x.toInt()) - print(String(s)) - UnboxedTuple(arrayOf(StgInt(z.x.toLong()))) - } else { - panic("nyi ghczuwrapperZC20ZCbaseZCSystemziPosixziInternalsZCwrite") - } - }, - "ghczuwrapperZC22ZCbaseZCSystemziPosixziInternalsZCread" to wrap4Boundary { x: StgInt, y: StgAddr, z: StgWord, _: VoidInh -> - UnboxedTuple(arrayOf(StgInt(posix.read(x.x.toInt(), y.asBuffer(), z.x.toLong())))) - }, + "initGCStatistics" to wrap1 { v: VoidInh -> v }, + +// "ghczuwrapperZC20ZCbaseZCSystemziPosixziInternalsZCwrite" to wrap4Boundary { x: StgInt, y: StgAddr, z: StgWord, _: VoidInh -> +// // stdout +// if (x.x == 1L) { +// val s = y.getRange(0 until z.x.toInt()) +// print(String(s)) +// UnboxedTuple(arrayOf(StgInt(z.x.toLong()))) +// } else { +// panic("nyi ghczuwrapperZC20ZCbaseZCSystemziPosixziInternalsZCwrite") +// } +// }, +// "ghczuwrapperZC22ZCbaseZCSystemziPosixziInternalsZCread" to wrap4Boundary { x: StgInt, y: StgAddr, z: StgWord, _: VoidInh -> +// UnboxedTuple(arrayOf(StgInt(posix.read(x.x.toInt(), y.asBuffer(), z.x.toLong())))) +// }, + + // TODO + "lockFile" to { object : StgPrimOp(5) { + override fun run(frame: VirtualFrame, args: Array): Any = UnboxedTuple(arrayOf(StgInt(0L))) + } }, + "unlockFile" to wrap2 { x: StgInt, _: VoidInh -> UnboxedTuple(arrayOf(StgInt(0L))) }, "shutdownHaskellAndExit" to wrap3Boundary { x: StgInt, _: StgInt, _: VoidInh -> throw TruffleStgExitException(x.x.toInt()) }, @@ -135,9 +285,9 @@ val primFCalls: Map StgPrimOp> = mapOf( UnboxedTuple(arrayOf()) }, "localeEncoding" to wrap1Boundary { _: VoidInh -> - UnboxedTuple(arrayOf(StgAddr("UTF-8".toByteArray() + zeroBytes, 0))) + UnboxedTuple(arrayOf(StgAddr.fromArray("UTF-8".toByteArray() + zeroBytes))) }, - "stg_sig_install" to wrap4 { x: StgInt, y: StgInt, z: NullAddr, _: VoidInh -> + "stg_sig_install" to wrap4 { x: StgInt, y: StgInt, z: StgAddr, _: VoidInh -> // TODO UnboxedTuple(arrayOf(StgInt(-1))) }, @@ -147,17 +297,23 @@ val primFCalls: Map StgPrimOp> = mapOf( *Language.currentContext().env.applicationArguments ) - val ixs = args.map { a -> - val ix = globalHeap.size - globalHeap += (a.toByteArray() + 0x00) - ix + val bs = args.map { it.toByteArray() + 0x00 } + + val n = bs.sumBy { it.size + 8 } + val addr = unsafe.allocateMemory(n.toLong()) + val buf = newDirectByteBuffer(addr, n) + + val ixs = bs.map { a -> + val ix = buf.position() + buf.put(a) + ix + addr } - val argvIx = globalHeap.size - ixs.forEach { globalHeap += it.toLong().toByteArray() } + val argvIx = buf.position() + addr + ixs.forEach { buf.putLong(it) } - argc.arr.write(argc.offset, ixs.size.toByteArray()) - argv.arr.write(argv.offset, argvIx.toLong().toByteArray()) + argc.writeLong(0, ixs.size.toLong()) + argv.writeLong(0, argvIx) UnboxedTuple(arrayOf()) }, "u_towupper" to wrap2 { x: StgInt, _: VoidInh -> diff --git a/src/trufflestg/jit/StgPrim.kt b/src/trufflestg/jit/StgPrim.kt index 9f66b86..3b937f2 100644 --- a/src/trufflestg/jit/StgPrim.kt +++ b/src/trufflestg/jit/StgPrim.kt @@ -12,7 +12,6 @@ import trufflestg.array_utils.write import trufflestg.data.* import trufflestg.panic import trufflestg.todo -import java.nio.ByteBuffer import kotlin.math.absoluteValue import kotlin.math.pow @@ -32,8 +31,9 @@ class StgPrim( val xs = args(frame) if (opNode != null) { if (args.size != opNode!!.arity) panic{"StgPrim $op bad arity"} +// print("${(rootNode as ClosureRootNode).toString()} $op ${xs.contentToString()} ") val x = opNode!!.run(frame, xs) - // println("$op ${xs.contentToString()} = $x") +// println("= $x") return x } panic{"$op nyi"} @@ -79,11 +79,12 @@ val primOps: Map StgPrimOp> = mapOf( "-#" to wrap2 { x: StgInt, y: StgInt -> StgInt(x.x - y.x) }, "negateInt#" to wrap1 { x: StgInt -> StgInt(-x.x) }, "*#" to wrap2 { x: StgInt, y: StgInt -> StgInt(x.x * y.x) }, - "<=#" to wrap2 { x: StgInt, y: StgInt -> StgInt(if (x.x <= y.x) 1L else 0L) }, ">#" to wrap2 { x: StgInt, y: StgInt -> StgInt(if (x.x > y.x) 1L else 0L) }, - "<#" to wrap2 { x: StgInt, y: StgInt -> StgInt(if (x.x < y.x) 1L else 0L) }, ">=#" to wrap2 { x: StgInt, y: StgInt -> StgInt(if (x.x >= y.x) 1L else 0L) }, "==#" to wrap2 { x: StgInt, y: StgInt -> StgInt(if (x.x == y.x) 1L else 0L) }, + "/=#" to wrap2 { x: StgInt, y: StgInt -> StgInt(if (x.x != y.x) 1L else 0L) }, + "<#" to wrap2 { x: StgInt, y: StgInt -> StgInt(if (x.x < y.x) 1L else 0L) }, + "<=#" to wrap2 { x: StgInt, y: StgInt -> StgInt(if (x.x <= y.x) 1L else 0L) }, "andI#" to wrap2 { x: StgInt, y: StgInt -> StgInt(x.x and y.x) }, "orI#" to wrap2 { x: StgInt, y: StgInt -> StgInt(x.x or y.x) }, @@ -106,9 +107,9 @@ val primOps: Map StgPrimOp> = mapOf( ), StgInt(hi), StgInt(lo))) }, - "quotRemInt#" to wrap2 { x: StgInt, y: StgInt -> - UnboxedTuple(arrayOf(StgInt(x.x / y.x), StgInt(x.x % y.x))) - }, + "quotInt#" to wrap2 { x: StgInt, y: StgInt -> StgInt(x.x / y.x) }, + "remInt#" to wrap2 { x: StgInt, y: StgInt -> StgInt(x.x % y.x) }, + "quotRemInt#" to wrap2 { x: StgInt, y: StgInt -> UnboxedTuple(arrayOf(StgInt(x.x / y.x), StgInt(x.x % y.x))) }, "newMVar#" to wrap1 { _: VoidInh -> UnboxedTuple(arrayOf(StgMVar(false, null))) }, "putMVar#" to wrap3 { x: StgMVar, y: Any, z: VoidInh -> @@ -127,8 +128,22 @@ val primOps: Map StgPrimOp> = mapOf( "newArray#" to wrap3 { x: StgInt, y: Any, _: VoidInh -> UnboxedTuple(arrayOf(StgArray(Array(x.toInt()) { y }))) }, "readArray#" to wrap3 { x: StgArray, y: StgInt, _: VoidInh -> UnboxedTuple(arrayOf(x[y])) }, "writeArray#" to wrap4 { x: StgArray, y: StgInt, z: Any, w: VoidInh -> x[y] = z; w }, + "sizeofArray#" to wrap1 { x: StgArray -> StgInt(x.arr.size.toLong()) }, + "sizeofMutableArray#" to wrap1 { x: StgArray -> StgInt(x.arr.size.toLong()) }, + "indexArray#" to wrap2 { x: StgArray, y: StgInt -> UnboxedTuple(arrayOf(x[y])) }, + "unsafeFreezeArray#" to wrap2 { arr: StgArray, _: VoidInh -> UnboxedTuple(arrayOf(arr)) }, + + + "eqAddr#" to wrap2 { x: Any, y: Any -> + StgInt(if (when (x) { + is StgAddr -> y is StgAddr && x == y + is StablePtr -> y is StablePtr && x === y + else -> panic{"eqAddr# of $x and $y"} + }) 1L else 0L) + }, - "eqAddr#" to wrap2 { x: Any, y: Any -> StgInt(if (x === y) 1L else 0L) }, + // TODO: is this sufficient? + "reallyUnsafePtrEquality#" to wrap2 { x: Any, y: Any -> StgInt(if (x === y) 1L else 0L) }, "myThreadId#" to wrap1 { _: VoidInh -> UnboxedTuple(arrayOf(ThreadId(Thread.currentThread().id))) }, @@ -156,6 +171,7 @@ val primOps: Map StgPrimOp> = mapOf( "int2Word#" to wrap1 { x: StgInt -> StgWord(x.x.toULong()) }, "word2Int#" to wrap1 { x: StgWord -> StgInt(x.x.toLong()) }, "narrow8Word#" to wrap1 { x: StgWord -> StgWord(x.x.toUByte().toULong()) }, + "narrow32Word#" to wrap1 { x: StgWord -> StgWord(x.x.toUInt().toULong()) }, "narrow32Int#" to wrap1 { x: StgInt -> StgInt(x.x.toInt().toLong()) }, "gtWord#" to wrap2 { x: StgWord, y: StgWord -> StgInt(if (x.x > y.x) 1L else 0L) }, @@ -165,6 +181,8 @@ val primOps: Map StgPrimOp> = mapOf( "ltWord#" to wrap2 { x: StgWord, y: StgWord -> StgInt(if (x.x < y.x) 1L else 0L) }, "leWord#" to wrap2 { x: StgWord, y: StgWord -> StgInt(if (x.x <= y.x) 1L else 0L) }, + "clz#" to wrap1 { x: StgWord -> StgWord(java.lang.Long.numberOfLeadingZeros(x.x.toLong()).toULong()) }, + "plusWord#" to wrap2 { x: StgWord, y: StgWord -> StgWord(x.x + y.x) }, "minusWord#" to wrap2 { x: StgWord, y: StgWord -> StgWord(x.x - y.x) }, @@ -261,7 +279,7 @@ val primOps: Map StgPrimOp> = mapOf( } } // TODO: profile which cons we've seen? - // TODO: do a virtual call to getInfo().tag for big types? + // TODO: do a virtual call to getInfo().tag / getTag() for big types? ty.cons.forEachIndexed { ix, c -> if (c.size != 0) { if (c.tryIs(x) !== null) { return StgInt(ix.toLong()) } @@ -271,40 +289,46 @@ val primOps: Map StgPrimOp> = mapOf( } } }, + // TODO: make all prims & fcalls taking a StgAddr have a cache or such + // should decide if i want to use TruffleLibrary or do it myself + // maybe just use is & as to impl the prims? + // just reads a byte - // TODO: make sure this is right (maybe should be unsigned)? - "indexCharOffAddr#" to wrap2 { x: StgAddr, y: StgInt -> - val off = x.offset + y.x.toInt() - if (off >= x.arr.size) { StgWord(0UL) } - else StgWord(x.arr[off].toUByte().toULong()) }, + "indexCharOffAddr#" to wrap2 { x: StgAddr, y: StgInt -> when (x) { + is StgAddr.StgArrayOffsetAddr -> { + val off = x.offset + y.x.toInt() + if (off >= x.arr.size) { StgWord(0UL) } + else StgWord(x.arr[off].toUByte().toULong()) + } + is StgAddr.StgFFIAddr -> StgWord(x.readByte(y.toInt()).toUByte().toULong()) + } }, "readInt8OffAddr#" to wrap3 { x: StgAddr, y: StgInt, _: VoidInh -> - UnboxedTuple(arrayOf(StgInt(x[y.toInt()].toLong()))) }, + UnboxedTuple(arrayOf(StgInt(x.readByte(y.toInt()).toLong()))) }, "readInt32OffAddr#" to wrap3 { x: StgAddr, y: StgInt, _: VoidInh -> - UnboxedTuple(arrayOf(StgInt(ByteBuffer.wrap(x.arr).getInt(x.offset+4*y.toInt()).toLong()))) }, + UnboxedTuple(arrayOf(StgInt(x.readInt(4*y.toInt()).toLong()))) }, "readWord8OffAddr#" to wrap3 { x: StgAddr, y: StgInt, _: VoidInh -> - UnboxedTuple(arrayOf(StgWord(x[y.toInt()].toULong()))) }, - "readAddrOffAddr#" to wrap3 { x: StgAddr, y: StgInt, _: VoidInh -> - val l = ByteBuffer.wrap(x.arr).getLong(x.offset + 8*y.toInt()).toInt() - // FIXME: need to not have globalHeap reallocate... - val w = StgAddr(globalHeap, l) - UnboxedTuple(arrayOf(w)) - }, - "writeWord8OffAddr#" to wrap4 { x: StgAddr, y: StgInt, z: StgWord, v: VoidInh -> x[y] = z.x.toByte(); v }, + UnboxedTuple(arrayOf(StgWord(x.readByte(y.toInt()).toUByte().toULong()))) }, + "readWord32OffAddr#" to wrap3 { x: StgAddr, y: StgInt, _: VoidInh -> + UnboxedTuple(arrayOf(StgWord(x.readInt(4*y.toInt()).toUInt().toULong()))) }, + + "writeWord8OffAddr#" to wrap4 { x: StgAddr, y: StgInt, z: StgWord, v: VoidInh -> x.writeByte(y.toInt(), z.x.toByte()); v }, + + "readAddrOffAddr#" to wrap3 { x: StgAddr, y: StgInt, _: VoidInh -> UnboxedTuple(arrayOf(StgAddr.StgFFIAddr(x.readLong(8*y.toInt())))) }, "readWideCharOffAddr#" to wrap3 { x: StgAddr, y: StgInt, _: VoidInh -> - UnboxedTuple(arrayOf(StgWord(ByteBuffer.wrap(x.arr).getInt(x.offset+4*y.toInt()).toUInt().toULong()))) }, + UnboxedTuple(arrayOf(StgWord(x.readInt(4*y.toInt()).toUInt().toULong()))) }, "writeWideCharOffAddr#" to wrap4 { x: StgAddr, y: StgInt, z: StgWord, v: VoidInh -> - x.arr.write(x.offset + 4*y.x.toInt(), z.x.toUInt()); v }, + x.writeInt(4*y.toInt(), z.x.toInt()); v }, - "writeWordArray#" to wrap4 { x: StgMutableByteArray, y: StgInt, z: StgWord, v: VoidInh -> - x.arr.write(8*y.toInt(), z.x.toLong()); v }, + "writeWordArray#" to wrap4 { x: StgByteArray, y: StgInt, z: StgWord, v: VoidInh -> + x.asAddr().writeLong(8*y.toInt(), z.x.toLong()); v }, - "indexWordArray#" to wrap2 { x: StgMutableByteArray, y: StgInt -> - StgWord(x.asBuffer().getLong(y.toInt()*8).toULong()) + "indexWordArray#" to wrap2 { x: StgByteArray, y: StgInt -> + StgWord(x.asAddr().readLong(y.toInt()*8).toULong()) }, - "plusAddr#" to wrap2 { x: StgAddr, y: StgInt -> StgAddr(x.arr, x.offset + y.x.toInt()) }, + "plusAddr#" to wrap2 { x: StgAddr, y: StgInt -> x.addOffset(y.toInt()) }, // TODO: make sure these are right "uncheckedIShiftL#" to wrap2 { x: StgInt, y: StgInt -> StgInt(x.x shl y.x.toInt()) }, @@ -341,24 +365,29 @@ val primOps: Map StgPrimOp> = mapOf( }, "newByteArray#" to wrap2 { x: StgInt, _: VoidInh -> - UnboxedTuple(arrayOf(StgMutableByteArray(ByteArray(x.toInt())))) + UnboxedTuple(arrayOf(StgByteArray.StgJvmByteArray(ByteArray(x.toInt())))) }, - // TODO: actually pin & align? only matters if we want to use native code - // FIXME: or for readAddrOffAddr, or if casted to Addr? + // TODO: possibly could avoid pinning here (do it lazily) when the jvm supports pinning itself + // FIXME: also might need to pin for readAddrOffAddr, or if casted to Addr? "newAlignedPinnedByteArray#" to wrap3 { x: StgInt, alignment: StgInt, _: VoidInh -> - UnboxedTuple(arrayOf(StgMutableByteArray(ByteArray(x.toInt())))) }, + UnboxedTuple(arrayOf(StgByteArray.StgPinnedByteArray(x.x))) }, "newPinnedByteArray#" to wrap2 { x: StgInt, _: VoidInh -> - UnboxedTuple(arrayOf(StgMutableByteArray(ByteArray(x.toInt())))) }, - "unsafeFreezeByteArray#" to wrap2 { arr: StgMutableByteArray, _: VoidInh -> - arr.frozen = true - UnboxedTuple(arrayOf(arr)) }, - "byteArrayContents#" to wrap1 { arr: StgMutableByteArray -> StgAddr(arr.arr, 0) }, + UnboxedTuple(arrayOf(StgByteArray.StgPinnedByteArray(x.x))) }, + "unsafeFreezeByteArray#" to wrap2 { arr: StgByteArray, _: VoidInh -> UnboxedTuple(arrayOf(arr)) }, + "byteArrayContents#" to wrap1 { arr: StgByteArray -> when (arr) { + is StgByteArray.StgJvmByteArray -> StgAddr.StgArrayOffsetAddr(arr.arr, 0) + is StgByteArray.StgPinnedByteArray -> StgAddr.StgFFIAddr(arr.addr) + } }, - "sizeofByteArray#" to wrap1 { x: StgMutableByteArray -> StgInt(x.arr.size.toLong()) } + "sizeofByteArray#" to wrap1 { x: StgByteArray -> StgInt(when (x) { + is StgByteArray.StgJvmByteArray -> x.arr.size.toLong() + is StgByteArray.StgPinnedByteArray -> x.len + }) } ) // this makes a class per call site +// TODO: maybe use castExact here? inline fun wrap1(crossinline f: (X) -> Y): () -> StgPrimOp = { object : StgPrimOp1() { override fun execute(x: Any): Any = f((x as? X)!!) } } diff --git a/src/trufflestg/jit/code.kt b/src/trufflestg/jit/code.kt index e1a86c4..3e63f4b 100644 --- a/src/trufflestg/jit/code.kt +++ b/src/trufflestg/jit/code.kt @@ -18,6 +18,7 @@ import com.oracle.truffle.api.profiles.BranchProfile import com.oracle.truffle.api.source.SourceSection import trufflestg.array_utils.map import trufflestg.data.DataTypes +import java.lang.reflect.Modifier // utility @Suppress("NOTHING_TO_INLINE") @@ -153,11 +154,13 @@ abstract class CaseAlts : Node() { else -> panic{"case prim of $ty"} } + val exactType: Boolean = expectedType.modifiers and (Modifier.ABSTRACT or Modifier.INTERFACE) == 0 + init { assert(alts.all { expectedType.isInstance(it) }) } @ExplodeLoop override fun execute(frame: VirtualFrame, x: Any?): Any? { - val x2 = CompilerDirectives.castExact(x, expectedType) + val x2 = if (exactType) CompilerDirectives.castExact(x, expectedType) else expectedType.cast(x) alts.forEachIndexed { ix, y -> if (x2 == y) { profiles[ix].enter() diff --git a/src/trufflestg/jit/dispatch.kt b/src/trufflestg/jit/dispatch.kt index b319e7f..4c97225 100644 --- a/src/trufflestg/jit/dispatch.kt +++ b/src/trufflestg/jit/dispatch.kt @@ -58,13 +58,10 @@ class CallWhnf(@JvmField val argsSize: Int, val tail_call: Boolean): Node() { val f = if (fn is Thunk) { thunkProfile.enter() // TODO: concurrency (blackholes/synchronization) - if (seenThunkValue) { + if (!seenThunkClosure) { val v = fn.value_ if (v === null) { - if (!seenThunkClosure) { - invalidate(); seenThunkClosure = true - reportPolymorphicSpecialize() - } + invalidate(); seenThunkClosure = true; reportPolymorphicSpecialize() val c = fn.expectClosure() fn.clos = null val x = thunkDispatch.run(frame, c, arrayOf()) @@ -74,7 +71,7 @@ class CallWhnf(@JvmField val argsSize: Int, val tail_call: Boolean): Node() { } else { val c = fn.clos if (c === null) { - invalidate(); seenThunkValue = true + if (!seenThunkValue) { invalidate(); seenThunkValue = true } fn.expectValue() } else { fn.clos = null diff --git a/src/trufflestg/language.kt b/src/trufflestg/language.kt index f4082ad..6ae46cb 100644 --- a/src/trufflestg/language.kt +++ b/src/trufflestg/language.kt @@ -82,7 +82,8 @@ private fun toString(value: Any?): String = defaultMimeType = LANGUAGE_MIME_TYPE, characterMimeTypes = [LANGUAGE_MIME_TYPE], contextPolicy = ContextPolicy.SHARED, - fileTypeDetectors = [Language.Detector::class] + fileTypeDetectors = [Language.Detector::class], + dependentLanguages = ["nfi"] ) @ProvidedTags( CallTag::class, StatementTag::class, RootTag::class, RootBodyTag::class, ExpressionTag::class, diff --git a/src/trufflestg/stg/elab.kt b/src/trufflestg/stg/elab.kt index 7aafa8a..ff3c0b0 100644 --- a/src/trufflestg/stg/elab.kt +++ b/src/trufflestg/stg/elab.kt @@ -1,6 +1,7 @@ // convert from ghc's stg rep to a truffle tree a function at a time package trufflestg.stg +import com.oracle.truffle.api.CompilerDirectives import trufflestg.Language import trufflestg.data.* import trufflestg.panic @@ -39,14 +40,14 @@ fun Stg.Lit.compile(): Any = when (this) { is Stg.Lit.LitDouble -> StgDouble(x.x.toDouble() / x.y) // TODO: i guess???? is Stg.Lit.LitFloat -> TODO() is Stg.Lit.LitLabel -> TODO() - is Stg.Lit.LitNullAddr -> NullAddr + is Stg.Lit.LitNullAddr -> StgAddr.nullAddr is Stg.Lit.LitNumber -> when (x) { Stg.LitNumType.LitNumInt -> StgInt(y.toLong()) Stg.LitNumType.LitNumInt64 -> TODO() Stg.LitNumType.LitNumWord -> StgWord(y.toLong().toULong()) Stg.LitNumType.LitNumWord64 -> TODO() } - is Stg.Lit.LitString -> StgAddr(x, 0) + is Stg.Lit.LitString -> StgAddr.fromArray(x) } fun Stg.Expr.compile(ci: CompileInfo, fd: FrameDescriptor, tc: Boolean): Code = when(this) { @@ -110,7 +111,7 @@ fun Stg.Expr.compile(ci: CompileInfo, fd: FrameDescriptor, tc: Boolean): Code = } is Stg.Expr.StgLit -> Code.Lit(x.compile()) is Stg.Expr.StgOpApp -> when (op) { - is Stg.StgOp.StgFCallOp -> StgFCall(op.x, map(args) { it.compile(ci, fd) }) + is Stg.StgOp.StgFCallOp -> StgFCall(type, op.x, map(args) { it.compile(ci, fd) }) is Stg.StgOp.StgPrimCallOp -> StgPrimCall(op.x, map(args) { it.compile(ci, fd) }) is Stg.StgOp.StgPrimOp -> StgPrim(ci.module.tyCons[tn.orElse(null)], op.x, map(args) { it.compile(ci, fd) }) } @@ -273,7 +274,7 @@ sealed class TopLevel( binder: Stg.SBinder, string: ByteArray ) : TopLevel(binder, module) { - val x = StgAddr(string, 0) + val x = StgAddr.fromArray(string) override fun getValue(): Any = x } class DataCon( @@ -324,15 +325,20 @@ class CborModuleDir( loadedModules[name] = m return m } + + @CompilerDirectives.CompilationFinal private var nativeLib: Any? = null + fun nativeLib(): Any { + var x = nativeLib + if (x === null) { + CompilerDirectives.transferToInterpreterAndInvalidate() + val p = path.substringBeforeLast("stg/") + "libs.so" + x = dlopen(p) + nativeLib = x + } + return x + } } -// contents of magic GHC.Prim module -val prims: Map = mapOf( - "void#" to VoidInh, - "realWorld#" to VoidInh, - "coercionToken#" to VoidInh, -// "void#" to -) class Module( val language: Language, @@ -391,7 +397,10 @@ class Module( val n = x.second.name prims[x.second.name] ?: TODO("GHC.Prim.$n not implemented: $x") } else { - moduleDir[x.first.second.x]!![x.second.name]!! + val fn = FullName(x.first.first.x, x.first.second.x, x.second.name) +// if (fn.module == "System.Environment.ExecutablePath") println(fn) + if (fn in overrides) overrides[fn]!! + else moduleDir[x.first.second.x]!![x.second.name]!! } } id in top_bindings -> top_bindings[id]!!.getValue() @@ -405,3 +414,16 @@ class Module( } +// contents of magic GHC.Prim module +val prims: Map = mapOf( + "void#" to VoidInh, + "realWorld#" to VoidInh, + "coercionToken#" to VoidInh, +// "void#" to +) + +val overrides: Map = mapOf( +// FullName("base", "System.Environment.ExecutablePath", "getExecutablePath") to null +) + +