Skip to content

Commit

Permalink
First draft for inlining pallas wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertMensing committed Jan 7, 2025
1 parent edf1c41 commit 828e10a
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 124 deletions.
18 changes: 14 additions & 4 deletions src/col/vct/col/ast/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ final class Procedure[G](
val inline: Boolean = false,
val pure: Boolean = false,
val vesuv_entry: Boolean = false,
val pallasWrapper: Boolean = false,
)(val blame: Blame[CallableFailure])(implicit val o: Origin)
extends GlobalDeclaration[G] with AbstractMethod[G] with ProcedureImpl[G]
@scopes[LabelDecl]
Expand Down Expand Up @@ -3530,9 +3531,6 @@ final class LLVMFunctionDefinition[G](
// If this function is a wrapper function for an expression of a
// pallas specification of a function F, then this field references F.
val pallasExprWrapperFor: Option[Ref[G, LLVMFunctionDefinition[G]]],
// Indicates that a new argument has to be added to pass the value for \result
// to the expression-wrapper.
val needsWrapperResultArg: Boolean = false,
)(val blame: Blame[CallableFailure])(implicit val o: Origin)
extends LLVMCallable[G]
with Applicable[G]
Expand Down Expand Up @@ -3673,10 +3671,22 @@ final class LLVMGlobalSpecification[G](val value: String)(
var data: Option[Seq[GlobalDeclaration[G]]] = None
}

// Node that represents the \result-construct in a Pallas contract.
/*
Nodes that represents the \result-construct in a Pallas contract.
- Keep a reference to the function whose result they represent.
- The LLVMResult-Node is the node that is generated in the C++-part
of Pallas. In LangLLVMToCol, it is transformed into an
LLVMIntermediaryResult, which in turn is converted when the Pallas
expression-wrappers are inlined.
*/
final case class LLVMResult[G](func: Ref[G, LLVMFunctionDefinition[G]])(
implicit val o: Origin
) extends LLVMExpr[G] with LLVMResultImpl[G]
final case class LLVMIntermediaryResult[G](
applicable: Ref[G, Procedure[G]],
sretArg: Option[Ref[G, Variable[G]]],
)(implicit val o: Origin)
extends LLVMExpr[G] with LLVMIntermediaryResultImpl[G]

@family
sealed trait LLVMMemoryOrdering[G]
Expand Down
19 changes: 19 additions & 0 deletions src/col/vct/col/ast/lang/llvm/LLVMIntermediaryResultImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package vct.col.ast.lang.llvm

import vct.col.ast.node.NodeFamilyImpl
import vct.col.ast.ops.LLVMIntermediaryResultOps
import vct.col.ast.{LLVMIntermediaryResult, Type}
import vct.col.print.Precedence

trait LLVMIntermediaryResultImpl[G]
extends NodeFamilyImpl[G] with LLVMIntermediaryResultOps[G] {
this: LLVMIntermediaryResult[G] =>

override def t: Type[G] =
sretArg match {
case Some(rArg) => rArg.decl.t
case None => applicable.decl.returnType
}

override def precedence: Int = Precedence.ATOMIC
}
7 changes: 6 additions & 1 deletion src/col/vct/col/ast/lang/llvm/LLVMResultImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ import vct.col.print.Precedence

trait LLVMResultImpl[G] extends NodeFamilyImpl[G] with LLVMResultOps[G] {
this: LLVMResult[G] =>
override def t: Type[G] = func.decl.returnType

override def t: Type[G] =
func.decl.returnInParam match {
case Some((_, t)) => t
case None => func.decl.returnType
}

override def precedence: Int = Precedence.ATOMIC
}
1 change: 1 addition & 0 deletions src/col/vct/col/typerules/CoercingRewriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2148,6 +2148,7 @@ abstract class CoercingRewriter[Pre <: Generation]()
case LLVMRawVectorValue(_, _) => e
case LLVMZeroedAggregateValue(_) => e
case LLVMResult(_) => e
case LLVMIntermediaryResult(_, _) => e
case PVLEndpointExpr(_, _) => e
case EndpointExpr(ref, expr) => e
case ChorExpr(expr) => ChorExpr(bool(expr))
Expand Down
4 changes: 0 additions & 4 deletions src/llvm/lib/Passes/Function/FunctionDeclarer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ FDResult FunctionDeclarer::run(Function &F, FunctionAnalysisManager &FAM) {
pallas::ErrorReporter::addError(SOURCE_LOC, errorStream.str(), F);
}

llvmFuncDef->set_needs_wrapper_result_arg(false);
if (utils::isPallasExprWrapper(F)) {
auto mapperResult = FAM.getResult<pallas::ExprWrapperMapper>(F);
auto *wrapperParent =
Expand All @@ -121,9 +120,6 @@ FDResult FunctionDeclarer::run(Function &F, FunctionAnalysisManager &FAM) {
auto colParent = FAM.getResult<FunctionDeclarer>(*wrapperParent);
llvmFuncDef->mutable_pallas_expr_wrapper_for()->set_id(
colParent.getFunctionId());
if (mapperResult.getContext() == PallasWrapperContext::FuncContractPost) {
llvmFuncDef->set_needs_wrapper_result_arg(true);
}
}

try {
Expand Down
1 change: 0 additions & 1 deletion src/llvm/lib/Transform/Instruction/OtherOpTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ void llvm2col::transformPallasSpecResult(llvm::CallInst &callInstruction,

} else {
// Case 2: Result is returned as a sret parameter
// Implement & and add example!
if (llvmSpecFunc->arg_size() != 1 ||
!llvmSpecFunc->getArg(0)->hasStructRetAttr()) {
pallas::ErrorReporter::addError(
Expand Down
3 changes: 3 additions & 0 deletions src/main/vct/main/stages/Transformation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import vct.rewrite.{
ExplicitResourceValues,
GenerateSingleOwnerPermissions,
HeapVariableToRef,
InlinePallasWrappers,
InlineTrivialLets,
LowerLocalHeapVariables,
MonomorphizeClass,
Expand Down Expand Up @@ -353,6 +354,8 @@ case class SilverTransformation(
// Replace leftover SYCL types
ReplaceSYCLTypes,
CFloatIntCoercion,
// Inline pallas-specifications
InlinePallasWrappers,

// BIP transformations
ComputeBipGlue,
Expand Down
113 changes: 113 additions & 0 deletions src/rewrite/vct/rewrite/InlinePallasWrappers.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package vct.rewrite

import vct.col.ast._
import vct.col.origin.{LabelContext, Origin, PreferredName}
import vct.col.ref.Ref
import vct.col.rewrite.{Generation, Rewriter, RewriterBuilder, Rewritten}
import vct.col.util.AstBuildHelpers.assignLocal
import vct.col.util.StatementToExpression
import vct.result.Message
import vct.result.VerificationError.SystemError
import vct.rewrite.InlinePallasWrappers.{
InlineArgAssignOrigin,
WrapperInlineFailed,
}

case object InlinePallasWrappers extends RewriterBuilder {
override def key: String = "inlinePallasWrapper"
override def desc: String =
"Inline calls to wrapper-function in pallas specifications."

val InlineArgAssignOrigin: Origin = Origin(Seq(
PreferredName(Seq("inlineArgAssign")),
LabelContext("Assign of argument during inlining"),
))

private def WrapperInliningOrigin(wrapperDef: Origin, inv: Node[_]): Origin =
Origin(
(LabelContext("inlining of ") +: inv.o.originContents) ++
(LabelContext("definition") +: wrapperDef.originContents)
)

case class WrapperInlineFailed(inv: ProcedureInvocation[_], msg: String = "")
extends SystemError {
override def text: String = {
Message.messagesInContext((
inv.o,
"Inlining of wrapper-function in pallas specification failed. " + msg,
))
}
}
}

case class InlinePallasWrappers[Pre <: Generation]() extends Rewriter[Pre] {

override def dispatch(decl: Declaration[Pre]): Unit = {

// Drop all definitions of wrapper functions, since they will be inlined.
decl match {
case proc: Procedure[Pre] if proc.pallasWrapper =>
case other => rewriteDefault(other)
}
}

override def dispatch(node: Expr[Pre]): Expr[Rewritten[Pre]] = {

node match {
case res: LLVMIntermediaryResult[Pre] =>
implicit val o: Origin = res.o
res.sretArg match {
case Some(Ref(retArg)) => Local[Post](ref = succ(retArg))
case None => Result[Post](applicable = succ(res.applicable.decl))
}
case inv: ProcedureInvocation[Pre] if inv.ref.decl.pallasWrapper =>
// TODO: Implement inlining of pallas wrappers
val wFunc = inv.ref.decl

if (wFunc.body.isEmpty) {
throw WrapperInlineFailed(inv, "Cannot inline function without body")
}
if (wFunc.args.size != inv.args.size) {
throw WrapperInlineFailed(
inv,
"Number of arguments differs between definition and invocation.",
)
}

// Declare variables to substitute the vars from the function definition.
val newArgs = localHeapVariables.scope {
variables.scope {
wFunc.args.map(arg => new Variable[Pre](arg.t)(arg.o))
}
}

val assigns = newArgs.zip(inv.args).map { case (v, e) =>
assignLocal(new Local[Pre](v.ref)(v.o), e)(InlineArgAssignOrigin)
}

/*
TODO: The function body is a scope, hence the assignments are not in the
same scope as the rest of the instructions of the function-body
*/
val bodyWithAssign = Block(assigns ++ wFunc.body)(inv.o)

val inlinedBody = StatementToExpression.toExpression(
this,
(s: String) => WrapperInlineFailed(inv, s),
bodyWithAssign,
None,
)

inlinedBody match {
case Some(e) => e
case None =>
throw WrapperInlineFailed(
inv,
"Wrapper could not be converted to expression.",
)
}
case other => rewriteDefault(other)
}

}
}
Loading

0 comments on commit 828e10a

Please sign in to comment.