Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use continuation monad to avoid nesting #19863

Draft
wants to merge 2 commits into
base: main-2.x
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/daml-lf/interpreter/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ da_scala_library(
srcs = glob(["src/main/**/*.scala"]),
scala_deps = [
"@maven//:io_spray_spray_json",
"@maven//:org_typelevel_cats_core",
"@maven//:org_scalaz_scalaz_core",
"@maven//:org_typelevel_paiges_core",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ import scala.annotation.nowarn
import scala.collection.immutable.TreeSet
import scala.jdk.CollectionConverters._
import scala.math.Ordering.Implicits.infixOrderingOps
import cats.data.ContT
import cats.Defer

/** Speedy builtins represent LF functional forms. As such, they *always* have a non-zero arity.
*
Expand Down Expand Up @@ -133,6 +135,9 @@ private[speedy] sealed abstract class ScenarioBuiltin(arity: Int)

private[lf] object SBuiltin {

def executeExpressionK(machine: UpdateMachine, expr: SExpr): Cont[SValue] =
Cont.wrap1(executeExpression(machine, expr))

def executeExpression[Q](machine: Machine[Q], expr: SExpr)(
f: SValue => Control[Q]
): Control[Q] = {
Expand Down Expand Up @@ -1130,17 +1135,24 @@ private[lf] object SBuiltin {
): Boolean =
getInterfaceInstance(machine, interfaceId, templateId).nonEmpty

private[this] def ensureTemplateImplementsInterfaceK[Q](
machine: Machine[_],
ifaceId: TypeConName,
coid: V.ContractId,
tplId: TypeConName,
): Cont[Unit] = Cont.wrap0(ensureTemplateImplementsInterface(machine, ifaceId, coid, tplId))

// Precondition: the package of tplId is loaded in the machine
private[this] def ensureTemplateImplementsInterface[Q](
machine: Machine[_],
ifaceId: TypeConName,
coid: V.ContractId,
tplId: TypeConName,
)(k: => Control[Q]): Control[Q] = {
)(k: () => Control[Q]): Control[Q] = {
if (!interfaceInstanceExists(machine, ifaceId, tplId)) {
Control.Error(IE.ContractDoesNotImplementInterface(ifaceId, coid, tplId))
} else {
k
k()
}
}

Expand Down Expand Up @@ -1173,6 +1185,22 @@ private[lf] object SBuiltin {
}
}

private[this] implicit object ControlDeferInstance extends Defer[Control] {
override def defer[A](fa: => Control[A]): Control[A] = fa
}

type Cont[A] = ContT[Control, Question.Update, A]
object Cont {
def pure[A](a: A): Cont[A] = ContT.pure(a)
def wrap0(f: (() => Control[Question.Update]) => Control[Question.Update]): Cont[Unit] =
ContT(k => f(() => k(())))
def wrap1[A](f: (A => Control[Question.Update]) => Control[Question.Update]): Cont[A] = ContT(f)
def wrap2[A, B](
f: ((A, B) => Control[Question.Update]) => Control[Question.Update]
): Cont[(A, B)] = ContT(k => f((a, b) => k((a, b))))
def throwError[A](err: IE): Cont[A] = Cont.wrap1(_ => Control.Error(err))
}

/** Fetches the requested contract ID, casts its to the requested interface, computes its view and returns it (via the
* continuation) as an SAny.
*/
Expand All @@ -1181,17 +1209,15 @@ private[lf] object SBuiltin {
coid: V.ContractId,
ifaceId: TypeConName,
)(k: SAny => Control[Question.Update]): Control[Question.Update] = {
fetchAny(machine, None, coid, SValue.SValue.None) { (_, srcContract) =>
val (tplId, arg) = getSAnyContract(ArrayList.single(srcContract), 0)
ensureTemplateImplementsInterface(machine, ifaceId, coid, tplId) {
viewInterface(machine, ifaceId, tplId, arg) { srcView =>
executeExpression(machine, SEPreventCatch(srcView)) { _ =>
k(SAny(Ast.TTyCon(tplId), arg))
}
}
}
}
}
for {
pkgNameSrcContract <- fetchAnyK(machine, None, coid, SValue.SValue.None)
(_, srcContract) = pkgNameSrcContract
(tplId, arg) = getSAnyContract(ArrayList.single(srcContract), 0)
_ <- ensureTemplateImplementsInterfaceK(machine, ifaceId, coid, tplId)
srcView <- viewInterfaceK(machine, ifaceId, tplId, arg)
_ <- executeExpressionK(machine, SEPreventCatch(srcView))
} yield SAny(Ast.TTyCon(tplId), arg)
}.run(k)

/** Fetches the requested contract ID, upgrades it to the preferred template version for the same package name,
* and compares the computed views according to the old and the new versions. If the two views agree then caches
Expand All @@ -1202,83 +1228,100 @@ private[lf] object SBuiltin {
coid: V.ContractId,
interfaceId: TypeConName,
)(k: SAny => Control[Question.Update]): Control[Question.Update] = {

// Continuation called by two different branches of the expression below. Factorized out to avoid duplication.
def cacheContractAndReturnAny(
machine: UpdateMachine,
coid: V.ContractId,
dstTplId: Ref.ValueRef,
dstArg: SValue,
)(k: SAny => Control[Question.Update]): Control[Question.Update] = {
): Cont[SAny] = for {
// ensure the contract and its metadata are cached
getContractInfo(
_ <- getContractInfoK(
machine,
coid,
dstTplId,
dstArg,
SValue.SValue.None,
) { _ =>
k(SAny(Ast.TTyCon(dstTplId), dstArg))
}
}
)
} yield SAny(Ast.TTyCon(dstTplId), dstArg)

fetchAny(machine, None, coid, SValue.SValue.None) { (maybePkgName, srcContract) =>
maybePkgName match {
case None =>
crash(s"unexpected contract instance without packageName")
case Some(pkgName) =>
val (srcTplId, srcArg) = getSAnyContract(ArrayList.single(srcContract), 0)
ensureTemplateImplementsInterface(machine, interfaceId, coid, srcTplId) {
viewInterface(machine, interfaceId, srcTplId, srcArg) { srcView =>
resolvePackageName(machine, pkgName) { pkgId =>
val dstTplId = srcTplId.copy(packageId = pkgId)
machine.ensurePackageIsLoaded(
dstTplId.packageId,
language.Reference.Template(dstTplId),
) { () =>
ensureTemplateImplementsInterface(machine, interfaceId, coid, dstTplId) {
fromInterface(machine, srcTplId, srcArg, dstTplId) {
case None =>
Control.Error(IE.WronglyTypedContract(coid, dstTplId, srcTplId))
case Some(dstArg) =>
viewInterface(machine, interfaceId, dstTplId, dstArg) { dstView =>
executeExpression(machine, SEPreventCatch(srcView)) { srcViewValue =>
// If the destination and src templates are the same, we skip the computation
// of the destination template's view.
if (dstTplId == srcTplId)
cacheContractAndReturnAny(machine, coid, dstTplId, dstArg)(k)
else
executeExpression(machine, SEPreventCatch(dstView)) { dstViewValue =>
if (srcViewValue != dstViewValue) {
Control.Error(
IE.Dev(
NameOf.qualifiedNameOfCurrentFunc,
IE.Dev.Upgrade(
IE.Dev.Upgrade.ViewMismatch(
coid,
interfaceId,
srcTplId,
dstTplId,
srcView = srcViewValue.toUnnormalizedValue,
dstView = dstViewValue.toUnnormalizedValue,
)
),
{
for {
mbPkgNameSrcContract <- fetchAnyK(machine, None, coid, SValue.SValue.None)
(maybePkgName, srcContract) = mbPkgNameSrcContract
res <- maybePkgName match {
case None =>
crash(s"unexpected contract instance without packageName")
case Some(pkgName) =>
val (srcTplId, srcArg) = getSAnyContract(ArrayList.single(srcContract), 0)
for {
_ <- ensureTemplateImplementsInterfaceK(machine, interfaceId, coid, srcTplId)
srcView <- viewInterfaceK(machine, interfaceId, srcTplId, srcArg)
pkgId <- resolvePackageNameK(machine, pkgName)
dstTplId = srcTplId.copy(packageId = pkgId)
_ <- ensurePackageIsLoadedK(
machine,
dstTplId.packageId,
language.Reference.Template(dstTplId),
)
_ <- ensureTemplateImplementsInterfaceK(machine, interfaceId, coid, dstTplId)
mbDstArg <- fromInterfaceK(machine, srcTplId, srcArg, dstTplId)
res <- mbDstArg match {
case None =>
Cont.throwError(IE.WronglyTypedContract(coid, dstTplId, srcTplId))
case Some(dstArg) =>
for {
dstView <- viewInterfaceK(machine, interfaceId, dstTplId, dstArg)
srcViewValue <- executeExpressionK(machine, SEPreventCatch(srcView))
res <-
if (dstTplId == srcTplId)
cacheContractAndReturnAny(machine, coid, dstTplId, dstArg)
else
for {
dstViewValue <- executeExpressionK(machine, SEPreventCatch(dstView))
res <-
if (srcViewValue != dstViewValue)
Cont.throwError(
IE.Dev(
NameOf.qualifiedNameOfCurrentFunc,
IE.Dev.Upgrade(
IE.Dev.Upgrade.ViewMismatch(
coid,
interfaceId,
srcTplId,
dstTplId,
srcView = srcViewValue.toUnnormalizedValue,
dstView = dstViewValue.toUnnormalizedValue,
)
)
} else
cacheContractAndReturnAny(machine, coid, dstTplId, dstArg)(k)
}
}
}
}
}
}
),
)
)
else
cacheContractAndReturnAny(machine, coid, dstTplId, dstArg)
} yield res
} yield res
}
}
}
}
}
} yield res
}
} yield res
}.run(k)

}

private[this] def ensurePackageIsLoadedK(
machine: UpdateMachine,
packageId: PackageId,
ref: => language.Reference,
): Cont[Unit] =
Cont.wrap0(machine.ensurePackageIsLoaded(packageId, ref))

private[this] def resolvePackageNameK[Q](
machine: UpdateMachine,
pkgName: Ref.PackageName,
): Cont[PackageId] =
Cont.wrap1(resolvePackageName(machine, pkgName))

private[this] def resolvePackageName[Q](machine: UpdateMachine, pkgName: Ref.PackageName)(
k: PackageId => Control[Q]
): Control[Q] = {
Expand Down Expand Up @@ -1449,6 +1492,13 @@ private[lf] object SBuiltin {
}
}

private[this] def fromInterfaceK(
machine: UpdateMachine,
srcTplId: TypeConName,
srcArg: SRecord,
dstTplId: TypeConName,
): Cont[Option[SValue]] = Cont.wrap1(fromInterface(machine, srcTplId, srcArg, dstTplId))

private[this] def fromInterface[Q](
machine: Machine[Q],
srcTplId: TypeConName,
Expand Down Expand Up @@ -1575,6 +1625,13 @@ private[lf] object SBuiltin {
}
}

private[this] def viewInterfaceK[Q](
machine: Machine[_],
ifaceId: TypeConName,
templateId: TypeConName,
record: SValue,
): Cont[SExpr] = Cont.wrap1(viewInterface(machine, ifaceId, templateId, record))

private[this] def viewInterface[Q](
machine: Machine[_],
ifaceId: TypeConName,
Expand Down Expand Up @@ -2368,6 +2425,14 @@ private[lf] object SBuiltin {
}
}

private def fetchAnyK(
machine: UpdateMachine,
optTargetTemplateId: Option[TypeConName],
coid: V.ContractId,
keyOpt: SValue,
): Cont[(Option[Ref.PackageName], SValue)] =
Cont.wrap2(fetchAny(machine, optTargetTemplateId, coid, keyOpt))

// This is the core function which fetches a contract given it's coid.
// Regardless of it being a local, disclosed or global contract
private def fetchAny(
Expand Down Expand Up @@ -2486,6 +2551,15 @@ private[lf] object SBuiltin {
}
}

private def getContractInfoK(
machine: UpdateMachine,
coid: V.ContractId,
templateId: Identifier,
templateArg: SValue,
keyOpt: SValue,
): Cont[ContractInfo] =
Cont.wrap1(getContractInfo(machine, coid, templateId, templateArg, keyOpt))

// Get the contract info for a contract, computing if not in our cache
private def getContractInfo(
machine: UpdateMachine,
Expand Down
Loading