From ffffac8c10298af06d4ebd4bedc11d01b9cc4642 Mon Sep 17 00:00:00 2001 From: Robert <27515600+RobertMensing@users.noreply.github.com> Date: Wed, 4 Dec 2024 11:10:17 +0100 Subject: [PATCH] First draft of supporting \result in Pallas specifications --- src/col/vct/col/ast/Node.scala | 11 ++ .../col/ast/lang/llvm/LLVMResultImpl.scala | 13 ++ .../vct/col/typerules/CoercingRewriter.scala | 1 + .../Passes/Function/ExprWrapperMapper.h | 51 ++++++++ .../PallasFunctionContractDeclarerPass.h | 12 -- .../Transform/Instruction/OtherOpTransform.h | 19 +++ src/llvm/include/Util/Constants.h | 3 + src/llvm/include/Util/PallasMD.h | 44 +++++++ .../lib/Passes/Function/ExprWrapperMapper.cpp | 88 ++++++++++++++ .../lib/Passes/Function/FunctionDeclarer.cpp | 20 ++++ .../PallasFunctionContractDeclarerPass.cpp | 19 ++- src/llvm/lib/Passes/Function/PureAssigner.cpp | 3 +- src/llvm/lib/Plugin.cpp | 4 + .../Instruction/OtherOpTransform.cpp | 83 +++++++++++++ src/llvm/lib/Util/PallasMD.cpp | 35 ++++++ .../vct/rewrite/lang/LangLLVMToCol.scala | 111 ++++++++++++++---- .../vct/rewrite/lang/LangSpecificToCol.scala | 2 +- 17 files changed, 470 insertions(+), 49 deletions(-) create mode 100644 src/col/vct/col/ast/lang/llvm/LLVMResultImpl.scala create mode 100644 src/llvm/include/Passes/Function/ExprWrapperMapper.h create mode 100644 src/llvm/include/Util/PallasMD.h create mode 100644 src/llvm/lib/Passes/Function/ExprWrapperMapper.cpp create mode 100644 src/llvm/lib/Util/PallasMD.cpp diff --git a/src/col/vct/col/ast/Node.scala b/src/col/vct/col/ast/Node.scala index 5dec0d9e09..350060eb0f 100644 --- a/src/col/vct/col/ast/Node.scala +++ b/src/col/vct/col/ast/Node.scala @@ -3524,6 +3524,12 @@ final class LLVMFunctionDefinition[G]( val functionBody: Option[Statement[G]], val contract: LLVMFunctionContract[G], val pure: Boolean = false, + // If this function is a wrapper function for an expression of a + // pallas specification of a function F, then this field references F. + val pallasExprWrapperFor: Option[Ref[G, LLVMFunctionDefinition[G]]], + // Indicates that a new argument has to be added to pass the value for \result + // to the expression-wrapper. + val needsWrapperResultArg: Boolean = false, )(val blame: Blame[CallableFailure])(implicit val o: Origin) extends LLVMCallable[G] with Applicable[G] @@ -3664,6 +3670,11 @@ final class LLVMGlobalSpecification[G](val value: String)( var data: Option[Seq[GlobalDeclaration[G]]] = None } +// Node that represents the \result-construct in a Pallas contract. +final case class LLVMResult[G](func: Ref[G, LLVMFunctionDefinition[G]])( + implicit val o: Origin +) extends LLVMExpr[G] with LLVMResultImpl[G] + @family sealed trait LLVMMemoryOrdering[G] extends NodeFamily[G] with LLVMMemoryOrderingImpl[G] diff --git a/src/col/vct/col/ast/lang/llvm/LLVMResultImpl.scala b/src/col/vct/col/ast/lang/llvm/LLVMResultImpl.scala new file mode 100644 index 0000000000..734be782d3 --- /dev/null +++ b/src/col/vct/col/ast/lang/llvm/LLVMResultImpl.scala @@ -0,0 +1,13 @@ +package vct.col.ast.lang.llvm + +import vct.col.ast.node.NodeFamilyImpl +import vct.col.ast.ops.LLVMResultOps +import vct.col.ast.{LLVMResult, Type} +import vct.col.print.Precedence + +trait LLVMResultImpl[G] extends NodeFamilyImpl[G] with LLVMResultOps[G] { + this: LLVMResult[G] => + override def t: Type[G] = func.decl.returnType + + override def precedence: Int = Precedence.ATOMIC +} diff --git a/src/col/vct/col/typerules/CoercingRewriter.scala b/src/col/vct/col/typerules/CoercingRewriter.scala index bd2749d87c..82bc394ad9 100644 --- a/src/col/vct/col/typerules/CoercingRewriter.scala +++ b/src/col/vct/col/typerules/CoercingRewriter.scala @@ -2147,6 +2147,7 @@ abstract class CoercingRewriter[Pre <: Generation]() case LLVMVectorValue(_, _) => e case LLVMRawVectorValue(_, _) => e case LLVMZeroedAggregateValue(_) => e + case LLVMResult(_) => e case PVLEndpointExpr(_, _) => e case EndpointExpr(ref, expr) => e case ChorExpr(expr) => ChorExpr(bool(expr)) diff --git a/src/llvm/include/Passes/Function/ExprWrapperMapper.h b/src/llvm/include/Passes/Function/ExprWrapperMapper.h new file mode 100644 index 0000000000..3f2acc9638 --- /dev/null +++ b/src/llvm/include/Passes/Function/ExprWrapperMapper.h @@ -0,0 +1,51 @@ +#ifndef PALLAS_EXPRWRAPPERMAPPER_H +#define PALLAS_EXPRWRAPPERMAPPER_H + +#include "vct/col/ast/col.pb.h" +#include +#include +#include + +/** + * Analysis-pass that maps functions that represent expression wrappers in a + * Pallas specification to the function to whose specification they belong. + */ +namespace pallas { + +enum PallasWrapperContext { + FuncContractPre, + FuncContractPost +}; + +class EWMResult { + private: + llvm::Function *parentFunc; + std::optional context; + + public: + explicit EWMResult(llvm::Function *parentFunc, + std::optional ctx); + + llvm::Function *getParentFunc(); + + std::optional getContext(); +}; + +class ExprWrapperMapper : public llvm::AnalysisInfoMixin { + friend llvm::AnalysisInfoMixin; + static llvm::AnalysisKey Key; + + public: + using Result = EWMResult; + + /** + * Maps functions that represent a Pallas expression wrapper to the function + * to whose specification they belong to. + * If a function does not belong to the contract of any function, + * the result contains a nullpointer. + */ + Result run(llvm::Function &F, llvm::FunctionAnalysisManager &FAM); +}; + +} // namespace pallas +#endif // PALLAS_EXPRWRAPPERMAPPER_H diff --git a/src/llvm/include/Passes/Function/PallasFunctionContractDeclarerPass.h b/src/llvm/include/Passes/Function/PallasFunctionContractDeclarerPass.h index c823f32fce..7579a8a88b 100644 --- a/src/llvm/include/Passes/Function/PallasFunctionContractDeclarerPass.h +++ b/src/llvm/include/Passes/Function/PallasFunctionContractDeclarerPass.h @@ -108,18 +108,6 @@ class PallasFunctionContractDeclarerPass */ bool hasConflictingContract(Function &f); - /** - * Checks if the given function has a metadata-node that is labeled as a - * Pallas function contract. - */ - bool hasPallasContract(const Function &f); - - /** - * Checks if the given function has a metadata-node that is labeled as a - * VCLLVM contract. - */ - bool hasVcllvmContract(const Function &f); - /** * Checks if the given metadata-node is a wellformed encoding of a * pallas source-location. diff --git a/src/llvm/include/Transform/Instruction/OtherOpTransform.h b/src/llvm/include/Transform/Instruction/OtherOpTransform.h index 6fdfade2b2..1cc7742c2b 100644 --- a/src/llvm/include/Transform/Instruction/OtherOpTransform.h +++ b/src/llvm/include/Transform/Instruction/OtherOpTransform.h @@ -38,6 +38,25 @@ void transformFCmp(llvm::FCmpInst &fcmpInstruction, col::Block &colBlock, pallas::FunctionCursor &funcCursor); bool checkCallSupport(llvm::CallInst &callInstruction); + +/** + * Transforms a call to a function form the Pallas specification library to the + * appropriate specification construct. + */ +void transformPallasSpecLibCall(llvm::CallInst &callInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor); + +/** + * Transform the given call-instruction to the result-function of the pallas + * specification library. + * Assumes that the provided function-call is indeed a call to a result-function + * of the pallas specification library. + */ +void transformPallasSpecResult(llvm::CallInst &callInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor); + } // namespace llvm2col #endif // PALLAS_OTHEROPTRANSFORM_H diff --git a/src/llvm/include/Util/Constants.h b/src/llvm/include/Util/Constants.h index 4afe72d4f9..5cc36663a1 100644 --- a/src/llvm/include/Util/Constants.h +++ b/src/llvm/include/Util/Constants.h @@ -13,6 +13,9 @@ const std::string PALLAS_ENSURES = "pallas.ensures"; const std::string PALLAS_WRAPPER_FUNC = "pallas.exprWrapper"; const std::string PALLAS_SRC_LOC_ID = "pallas.srcLoc"; +const std::string PALLAS_SPEC_LIB_MARKER = "pallas.specLib"; +const std::string PALLAS_SPEC_RESULT = "pallas.result"; + // Legacy VCLLVM constants const std::string VC_PREFIX = "VC."; diff --git a/src/llvm/include/Util/PallasMD.h b/src/llvm/include/Util/PallasMD.h new file mode 100644 index 0000000000..580ed43652 --- /dev/null +++ b/src/llvm/include/Util/PallasMD.h @@ -0,0 +1,44 @@ +#ifndef PALLAS_MD_H +#define PALLAS_MD_H + +#include +#include +#include + +/** + * Utils for working with the metadata-node of pallas specifications. + */ +namespace pallas::utils { + +/** + * Checks if the given function is labeled as a function from the pallas + * specification-library. + * If so, it returns an optinal that contains the string-identifier of the kind + * of spec-livb function. + * If it is not a function from the specification library, an empty optional is + * returned. + * @param f The function to check + */ +std::optional isPallasSpecLib(const llvm::Function &f); + +/** + * Checks if the given function has a metadata-node that is labeled as a + * Pallas function contract. + */ +bool hasPallasContract(const llvm::Function &f); + +/** + * Checks if the given function has a metadata-node that is labeled as a + * VCLLVM contract. + */ +bool hasVcllvmContract(const llvm::Function &f); + +/** + * Checks if the given llvm function is marked as an expression wrapper of a + * pallas specification. + */ +bool isPallasExprWrapper(const llvm::Function &f); + +} // namespace pallas::utils + +#endif // PALLAS_MD_H diff --git a/src/llvm/lib/Passes/Function/ExprWrapperMapper.cpp b/src/llvm/lib/Passes/Function/ExprWrapperMapper.cpp new file mode 100644 index 0000000000..5860d19186 --- /dev/null +++ b/src/llvm/lib/Passes/Function/ExprWrapperMapper.cpp @@ -0,0 +1,88 @@ +#include "Passes/Function/ExprWrapperMapper.h" +#include "Passes/Function/FunctionDeclarer.h" + +#include "Origin/OriginProvider.h" +#include "Util/Constants.h" +#include "Util/Exceptions.h" +#include "Util/PallasMD.h" + +#include + +namespace pallas { +const std::string SOURCE_LOC = "Passes::Function::ExprWrapperMapper"; + +using namespace llvm; +namespace col = vct::col::ast; + +/* + * EWMResult + */ + +EWMResult::EWMResult(llvm::Function *parentFunc, + std::optional ctx) + : parentFunc(parentFunc), context(ctx) {} + +llvm::Function *EWMResult::getParentFunc() { return parentFunc; } + +std::optional EWMResult::getContext() { return context; } + +/* + * ExpressionWrapperMapper + */ + +AnalysisKey ExprWrapperMapper::Key; + +ExprWrapperMapper::Result ExprWrapperMapper::run(Function &F, + FunctionAnalysisManager &FAM) { + + auto *llvmModule = F.getParent(); + + // Check all functions in the current module + for (Function &parentF : llvmModule->functions()) { + // Check if the function has a pallas-contract + if (!utils::hasPallasContract(parentF)) { + continue; + } + auto *contract = parentF.getMetadata(constants::PALLAS_FUNC_CONTRACT); + + // Look at all of the clauses and check if they reference the + // wrapper-function + auto numOps = contract->getNumOperands(); + unsigned int clauseIdx = 2; + for (clauseIdx = 2; clauseIdx < numOps; ++clauseIdx) { + // Try to get the third operand as a function + auto *clause = + dyn_cast(contract->getOperand(clauseIdx).get()); + if (clause == nullptr || clause->getNumOperands() < 3) + continue; + auto *clauseWrapperMD = + dyn_cast(clause->getOperand(2).get()); + if (clauseWrapperMD == nullptr) + continue; + auto *clauseWrapper = + dyn_cast_if_present(clauseWrapperMD->getValue()); + if (clauseWrapper == nullptr) + continue; + // Check if the wrapper-function in the clause is the function that + // we are looking for. + if (clauseWrapper == &F) { + // Determine the context in which the wrapper is used. + std::optional ctx = std::nullopt; + if (auto *fClauseTMD = + dyn_cast(clause->getOperand(0).get())) { + auto clauseTStr = fClauseTMD->getString().str(); + if (clauseTStr == pallas::constants::PALLAS_REQUIRES) { + ctx = PallasWrapperContext::FuncContractPre; + } else if (clauseTStr == + pallas::constants::PALLAS_ENSURES) { + ctx = PallasWrapperContext::FuncContractPost; + } + } + return EWMResult(&parentF, ctx); + } + } + } + return EWMResult(nullptr, std::nullopt); +} + +} // namespace pallas diff --git a/src/llvm/lib/Passes/Function/FunctionDeclarer.cpp b/src/llvm/lib/Passes/Function/FunctionDeclarer.cpp index b288834db7..ce7c8baffb 100644 --- a/src/llvm/lib/Passes/Function/FunctionDeclarer.cpp +++ b/src/llvm/lib/Passes/Function/FunctionDeclarer.cpp @@ -1,9 +1,11 @@ #include "Passes/Function/FunctionDeclarer.h" +#include "Passes/Function/ExprWrapperMapper.h" #include "Origin/OriginProvider.h" #include "Passes/Module/RootContainer.h" #include "Transform/Transform.h" #include "Util/Exceptions.h" +#include "Util/PallasMD.h" namespace pallas { const std::string SOURCE_LOC = "Passes::Function::FunctionDeclarer"; @@ -110,6 +112,20 @@ FDResult FunctionDeclarer::run(Function &F, FunctionAnalysisManager &FAM) { pallas::ErrorReporter::addError(SOURCE_LOC, errorStream.str(), F); } + llvmFuncDef->set_needs_wrapper_result_arg(false); + if (utils::isPallasExprWrapper(F)) { + auto mapperResult = FAM.getResult(F); + auto *wrapperParent = + mapperResult.getParentFunc(); + + auto colParent = FAM.getResult(*wrapperParent); + llvmFuncDef->mutable_pallas_expr_wrapper_for()->set_id( + colParent.getFunctionId()); + if (mapperResult.getContext() == PallasWrapperContext::FuncContractPost) { + llvmFuncDef->set_needs_wrapper_result_arg(true); + } + } + if (F.isDeclaration()) { // Defined outside of this module so we don't know if it's pure or what // its contract is @@ -130,6 +146,10 @@ FDResult FunctionDeclarer::run(Function &F, FunctionAnalysisManager &FAM) { */ PreservedAnalyses FunctionDeclarerPass::run(Function &F, FunctionAnalysisManager &FAM) { + + // TODO: Check if the function is part of the spec-lib library. + // If so, skip it. + FDResult result = FAM.getResult(F); // Just makes sure we analyse every function return PreservedAnalyses::all(); diff --git a/src/llvm/lib/Passes/Function/PallasFunctionContractDeclarerPass.cpp b/src/llvm/lib/Passes/Function/PallasFunctionContractDeclarerPass.cpp index d822191dc2..f2ec9baaac 100644 --- a/src/llvm/lib/Passes/Function/PallasFunctionContractDeclarerPass.cpp +++ b/src/llvm/lib/Passes/Function/PallasFunctionContractDeclarerPass.cpp @@ -5,6 +5,7 @@ #include "Passes/Function/FunctionDeclarer.h" #include "Util/Constants.h" #include "Util/Exceptions.h" +#include "Util/PallasMD.h" #include #include @@ -30,9 +31,10 @@ PallasFunctionContractDeclarerPass::run(Function &f, return PreservedAnalyses::all(); // Skip, if f has a non-empty vcllvm-contract, or no contract at all // If it does not have a contract, we need an empty VCLLVM contract instead - // of an empty Pallas contract. Otherwise the mechanism for loading - // contracts from a PVL-file does not get invoked. - if (hasVcllvmContract(f) || !hasPallasContract(f)) + // of an empty Pallas contract. Otherwise the mechanism for loading + // contracts from a PVL-file does not get invoked. + if (pallas::utils::hasVcllvmContract(f) || + !pallas::utils::hasPallasContract(f)) return PreservedAnalyses::all(); // Setup a fresh Pallas-contract @@ -357,7 +359,8 @@ void PallasFunctionContractDeclarerPass::extendPredicate( } bool PallasFunctionContractDeclarerPass::hasConflictingContract(Function &f) { - bool conflict = hasPallasContract(f) && hasVcllvmContract(f); + bool conflict = pallas::utils::hasPallasContract(f) && + pallas::utils::hasVcllvmContract(f); if (conflict) { pallas::ErrorReporter::addError( SOURCE_LOC, @@ -366,14 +369,6 @@ bool PallasFunctionContractDeclarerPass::hasConflictingContract(Function &f) { return conflict; } -bool PallasFunctionContractDeclarerPass::hasPallasContract(const Function &f) { - return f.hasMetadata(pallas::constants::PALLAS_FUNC_CONTRACT); -} - -bool PallasFunctionContractDeclarerPass::hasVcllvmContract(const Function &f) { - return f.hasMetadata(pallas::constants::METADATA_CONTRACT_KEYWORD); -} - bool PallasFunctionContractDeclarerPass::isWellformedPallasLocation( const MDNode *mdNode) { diff --git a/src/llvm/lib/Passes/Function/PureAssigner.cpp b/src/llvm/lib/Passes/Function/PureAssigner.cpp index d80a534757..1b9de59143 100644 --- a/src/llvm/lib/Passes/Function/PureAssigner.cpp +++ b/src/llvm/lib/Passes/Function/PureAssigner.cpp @@ -3,6 +3,7 @@ #include "Passes/Function/FunctionDeclarer.h" #include "Util/Constants.h" #include "Util/Exceptions.h" +#include "Util/PallasMD.h" #include @@ -54,7 +55,7 @@ PreservedAnalyses PureAssignerPass::run(Function &F, } // Check if the function is marked as a pallas wrapper-function - if (F.hasMetadata(pallas::constants::PALLAS_WRAPPER_FUNC)) { + if (utils::isPallasExprWrapper(F)) { pureAnnotationCount++; isPure = true; } diff --git a/src/llvm/lib/Plugin.cpp b/src/llvm/lib/Plugin.cpp index 064a98dcfd..0ec330bbc3 100644 --- a/src/llvm/lib/Plugin.cpp +++ b/src/llvm/lib/Plugin.cpp @@ -1,3 +1,5 @@ + +#include "Passes/Function/ExprWrapperMapper.h" #include "Passes/Function/FunctionBodyTransformer.h" #include "Passes/Function/FunctionContractDeclarer.h" #include "Passes/Function/FunctionDeclarer.h" @@ -27,6 +29,8 @@ llvm::PassPluginLibraryInfo getPallasPluginInfo() { [&] { return pallas::FunctionDeclarer(); }); FAM.registerPass( [&] { return pallas::FunctionContractDeclarer(); }); + FAM.registerPass( + [&] { return pallas::ExprWrapperMapper(); }); }); PB.registerPipelineParsingCallback( [](StringRef Name, llvm::ModulePassManager &MPM, diff --git a/src/llvm/lib/Transform/Instruction/OtherOpTransform.cpp b/src/llvm/lib/Transform/Instruction/OtherOpTransform.cpp index 8490b1598c..b2fd1a1a9e 100644 --- a/src/llvm/lib/Transform/Instruction/OtherOpTransform.cpp +++ b/src/llvm/lib/Transform/Instruction/OtherOpTransform.cpp @@ -1,9 +1,16 @@ #include "Transform/Instruction/OtherOpTransform.h" +#include #include +#include +#include +#include +#include "Passes/Function/ExprWrapperMapper.h" #include "Transform/BlockTransform.h" #include "Transform/Transform.h" +#include "Util/Constants.h" #include "Util/Exceptions.h" +#include "Util/PallasMD.h" const std::string SOURCE_LOC = "Transform::Instruction::OtherOp"; @@ -250,6 +257,14 @@ void llvm2col::transformCallExpr(llvm::CallInst &callInstruction, // TODO: Deal with intrinsic functions return; } + + // If it is a call to a function from the pallas specification library, + // we transform it into the appropriate col-node. + if (pallas::utils::isPallasSpecLib(*callInstruction.getCalledFunction())) { + transformPallasSpecLibCall(callInstruction, colBlock, funcCursor); + return; + } + // allocate expression to host the function call in advance col::Expr *functionCallExpr; // if void function add an eval expression @@ -280,3 +295,71 @@ void llvm2col::transformCallExpr(llvm::CallInst &callInstruction, *invocation->add_args()); } } + +void llvm2col::transformPallasSpecLibCall(llvm::CallInst &callInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + auto specLibType = + pallas::utils::isPallasSpecLib(*callInstruction.getCalledFunction()) + .value(); + + if (specLibType == pallas::constants::PALLAS_SPEC_RESULT) { + transformPallasSpecResult(callInstruction, colBlock, funcCursor); + } else { + pallas::ErrorReporter::addError( + SOURCE_LOC, "Unsupported Pallas specification function", + callInstruction); + } +} + +void llvm2col::transformPallasSpecResult(llvm::CallInst &callInstruction, + col::Block &colBlock, + pallas::FunctionCursor &funcCursor) { + auto *llvmSpecFunc = callInstruction.getCalledFunction(); + bool isRegularReturn = !llvmSpecFunc->getReturnType()->isVoidTy(); + + // Get the function to whose contract this call instuction belongs to. + auto *wrapperFunc = callInstruction.getFunction(); + auto *llvmParentFunc = funcCursor.getFunctionAnalysisManager() + .getResult(*wrapperFunc) + .getParentFunc(); + if (llvmParentFunc == nullptr) { + pallas::ErrorReporter::addError( + SOURCE_LOC, + "Encountered call to spec-lib that cannot be associated " + "with a function", + callInstruction); + return; + } + auto &colParentFunc = funcCursor.getFDResult(*llvmParentFunc); + + if (isRegularReturn) { + // Case 1: Result is returned as regular return-value + // %2 = call i32 @pallas.result.0() + + // Check that the function signature is wellformed + if (!llvmSpecFunc->arg_empty()) { + pallas::ErrorReporter::addError( + SOURCE_LOC, "Malformed pallas spec-lib result-function", + callInstruction); + } + + // Build the assignment-expression + col::Assign &assignment = funcCursor.createAssignmentAndDeclaration( + callInstruction, colBlock); + auto *assignExpr = assignment.mutable_value(); + auto *resultNode = assignExpr->mutable_llvm_result(); + resultNode->set_allocated_origin( + llvm2col::generateFunctionCallOrigin(callInstruction)); + // Set ref to the function to which this contract is attached to + resultNode->mutable_func()->set_id(colParentFunc.getFunctionId()); + + } else { + // Case 2: Result is returned as a sret parameter + // Implement & and add example! + pallas::ErrorReporter::addError(SOURCE_LOC, "Unsupported", + callInstruction); + } + + // TODO: Handle cases, where the result is returned in other ways +} \ No newline at end of file diff --git a/src/llvm/lib/Util/PallasMD.cpp b/src/llvm/lib/Util/PallasMD.cpp new file mode 100644 index 0000000000..446f8e8da0 --- /dev/null +++ b/src/llvm/lib/Util/PallasMD.cpp @@ -0,0 +1,35 @@ +#include "Util/PallasMD.h" +#include "Util/Constants.h" + +#include +#include + +namespace pallas::utils { + +std::optional isPallasSpecLib(const llvm::Function &f) { + + auto *mdMarker = f.getMetadata(constants::PALLAS_SPEC_LIB_MARKER); + if (mdMarker == nullptr || mdMarker->getNumOperands() != 1) + return {}; + + auto *mdTypeStr = + llvm::dyn_cast(mdMarker->getOperand(0).get()); + if (mdTypeStr == nullptr) + return {}; + + return mdTypeStr->getString().str(); +} + +bool hasPallasContract(const llvm::Function &f) { + return f.hasMetadata(pallas::constants::PALLAS_FUNC_CONTRACT); +} + +bool hasVcllvmContract(const llvm::Function &f) { + return f.hasMetadata(pallas::constants::METADATA_CONTRACT_KEYWORD); +} + +bool isPallasExprWrapper(const llvm::Function &f) { + return f.hasMetadata(pallas::constants::PALLAS_WRAPPER_FUNC); +} + +} // namespace pallas::utils diff --git a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala index b73eb7fe41..fd9a06739e 100644 --- a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala @@ -1,13 +1,16 @@ package vct.rewrite.lang import com.typesafe.scalalogging.LazyLogging -import vct.col.ast._ +import hre.util.ScopedStack +import vct.col.ast.{Expr, _} import vct.col.origin.{ AssertFailed, Blame, DiagnosticOrigin, + LabelContext, Origin, PanicBlame, + PreferredName, TypeName, UnreachableReachedError, } @@ -76,6 +79,10 @@ case object LangLLVMToCol { override def blame(error: AssertFailed): Unit = unreachable.blame.blame(UnreachableReachedError(unreachable)) } + + val pallasResArgOrigin: Origin = Origin( + Seq(PreferredName(Seq("resArg")), LabelContext("result arg")) + ) } case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) @@ -108,6 +115,10 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) .ArrayBuffer() private val elidedBackEdges: mutable.Set[LabelDecl[Pre]] = mutable.Set() + // If the LLVM-function that is currently being transformed, + // this contains the argument that is used to pass the value of \result + private val wrapperRetArg: ScopedStack[Option[Variable[Post]]] = ScopedStack() + def gatherBackEdges(program: Program[Pre]): Unit = { program.collect { case loop: LLVMLoop[Pre] => elidedBackEdges.add(loop.header.decl) @@ -362,6 +373,7 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) val newArgs = func.importedArguments.getOrElse(func.args).map { it => it.rewriteDefault() } + val retArg = getPallasSpecRetArg(func) rw.globalDeclarations.declare( new Procedure[Post]( returnType = rw @@ -371,20 +383,22 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) func.args.zip(newArgs).foreach { case (a, b) => rw.variables.succeed(a, b) } - }._1, + }._1 ++ retArg, outArgs = Nil, typeArgs = Nil, body = - func.functionBody match { - case None => None - case Some(functionBody) => - if (func.pure) - Some(GotoEliminator(functionBody match { - case scope: Scope[Pre] => scope; - case other => throw UnexpectedLLVMNode(other) - }).eliminate()) - else - Some(rw.dispatch(functionBody)) + wrapperRetArg.having(retArg) { + func.functionBody match { + case None => None + case Some(functionBody) => + if (func.pure) + Some(GotoEliminator(functionBody match { + case scope: Scope[Pre] => scope; + case other => throw UnexpectedLLVMNode(other) + }).eliminate()) + else + Some(rw.dispatch(functionBody)) + } }, contract = func.contract match { @@ -396,10 +410,25 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) pure = func.pure, )(func.blame) ) + } llvmFunctionMap.update(func, procedure) } + private def getPallasSpecRetArg( + wFunc: LLVMFunctionDefinition[Pre] + ): Option[Variable[Post]] = { + if (!wFunc.needsWrapperResultArg) { None } + wFunc.pallasExprWrapperFor match { + case Some(pFunc) => + Some( + new Variable(pFunc.decl.returnType)(pallasResArgOrigin) + .rewriteDefault() + ) + case None => None + } + } + def rewriteAmbiguousFunctionInvocation( inv: LLVMAmbiguousFunctionInvocation[Pre] ): Invocation[Post] = { @@ -438,19 +467,23 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) inv: LLVMFunctionInvocation[Pre] ): ProcedureInvocation[Post] = { implicit val o: Origin = inv.o + new ProcedureInvocation[Post]( ref = new LazyRef[Post, Procedure[Post]](llvmFunctionMap(inv.ref.decl)), - args = inv.args.zipWithIndex.map { - // TODO: This is really ugly, can we do the type inference in the resolve step and then do coercions to do this? - case (a, i) => - val requiredType = localVariableInferredType - .getOrElse(inv.ref.decl.args(i), inv.ref.decl.args(i).t) - if ( - a.t != requiredType && a.t.asPointer.isDefined && - requiredType.asPointer.isDefined - ) { Cast(a, TypeValue(requiredType)) } - else { a } - }.map(rw.dispatch), + args = extendWithPallasSpecArgs( + inv, + inv.args.zipWithIndex.map { + // TODO: This is really ugly, can we do the type inference in the resolve step and then do coercions to do this? + case (a, i) => + val requiredType = localVariableInferredType + .getOrElse(inv.ref.decl.args(i), inv.ref.decl.args(i).t) + if ( + a.t != requiredType && a.t.asPointer.isDefined && + requiredType.asPointer.isDefined + ) { Cast(a, TypeValue(requiredType)) } + else { a } + }.map(rw.dispatch), + ), givenMap = inv.givenMap.map { case (Ref(v), e) => (rw.succ(v), rw.dispatch(e)) }, @@ -462,6 +495,29 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) )(inv.blame) } + /** If the given invocation invokes a wrapper-function for a Pallas + * specification, the list of passed arguments is extended to pass the value + * of \result into the wrapper-function. + */ + private def extendWithPallasSpecArgs( + inv: LLVMFunctionInvocation[Pre], + args: Seq[Expr[Post]], + ): Seq[Expr[Post]] = { + val wFunc = inv.ref.decl + wFunc.pallasExprWrapperFor match { + case Some(pFunc) => + var newArgs = args + if (wFunc.needsWrapperResultArg) { + newArgs = + newArgs :+ new Result[Post](llvmFunctionMap.ref(pFunc.decl))( + pallasResArgOrigin + ) + } + newArgs + case None => args + } + } + def rewriteGlobal(decl: LLVMGlobalSpecification[Pre]): Unit = { implicit val o: Origin = decl.o decl.data.get.foreach { decl => @@ -905,6 +961,15 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) ) } + def rewriteResult(res: LLVMResult[Pre]): Local[Post] = { + implicit val o: Origin = res.o + if (wrapperRetArg.isEmpty || wrapperRetArg.top.isEmpty) { + throw UnexpectedLLVMNode(res) + } + new Local[Post](ref = wrapperRetArg.top.get.ref) + // new Result[Post](applicable = llvmFunctionMap.ref(res.func.decl)) + } + def result(ref: RefLLVMFunctionDefinition[Pre])( implicit o: Origin ): Expr[Post] = Result[Post](llvmFunctionMap.ref(ref.decl)) diff --git a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala index c1af3d68bd..004cd749fd 100644 --- a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala @@ -394,7 +394,7 @@ case class LangSpecificToCol[Pre <: Generation]( case zext: LLVMZeroExtend[Pre] => llvm.rewriteZeroExtend(zext) case trunc: LLVMTruncate[Pre] => llvm.rewriteTruncate(trunc) case fpext: LLVMFloatExtend[Pre] => llvm.rewriteFloatExtend(fpext) - + case result: LLVMResult[Pre] => llvm.rewriteResult(result) case other => rewriteDefault(other) }