diff --git a/lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp b/lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp index 0200ba760..f63de5a3b 100644 --- a/lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp +++ b/lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp @@ -5,6 +5,7 @@ #include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" #include "lib/Analysis/Utils.h" +#include "lib/Dialect/Mgmt/IR/MgmtAttributes.h" #include "lib/Dialect/Mgmt/IR/MgmtOps.h" #include "lib/Dialect/Secret/IR/SecretOps.h" #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project @@ -106,6 +107,26 @@ int getDimension(Value value, DataFlowSolver *solver) { return lattice->getValue().getDimension(); } +int getDimensionFromMgmtAttr(Value value) { + Attribute attr; + if (auto blockArg = dyn_cast(value)) { + auto *parentOp = blockArg.getOwner()->getParentOp(); + auto genericOp = dyn_cast(parentOp); + if (genericOp) { + attr = genericOp.getArgAttr(blockArg.getArgNumber(), + mgmt::MgmtDialect::kArgMgmtAttrName); + } + } else { + auto *parentOp = value.getDefiningOp(); + attr = parentOp->getAttr(mgmt::MgmtDialect::kArgMgmtAttrName); + } + if (!mlir::isa(attr)) { + assert(false && "MgmtAttr not found"); + } + auto mgmtAttr = mlir::cast(attr); + return mgmtAttr.getDimension(); +} + void annotateDimension(Operation *top, DataFlowSolver *solver) { auto getIntegerAttr = [&](int dimension) { return IntegerAttr::get(IntegerType::get(top->getContext(), 64), dimension); diff --git a/lib/Analysis/DimensionAnalysis/DimensionAnalysis.h b/lib/Analysis/DimensionAnalysis/DimensionAnalysis.h index af45f684d..3644a2819 100644 --- a/lib/Analysis/DimensionAnalysis/DimensionAnalysis.h +++ b/lib/Analysis/DimensionAnalysis/DimensionAnalysis.h @@ -93,6 +93,8 @@ class DimensionAnalysis // initialized DimensionState::DimensionType getDimension(Value value, DataFlowSolver *solver); +DimensionState::DimensionType getDimensionFromMgmtAttr(Value value); + void annotateDimension(Operation *top, DataFlowSolver *solver); } // namespace heir diff --git a/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp b/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp index 40e83459a..0d324b150 100644 --- a/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp +++ b/lib/Analysis/LevelAnalysis/LevelAnalysis.cpp @@ -5,19 +5,17 @@ #include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" #include "lib/Analysis/Utils.h" +#include "lib/Dialect/Mgmt/IR/MgmtAttributes.h" #include "lib/Dialect/Mgmt/IR/MgmtOps.h" #include "lib/Dialect/Secret/IR/SecretOps.h" #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace heir { @@ -142,5 +140,25 @@ void annotateLevel(Operation *top, DataFlowSolver *solver) { }); } +LevelState::LevelType getLevelFromMgmtAttr(Value value) { + Attribute attr; + if (auto blockArg = dyn_cast(value)) { + auto *parentOp = blockArg.getOwner()->getParentOp(); + auto genericOp = dyn_cast(parentOp); + if (genericOp) { + attr = genericOp.getArgAttr(blockArg.getArgNumber(), + mgmt::MgmtDialect::kArgMgmtAttrName); + } + } else { + auto *parentOp = value.getDefiningOp(); + attr = parentOp->getAttr(mgmt::MgmtDialect::kArgMgmtAttrName); + } + if (!mlir::isa(attr)) { + assert(false && "MgmtAttr not found"); + } + auto mgmtAttr = mlir::cast(attr); + return mgmtAttr.getLevel(); +} + } // namespace heir } // namespace mlir diff --git a/lib/Analysis/LevelAnalysis/LevelAnalysis.h b/lib/Analysis/LevelAnalysis/LevelAnalysis.h index 8d1679a54..1fec33fc1 100644 --- a/lib/Analysis/LevelAnalysis/LevelAnalysis.h +++ b/lib/Analysis/LevelAnalysis/LevelAnalysis.h @@ -93,6 +93,8 @@ class LevelAnalysis } }; +LevelState::LevelType getLevelFromMgmtAttr(Value value); + void annotateLevel(Operation *top, DataFlowSolver *solver); } // namespace heir diff --git a/lib/Analysis/NoiseAnalysis/BGV/BUILD b/lib/Analysis/NoiseAnalysis/BGV/BUILD new file mode 100644 index 000000000..a4d23f7b6 --- /dev/null +++ b/lib/Analysis/NoiseAnalysis/BGV/BUILD @@ -0,0 +1,40 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "NoiseByBoundCoeffModel", + srcs = [ + "NoiseByBoundCoeffModel.cpp", + "NoiseByBoundCoeffModelAnalysis.cpp", + ], + hdrs = [ + "NoiseByBoundCoeffModel.h", + ], + deps = [ + ":Noise", + "@heir//lib/Analysis/DimensionAnalysis", + "@heir//lib/Analysis/LevelAnalysis", + "@heir//lib/Analysis/NoiseAnalysis", + "@heir//lib/Dialect/Mgmt/IR:Dialect", + "@heir//lib/Dialect/Secret/IR:Dialect", + "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@heir//lib/Parameters/BGV:Params", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "Noise", + srcs = ["Noise.cpp"], + hdrs = [ + "Noise.h", + ], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) diff --git a/lib/Analysis/NoiseAnalysis/BGV/Noise.cpp b/lib/Analysis/NoiseAnalysis/BGV/Noise.cpp new file mode 100644 index 000000000..e527731aa --- /dev/null +++ b/lib/Analysis/NoiseAnalysis/BGV/Noise.cpp @@ -0,0 +1,30 @@ +#include "lib/Analysis/NoiseAnalysis/BGV/Noise.h" + +#include + +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace bgv { + +std::string NoiseState::toString() const { + switch (noiseType) { + case (NoiseType::UNINITIALIZED): + return "NoiseState(uninitialized)"; + case (NoiseType::SET): + return "NoiseState(" + std::to_string(log(getValue()) / log(2)) + ") "; + } +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const NoiseState &noise) { + return os << noise.toString(); +} + +Diagnostic &operator<<(Diagnostic &diagnostic, const NoiseState &noise) { + return diagnostic << noise.toString(); +} + +} // namespace bgv +} // namespace heir +} // namespace mlir diff --git a/lib/Analysis/NoiseAnalysis/BGV/Noise.h b/lib/Analysis/NoiseAnalysis/BGV/Noise.h new file mode 100644 index 000000000..2ea01a01a --- /dev/null +++ b/lib/Analysis/NoiseAnalysis/BGV/Noise.h @@ -0,0 +1,88 @@ +#ifndef INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISE_H_ +#define INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISE_H_ + +#include + +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace bgv { + +// This class could be shared among all noise models that tracks the noise by a +// single value. Noise model could have different interpretation of the value. +// In BGV world, most noise model just use a single value, either as bound or +// as variance. +class NoiseState { + public: + enum NoiseType { + // A min value for the lattice, discarable when joined with anything else. + UNINITIALIZED, + // A known value for the lattice, when noise can be inferred. + SET, + }; + + static NoiseState uninitialized() { + return NoiseState(NoiseType::UNINITIALIZED, std::nullopt); + } + static NoiseState of(double value) { + return NoiseState(NoiseType::SET, value); + } + + /// Create an integer value range lattice value. + /// The default constructor must be equivalent to the "entry state" of the + /// lattice, i.e., an uninitialized noise. + NoiseState(NoiseType noiseType = NoiseType::UNINITIALIZED, + std::optional value = std::nullopt) + : noiseType(noiseType), value(value) {} + + bool isKnown() const { return noiseType == NoiseType::SET; } + + bool isInitialized() const { return noiseType != NoiseType::UNINITIALIZED; } + + const double &getValue() const { + assert(isKnown()); + return *value; + } + + bool operator==(const NoiseState &rhs) const { + return noiseType == rhs.noiseType && value == rhs.value; + } + + static NoiseState join(const NoiseState &lhs, const NoiseState &rhs) { + // Uninitialized noises correspond to values that are not secret, + // which may be the inputs to an encryption operation. + if (lhs.noiseType == NoiseType::UNINITIALIZED) { + return rhs; + } + if (rhs.noiseType == NoiseType::UNINITIALIZED) { + return lhs; + } + + assert(lhs.noiseType == NoiseType::SET && rhs.noiseType == NoiseType::SET); + return NoiseState::of(std::max(lhs.getValue(), rhs.getValue())); + } + + void print(llvm::raw_ostream &os) const { os << value; } + + std::string toString() const; + + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const NoiseState &noise); + + friend Diagnostic &operator<<(Diagnostic &diagnostic, + const NoiseState &noise); + + private: + NoiseType noiseType; + // notice that when level becomes large (e.g. 17), the max Q could be like + // 3523 bits and could not be represented in double. + std::optional value; +}; + +} // namespace bgv +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISE_H_ diff --git a/lib/Analysis/NoiseAnalysis/BGV/NoiseByBoundCoeffModel.cpp b/lib/Analysis/NoiseAnalysis/BGV/NoiseByBoundCoeffModel.cpp new file mode 100644 index 000000000..1907be120 --- /dev/null +++ b/lib/Analysis/NoiseAnalysis/BGV/NoiseByBoundCoeffModel.cpp @@ -0,0 +1,268 @@ +#include "lib/Analysis/NoiseAnalysis/BGV/NoiseByBoundCoeffModel.h" + +#include +#include + +namespace mlir { +namespace heir { +namespace bgv { + +// the formulae below are mainly taken from KPZ21 +// "Revisiting Homomorphic Encryption Schemes for Finite Fields" +// https://ia.cr/2021/204 + +template +using Model = NoiseByBoundCoeffModel; + +template +double Model::toLogBound(const LocalParamType ¶m, + const StateType &noise) { + auto t = param.getSchemeParam()->getPlaintextModulus(); + // StateType only stores e in (m + t * e), so when we want to print the bound + // we need to multiply t back + return log(t * noise.getValue()) / log(2); +} + +template +double Model::toLogBudget(const LocalParamType ¶m, + const StateType &noise) { + return toLogTotal(param) - toLogBound(param, noise); +} + +template +double Model::toLogTotal(const LocalParamType ¶m) { + double total = 0; + auto logqi = param.getSchemeParam()->getLogqi(); + for (auto i = 0; i <= param.getCurrentLevel(); ++i) { + total += logqi[i]; + } + return total - 1.0; +} + +template +std::string Model::toLogBoundString(const LocalParamType ¶m, + const StateType &noise) { + auto logBound = toLogBound(param, noise); + std::stringstream stream; + stream << std::fixed << std::setprecision(2) << logBound; + return stream.str(); +} + +template +std::string Model::toLogBudgetString(const LocalParamType ¶m, + const StateType &noise) { + auto logBudget = toLogBudget(param, noise); + std::stringstream stream; + stream << std::fixed << std::setprecision(2) << logBudget; + return stream.str(); +} + +template +std::string Model::toLogTotalString(const LocalParamType ¶m) { + auto logTotal = toLogTotal(param); + std::stringstream stream; + stream << std::fixed << std::setprecision(2) << logTotal; + return stream.str(); +} + +template +double Model::getExpansionFactor(const LocalParamType ¶m) { + auto n = param.getSchemeParam()->getRingDim(); + if constexpr (W) { + // worst-case + // well known from DPSZ12 + return n; + } else { + // average-case + // experimental result + // cite HPS19 and KPZ21 + return 2.0 * sqrt(n); + } +} + +template +double Model::getBoundErr(const LocalParamType ¶m) { + auto std0 = param.getSchemeParam()->getStd0(); + // probability of larger than 6 * std0 is less than 2^{-30} + auto assurance = 6; + auto boundErr = std0 * assurance; + return boundErr; +} + +template +double Model::getBoundKey(const LocalParamType ¶m) { + // assume UNIFORM_TERNARY + auto boundKey = 1.0; + return boundKey; +} + +template +typename Model::StateType Model::evalEncryptPk( + const LocalParamType ¶m) { + auto boundErr = getBoundErr(param); + auto boundKey = getBoundKey(param); + auto expansionFactor = getExpansionFactor(param); + + // public key (-as + t * e, a) + // public key encryption (-aus + t(u * e + e_0) + m, au + e_1) + // m + t * (u * e + e_1 * s + e_0) + // v_fresh = u * e + e_1 * s + e_0 + double fresh = boundErr * (1. + 2. * expansionFactor * boundKey); + return StateType::of(fresh); +} + +template +typename Model::StateType Model::evalEncryptSk( + const LocalParamType ¶m) { + auto boundErr = getBoundErr(param); + + // secret key s + // secret key encryption (-as + m + t * e, a) + // v_fresh = e + double fresh = boundErr; + return StateType::of(fresh); +} + +template +typename Model::StateType Model::evalEncrypt( + const LocalParamType ¶m) { + // P stands for public key encryption + if constexpr (P) { + return evalEncryptPk(param); + } else { + return evalEncryptSk(param); + } +} + +template +typename Model::StateType Model::evalConstant( + const LocalParamType ¶m) { + // constant is m + t * 0 + // v_constant = 0 + return StateType::of(0); +} + +template +typename Model::StateType Model::evalAdd(const StateType &lhs, + const StateType &rhs) { + // m_0 + tv_0 + m_1 + tv_1 <= [m_0 + m_1]_t + t(v_0 + v_1 + u) + // v_add = v_0 + v_1 + u + // where ||u|| <= 1 + return StateType::of(lhs.getValue() + rhs.getValue() + 1); +} + +template +typename Model::StateType Model::evalMul( + const LocalParamType &resultParam, const StateType &lhs, + const StateType &rhs) { + auto t = resultParam.getSchemeParam()->getPlaintextModulus(); + auto expansionFactor = getExpansionFactor(resultParam); + + // (m_0 + tv_0) * (m_1 + tv_1) <= + // [m_0 * m_1]_t + t(v_0 * m_1 + v_1 * m_0 + v_0 * v_1 + r_m) + // where m_0 * m_1 = [m_0 * m_1]_t + tr_m + // ||r_m|| <= delta * t / 2, delta is the expansion factor + // v_mul = v_0 * m_1 + v_1 * m_0 + v_0 * v_1 + r_m + // ||v_mul|| <= + // (delta * t / 2) * (2 * ||v_0|| * ||v_1|| + ||v_0|| + ||v_1|| + 1) + return StateType::of((expansionFactor * t / 2) * + (lhs.getValue() * rhs.getValue() * 2 + lhs.getValue() + + rhs.getValue() + 1)); +} + +template +typename Model::StateType Model::evalModReduce( + const LocalParamType &inputParam, const StateType &input) { + auto cv = inputParam.getDimension(); + // for cv > 2 the rounding error term is different! + // like (tau_0, tau_1, tau_2) and the error becomes + // tau_0 + tau_1 s + tau_2 s^2 + assert(cv == 2); + + auto currentLogqi = + inputParam.getSchemeParam()->getLogqi()[inputParam.getCurrentLevel()]; + + double modulus = pow(2.0, currentLogqi); + + auto expansionFactor = getExpansionFactor(inputParam); + auto boundKey = getBoundKey(inputParam); + + // modulus switching is essentially a scaling operation + // so the original error is scaled by the modulus + // ||v_scaled|| = ||v_input|| / modulus + auto scaled = input.getValue() / modulus; + // in the meantime, it will introduce an rounding error + // (tau_0, tau_1) to the (ct_0, ct_1) where ||tau_i|| < t / 2 + // so ||tau_0 + tau_1 * s|| <= t / 2 (1 + delta ||s||) + // ||v_added|| <= (1 + delta * Bkey) / 2 + auto added = (1.0 + expansionFactor * boundKey) / 2; + return StateType::of(scaled + added); +} + +template +typename Model::StateType Model::evalRelinearizeHYBRID( + const LocalParamType &inputParam, const StateType &input) { + // for v_input, after modup and moddown, it remains the same (with rounding). + // We only need to consider the error from key switching key + // and rounding error during moddown. + // Check the B.1.3 section of KPZ21 for more details. + + // also note that for cv > 3 (e.g. multiplication), we need to relinearize + // more terms like ct_3 and ct_4. + // this is a common path for mult relinearize and rotation relinearize + // so no assertion here for now. + + auto dnum = inputParam.getSchemeParam()->getDnum(); + auto expansionFactor = getExpansionFactor(inputParam); + auto boundErr = getBoundErr(inputParam); + auto boundKey = getBoundKey(inputParam); + + auto currentLevel = inputParam.getCurrentLevel(); + // modup from Ql to QlP, so one more digit + auto currentNumDigit = ceil(static_cast(currentLevel + 1) / dnum) + 1; + + // log(qiq_{i+1}...), the digit size for a certain digit + // we use log(pip_{i+1}...) as an approximation, + // as we often choose P > each digit + auto logqi = inputParam.getSchemeParam()->getLogqi(); + auto logDigitSize = std::accumulate(logqi.begin(), logqi.end(), 0); + // omega in literature + auto digitSize = pow(2.0, logDigitSize); + + // the HYBRID key switching error is + // sum over all digit (ct_2 * e_ksk) + // there are "currentNumDigit" digits + // and ||c_2|| <= digitSize / 2 + // ||c_2 * e_ksk|| <= delta * digitSize * Berr / 2 + auto boundKeySwitch = + currentNumDigit * digitSize * expansionFactor * boundErr / 2.0; + + // moddown by P + auto scaled = boundKeySwitch / digitSize; + + // moddown added noise, similar to modreduce above. + auto added = (1.0 + expansionFactor * boundKey) / 2; + + // for relinearization after multiplication, often scaled + added is far less + // than input. + return StateType::of(input.getValue() + scaled + added); +} + +template +typename Model::StateType Model::evalRelinearize( + const LocalParamType &inputParam, const StateType &input) { + // assume HYBRID + // if we further introduce BV to SchemeParam we can have alternative + // implementation. + return evalRelinearizeHYBRID(inputParam, input); +} + +// instantiate template class +template class NoiseByBoundCoeffModel; +template class NoiseByBoundCoeffModel; +template class NoiseByBoundCoeffModel; +template class NoiseByBoundCoeffModel; + +} // namespace bgv +} // namespace heir +} // namespace mlir diff --git a/lib/Analysis/NoiseAnalysis/BGV/NoiseByBoundCoeffModel.h b/lib/Analysis/NoiseAnalysis/BGV/NoiseByBoundCoeffModel.h new file mode 100644 index 000000000..fb324b1ac --- /dev/null +++ b/lib/Analysis/NoiseAnalysis/BGV/NoiseByBoundCoeffModel.h @@ -0,0 +1,78 @@ +#ifndef INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISEBYBOUNDCOEFFMODEL_H_ +#define INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISEBYBOUNDCOEFFMODEL_H_ + +#include +#include +#include +#include +#include +#include + +#include "lib/Analysis/NoiseAnalysis/BGV/Noise.h" +#include "lib/Parameters/BGV/Params.h" + +namespace mlir { +namespace heir { +namespace bgv { + +// coefficient embedding noise model +// both average-case (from HPS19/KPZ21) and worst-case +// use template here just for the sake of code reuse +// W for worst-case, P for public key +template +class NoiseByBoundCoeffModel { + public: + // for KPZ21, NoiseState stores the bound ||e|| for error e + // instead of t * ||e||. + using StateType = NoiseState; + using SchemeParamType = SchemeParam; + using LocalParamType = LocalParam; + + private: + static double getExpansionFactor(const LocalParamType ¶m); + static double getBoundErr(const LocalParamType ¶m); + static double getBoundKey(const LocalParamType ¶m); + + static StateType evalEncryptPk(const LocalParamType ¶m); + static StateType evalEncryptSk(const LocalParamType ¶m); + static StateType evalRelinearizeHYBRID(const LocalParamType &inputParam, + const StateType &input); + + public: + static StateType evalEncrypt(const LocalParamType ¶m); + static StateType evalConstant(const LocalParamType ¶m); + static StateType evalAdd(const StateType &lhs, const StateType &rhs); + static StateType evalMul(const LocalParamType &resultParam, + const StateType &lhs, const StateType &rhs); + static StateType evalRelinearize(const LocalParamType &inputParam, + const StateType &input); + static StateType evalModReduce(const LocalParamType &inputParam, + const StateType &input); + + // logTotal: log(Ql / 2) + // logBound: bound on ||m + t * e|| predicted by the model + // logBudget: logTotal - logBound + // as ||m + t * e|| < Ql / 2 for correct decryption + static double toLogBound(const LocalParamType ¶m, const StateType &noise); + static std::string toLogBoundString(const LocalParamType ¶m, + const StateType &noise); + static double toLogBudget(const LocalParamType ¶m, + const StateType &noise); + static std::string toLogBudgetString(const LocalParamType ¶m, + const StateType &noise); + static double toLogTotal(const LocalParamType ¶m); + static std::string toLogTotalString(const LocalParamType ¶m); +}; + +// user-facing typedefs +using NoiseByBoundCoeffAverageCasePkModel = NoiseByBoundCoeffModel; +using NoiseByBoundCoeffWorstCasePkModel = NoiseByBoundCoeffModel; +using NoiseByBoundCoeffAverageCaseSkModel = + NoiseByBoundCoeffModel; +using NoiseByBoundCoeffWorstCaseSkModel = NoiseByBoundCoeffModel; + +} // namespace bgv +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISEBYBOUNDCOEFFMODEL_H_ diff --git a/lib/Analysis/NoiseAnalysis/BGV/NoiseByBoundCoeffModelAnalysis.cpp b/lib/Analysis/NoiseAnalysis/BGV/NoiseByBoundCoeffModelAnalysis.cpp new file mode 100644 index 000000000..11a566563 --- /dev/null +++ b/lib/Analysis/NoiseAnalysis/BGV/NoiseByBoundCoeffModelAnalysis.cpp @@ -0,0 +1,182 @@ +#include "lib/Analysis/DimensionAnalysis/DimensionAnalysis.h" +#include "lib/Analysis/LevelAnalysis/LevelAnalysis.h" +#include "lib/Analysis/NoiseAnalysis/BGV/NoiseByBoundCoeffModel.h" +#include "lib/Analysis/NoiseAnalysis/NoiseAnalysis.h" +#include "lib/Analysis/Utils.h" +#include "lib/Dialect/Mgmt/IR/MgmtOps.h" +#include "lib/Dialect/Secret/IR/SecretOps.h" +#include "lib/Dialect/TensorExt/IR/TensorExtOps.h" +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project + +#define DEBUG_TYPE "NoiseAnalysis" + +namespace mlir { +namespace heir { + +// explicit specialization of NoiseAnalysis for NoiseByBoundCoeffModel +template +void NoiseAnalysis::setToEntryState(LatticeType *lattice) { + // At an entry point, we have no information about the noise. + this->propagateIfChanged(lattice, lattice->join(NoiseState::uninitialized())); +} + +// explicit specialization of NoiseAnalysis for NoiseByBoundCoeffModel +template +void NoiseAnalysis::visitExternalCall( + CallOpInterface call, ArrayRef argumentLattices, + ArrayRef resultLattices) { + auto callback = + std::bind(&NoiseAnalysis::propagateIfChangedWrapper, this, + std::placeholders::_1, std::placeholders::_2); + ::mlir::heir::visitExternalCall( + call, argumentLattices, resultLattices, callback); +} + +// explicit specialization of NoiseAnalysis for NoiseByBoundCoeffModel +template +LogicalResult NoiseAnalysis::visitOperation( + Operation *op, ArrayRef operands, + ArrayRef results) { + auto getLocalParam = [&](Value value) { + auto level = getLevelFromMgmtAttr(value); + auto dimension = getDimensionFromMgmtAttr(value); + return LocalParamType(&schemeParam, level, dimension); + }; + + auto propagate = [&](Value value, NoiseState noise) { + auto localParam = getLocalParam(value); + + LLVM_DEBUG(llvm::dbgs() << "Propagating " + << NoiseModel::toLogBoundString(localParam, noise) + << " to " << value << "\n"); + LatticeType *lattice = this->getLatticeElement(value); + auto changeResult = lattice->join(noise); + this->propagateIfChanged(lattice, changeResult); + }; + + auto getOperandNoises = [&](Operation *op, + SmallVectorImpl &noises) { + SmallVector secretOperands; + SmallVector nonSecretOperands; + this->getSecretOperands(op, secretOperands); + this->getNonSecretOperands(op, nonSecretOperands); + + for (auto *operand : secretOperands) { + noises.push_back(this->getLatticeElement(operand->get())->getValue()); + } + for (auto *operand : nonSecretOperands) { + (void)operand; + // at least one operand is secret + auto localParam = getLocalParam(secretOperands[0]->get()); + noises.push_back(NoiseModel::evalConstant(localParam)); + } + }; + + auto res = + llvm::TypeSwitch(*op) + .Case([&](auto genericOp) { + Block *body = genericOp.getBody(); + for (Value &arg : body->getArguments()) { + auto localParam = getLocalParam(arg); + NoiseState encrypted = NoiseModel::evalEncrypt(localParam); + propagate(arg, encrypted); + } + return success(); + }) + .template Case([&](auto mulOp) { + SmallVector secretResults; + this->getSecretResults(mulOp, secretResults); + if (secretResults.empty()) { + return success(); + } + + SmallVector operandNoises; + getOperandNoises(mulOp, operandNoises); + + auto localParam = getLocalParam(mulOp.getResult()); + NoiseState mult = NoiseModel::evalMul(localParam, operandNoises[0], + operandNoises[1]); + propagate(mulOp.getResult(), mult); + return success(); + }) + .template Case([&](auto addOp) { + SmallVector secretResults; + this->getSecretResults(addOp, secretResults); + if (secretResults.empty()) { + return success(); + } + + SmallVector operandNoises; + getOperandNoises(addOp, operandNoises); + NoiseState add = + NoiseModel::evalAdd(operandNoises[0], operandNoises[1]); + propagate(addOp.getResult(), add); + return success(); + }) + .template Case([&](auto rotateOp) { + // implicitly assumed secret + auto localParam = getLocalParam(rotateOp.getOperand(0)); + + // assume relinearize immediately after rotate + // when we support hoisting relinearize, we need to change + // this + NoiseState rotate = NoiseModel::evalRelinearize( + localParam, operands[0]->getValue()); + propagate(rotateOp.getResult(), rotate); + return success(); + }) + // NOTE: special case for ExtractOp... it is a mulconst+rotate + // if not annotated with slot_extract + // TODO(#1174): decide packing earlier in the pipeline instead of + // annotation + .template Case([&](auto extractOp) { + auto localParam = getLocalParam(extractOp.getOperand(0)); + + // extract = mul_plain 1 + rotate + // although the cleartext is 1, when encoded (i.e. CRT + // packing), the value multiplied to the ciphertext is not 1, + // If we can know the encoded value, we can bound it more precisely. + NoiseState one = NoiseModel::evalConstant(localParam); + NoiseState extract = + NoiseModel::evalMul(localParam, operands[0]->getValue(), one); + // assume relinearize immediately after rotate + // when we support hoisting relinearize, we need to change + // this + NoiseState rotate = + NoiseModel::evalRelinearize(localParam, extract); + propagate(extractOp.getResult(), extract); + return success(); + }) + .template Case([&](auto modReduceOp) { + auto localParam = getLocalParam(modReduceOp.getInput()); + + NoiseState modReduce = + NoiseModel::evalModReduce(localParam, operands[0]->getValue()); + propagate(modReduceOp.getResult(), modReduce); + return success(); + }) + .template Case([&](auto relinearizeOp) { + auto localParam = getLocalParam(relinearizeOp.getInput()); + + NoiseState relinearize = NoiseModel::evalRelinearize( + localParam, operands[0]->getValue()); + propagate(relinearizeOp.getResult(), relinearize); + return success(); + }) + .Default([&](auto &op) { return success(); }); + return res; +} + +// template instantiation +template class NoiseAnalysis; +template class NoiseAnalysis; +template class NoiseAnalysis; +template class NoiseAnalysis; + +} // namespace heir +} // namespace mlir diff --git a/lib/Analysis/NoiseAnalysis/BUILD b/lib/Analysis/NoiseAnalysis/BUILD new file mode 100644 index 000000000..7b11aab6f --- /dev/null +++ b/lib/Analysis/NoiseAnalysis/BUILD @@ -0,0 +1,18 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "NoiseAnalysis", + srcs = [], + hdrs = [ + "NoiseAnalysis.h", + ], + deps = [ + "@heir//lib/Analysis/SecretnessAnalysis", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + ], +) diff --git a/lib/Analysis/NoiseAnalysis/NoiseAnalysis.h b/lib/Analysis/NoiseAnalysis/NoiseAnalysis.h new file mode 100644 index 000000000..937122359 --- /dev/null +++ b/lib/Analysis/NoiseAnalysis/NoiseAnalysis.h @@ -0,0 +1,64 @@ +#ifndef INCLUDE_ANALYSIS_NOISEANALYSIS_NOISEANALYSIS_H_ +#define INCLUDE_ANALYSIS_NOISEANALYSIS_NOISEANALYSIS_H_ + +#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h" +#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project + +namespace mlir { +namespace heir { + +/// This lattice element represents the noise data of an SSA value. +template +class NoiseLattice : public dataflow::Lattice { + public: + using dataflow::Lattice::Lattice; +}; + +/// This analysis template takes a noise model as argument and computes the +/// noise data for each SSA value in the program. The exact instantiations of +/// member functions are dependent on the noise model hence explicit +/// specialization is required for each noise model. +template +class NoiseAnalysis + : public dataflow::SparseForwardDataFlowAnalysis< + NoiseLattice>, + public SecretnessAnalysisDependent> { + public: + friend class SecretnessAnalysisDependent>; + + using NoiseModel = NoiseModelT; + using NoiseState = typename NoiseModelT::StateType; + using LatticeType = NoiseLattice; + using SchemeParamType = typename NoiseModelT::SchemeParamType; + using LocalParamType = typename NoiseModelT::LocalParamType; + + using dataflow::SparseForwardDataFlowAnalysis< + LatticeType>::SparseForwardDataFlowAnalysis; + + NoiseAnalysis(DataFlowSolver &solver, const SchemeParamType &schemeParam) + : dataflow::SparseForwardDataFlowAnalysis(solver), + schemeParam(schemeParam) {} + + void setToEntryState(LatticeType *lattice) override; + + LogicalResult visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override; + + void visitExternalCall(CallOpInterface call, + ArrayRef argumentLattices, + ArrayRef resultLattices) override; + + void propagateIfChangedWrapper(AnalysisState *state, ChangeResult changed) { + this->propagateIfChanged(state, changed); + } + + private: + const SchemeParamType schemeParam; +}; + +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_ANALYSIS_NOISEANALYSIS_NOISEANALYSIS_H_ diff --git a/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h b/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h index c3fc94c89..6482f65df 100644 --- a/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h +++ b/lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h @@ -190,6 +190,25 @@ class SecretnessAnalysisDependent { } } } + + /** + * @brief Selects the OpOperands of an operation that are not secret + * (secretness = false or unknown). + * + * This method iterates through the operands of the given operation and adds + * those that are not secret to the provided vector. + * + * @param op The operation to analyze. + * @param nonSecretOperands A vector to store the non-secret operands. + */ + void getNonSecretOperands(Operation *op, + SmallVectorImpl &nonSecretOperands) { + for (auto &operand : op->getOpOperands()) { + if (!isSecretInternal(op, operand.get())) { + nonSecretOperands.push_back(&operand); + } + } + } }; // Annotate the secretness of operation based on the secretness of its results diff --git a/lib/Parameters/BGV/BUILD b/lib/Parameters/BGV/BUILD new file mode 100644 index 000000000..cf4f46a64 --- /dev/null +++ b/lib/Parameters/BGV/BUILD @@ -0,0 +1,13 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Params", + srcs = ["Params.cpp"], + hdrs = ["Params.h"], + deps = [ + "@llvm-project//llvm:Support", + ], +) diff --git a/lib/Parameters/BGV/Params.cpp b/lib/Parameters/BGV/Params.cpp new file mode 100644 index 000000000..eb9e672ab --- /dev/null +++ b/lib/Parameters/BGV/Params.cpp @@ -0,0 +1,88 @@ +#include "lib/Parameters/BGV/Params.h" + +#include +#include +#include + +namespace mlir { +namespace heir { +namespace bgv { + +// struct for recording the maximal Q for each ring dim +// under certain security condition. +struct RLWESecurityParam { + int ringDim; + int logMaxQ; +}; + +// 128-bit classic security for uniform ternary secret distribution +// taken from the "Homomorphic Encryption Standard" Preprint +// https://ia.cr/2019/939 +// logMaxQ for 65536/131072 taken from OpenFHE +// https://github.com/openfheorg/openfhe-development/blob/7b8346f4eac27121543e36c17237b919e03ec058/src/core/lib/lattice/stdlatticeparms.cpp#L187 +static struct RLWESecurityParam HEStd_128_classic[] = { + {1024, 27}, {2048, 54}, {4096, 109}, {8192, 218}, + {16384, 438}, {32768, 881}, {65536, 1747}, {131072, 3523}}; + +// from OpenFHE +int computeDnum(int level) { + if (level > 3) { + return 3; + } + if (level > 0) { + return 2; + } + return 1; +} + +int computeRingDim(int logTotalPQ) { + for (auto ¶m : HEStd_128_classic) { + if (param.logMaxQ >= logTotalPQ) { + return param.ringDim; + } + } + assert(false && "Failed to find ring dimension, level too large"); + return 0; +} + +SchemeParam SchemeParam::getConservativeSchemeParam(int level, + int64_t plaintextModulus) { + auto logModuli = 60; // assume all 60 bit moduli + auto dnum = computeDnum(level); + std::vector logqi(level + 1, logModuli); + std::vector logpi(ceil(static_cast(logqi.size()) / dnum), + logModuli); + + auto totalQP = logModuli * (logqi.size() + logpi.size()); + + auto ringDim = computeRingDim(totalQP); + + return SchemeParam(ringDim, plaintextModulus, level, logqi, dnum, logpi); +} + +void SchemeParam::print(llvm::raw_ostream &os) const { + auto doubleToString = [](double d) { + std::stringstream stream; + stream << std::fixed << std::setprecision(2) << d; + return stream.str(); + }; + + os << "ringDim: " << ringDim << "\n"; + os << "plaintextModulus: " << plaintextModulus << "\n"; + os << "level: " << level << "\n"; + os << "logqi: "; + for (auto qi : logqi) { + os << doubleToString(qi) << " "; + } + os << "\n"; + os << "dnum: " << dnum << "\n"; + os << "logpi: "; + for (auto pi : logpi) { + os << doubleToString(pi) << " "; + } + os << "\n"; +} + +} // namespace bgv +} // namespace heir +} // namespace mlir diff --git a/lib/Parameters/BGV/Params.h b/lib/Parameters/BGV/Params.h new file mode 100644 index 000000000..2071af620 --- /dev/null +++ b/lib/Parameters/BGV/Params.h @@ -0,0 +1,98 @@ +#ifndef LIB_PARAMETERS_BGV_PARAMS_H_ +#define LIB_PARAMETERS_BGV_PARAMS_H_ + +#include +#include + +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace bgv { + +// Parameter for BGV scheme at ModuleOp level +class SchemeParam { + public: + SchemeParam(int ringDim, int64_t plaintextModulus, int level, + const std::vector &logqi, int dnum, + const std::vector &logpi) + : ringDim(ringDim), + plaintextModulus(plaintextModulus), + level(level), + logqi(logqi), + dnum(dnum), + logpi(logpi) {} + + private: + // the N in Z[X]/(X^N+1) + int ringDim; + + // the plaintext modulus for BGV + int64_t plaintextModulus; + + // the standard deviation of the error distribution + double std0 = 3.2; + + // RNS level, from 0 to L + int level; + + // logarithm of the modulus of each level + // logqi.size() == level + 1 + std::vector logqi; + + // The following part is for HYBRID key switching technique + + // number of digits + // In HYBRID, we decompose Q into `dnum` digits + // for example, when Q consists of q0, q1, q2, q3 and dnum = 2, + // we have two digits: q0q1 and q2q3 + int dnum; + // logarithm of the special modulus + std::vector logpi; + + public: + int getRingDim() const { return ringDim; } + int64_t getPlaintextModulus() const { return plaintextModulus; } + int getLevel() const { return level; } + const std::vector &getLogqi() const { return logqi; } + int getDnum() const { return dnum; } + const std::vector &getLogpi() const { return logpi; } + double getStd0() const { return std0; } + + void print(llvm::raw_ostream &os) const; + + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const SchemeParam ¶m) { + param.print(os); + return os; + } + + static SchemeParam getConservativeSchemeParam(int level, + int64_t plaintextModulus); +}; + +// Parameter for each BGV ciphertext SSA value. +class LocalParam { + public: + LocalParam(const SchemeParam *schemeParam, int currentLevel, int dimension) + : schemeParam(schemeParam), + currentLevel(currentLevel), + dimension(dimension) {} + + private: + const SchemeParam *schemeParam; + int currentLevel; + int dimension; + + public: + const SchemeParam *getSchemeParam() const { return schemeParam; } + + int getCurrentLevel() const { return currentLevel; } + int getDimension() const { return dimension; } +}; + +} // namespace bgv +} // namespace heir +} // namespace mlir + +#endif // LIB_PARAMETERS_BGV_PARAMS_H_ diff --git a/lib/Transforms/ValidateNoise/BUILD b/lib/Transforms/ValidateNoise/BUILD new file mode 100644 index 000000000..dfe441228 --- /dev/null +++ b/lib/Transforms/ValidateNoise/BUILD @@ -0,0 +1,33 @@ +load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "ValidateNoise", + srcs = ["ValidateNoise.cpp"], + hdrs = ["ValidateNoise.h"], + deps = [ + ":pass_inc_gen", + "@heir//lib/Analysis/DimensionAnalysis", + "@heir//lib/Analysis/LevelAnalysis", + "@heir//lib/Analysis/NoiseAnalysis", + "@heir//lib/Analysis/NoiseAnalysis/BGV:NoiseByBoundCoeffModel", + "@heir//lib/Dialect/Secret/IR:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) + +add_heir_transforms( + generated_target_name = "pass_inc_gen", + pass_name = "ValidateNoise", + td_file = "ValidateNoise.td", +) diff --git a/lib/Transforms/ValidateNoise/ValidateNoise.cpp b/lib/Transforms/ValidateNoise/ValidateNoise.cpp new file mode 100644 index 000000000..ad98967c7 --- /dev/null +++ b/lib/Transforms/ValidateNoise/ValidateNoise.cpp @@ -0,0 +1,169 @@ +#include "lib/Transforms/ValidateNoise/ValidateNoise.h" + +#include "lib/Analysis/DimensionAnalysis/DimensionAnalysis.h" +#include "lib/Analysis/LevelAnalysis/LevelAnalysis.h" +#include "lib/Analysis/NoiseAnalysis/BGV/NoiseByBoundCoeffModel.h" +#include "lib/Analysis/NoiseAnalysis/NoiseAnalysis.h" +#include "lib/Dialect/Secret/IR/SecretOps.h" +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project + +#define DEBUG_TYPE "ValidateNoise" + +namespace mlir { +namespace heir { + +#define GEN_PASS_DEF_VALIDATENOISE +#include "lib/Transforms/ValidateNoise/ValidateNoise.h.inc" + +struct ValidateNoise : impl::ValidateNoiseBase { + using ValidateNoiseBase::ValidateNoiseBase; + + // assume only one main func + // also assume max level at entry + // also assume first genericOp arg is secret + int getMaxLevel() { + int maxLevel = 0; + getOperation()->walk([&](func::FuncOp funcOp) { + funcOp->walk([&](secret::GenericOp genericOp) { + for (Value arg : genericOp.getBody()->getArguments()) { + maxLevel = getLevelFromMgmtAttr(arg); + break; + } + }); + }); + return maxLevel; + } + + template + LogicalResult validateNoiseForValue( + Value value, DataFlowSolver *solver, + const typename NoiseAnalysis::SchemeParamType &schemeParam) { + using NoiseModel = typename NoiseAnalysis::NoiseModel; + using NoiseLatticeType = typename NoiseAnalysis::LatticeType; + using LocalParamType = typename NoiseAnalysis::LocalParamType; + + auto getLocalParam = [&](Value value) { + auto level = getLevelFromMgmtAttr(value); + auto dimension = getDimensionFromMgmtAttr(value); + return LocalParamType(&schemeParam, level, dimension); + }; + + auto secretness = isSecret(value, solver); + if (!secretness) { + return success(); + } + + const auto *noiseLattice = solver->lookupState(value); + if (!noiseLattice || !noiseLattice->getValue().isInitialized()) { + return failure(); + } + + auto noiseState = noiseLattice->getValue(); + auto localParam = getLocalParam(value); + + auto budget = NoiseModel::toLogBudget(localParam, noiseState); + + LLVM_DEBUG({ + auto boundString = NoiseModel::toLogBoundString(localParam, noiseState); + auto budgetString = NoiseModel::toLogBudgetString(localParam, noiseState); + auto totalString = NoiseModel::toLogTotalString(localParam); + llvm::dbgs() << "Noise Bound: " << boundString + << " Budget: " << budgetString << " Total: " << totalString + << " for value: " << value << " " << "\n"; + }); + + if (budget < 0) { + return failure(); + } + + return success(); + } + + template + LogicalResult validate( + DataFlowSolver *solver, + const typename NoiseAnalysis::SchemeParamType &schemeParam) { + auto res = getOperation()->walk([&](secret::GenericOp genericOp) { + // check arguments + for (Value arg : genericOp.getBody()->getArguments()) { + if (failed(validateNoiseForValue(arg, solver, + schemeParam))) { + return WalkResult::interrupt(); + } + } + + // check each operation + // TODO(#1181): handle region bearing ops + return genericOp.getBody()->walk([&](Operation *op) { + for (Value result : op->getResults()) { + if (failed(validateNoiseForValue(result, solver, + schemeParam))) { + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + }); + if (res == WalkResult::interrupt()) { + return failure(); + } + return success(); + } + + template + void run() { + DataFlowSolver solver; + solver.load(); + solver.load(); + // NoiseAnalysis depends on SecretnessAnalysis + solver.load(); + + int maxLevel = getMaxLevel(); + + // plaintext modulus from command line option + auto schemeParam = + NoiseAnalysis::SchemeParamType::getConservativeSchemeParam( + maxLevel, plaintextModulus); + + LLVM_DEBUG(llvm::dbgs() << "Conservative Scheme Param:\n" + << schemeParam << "\n"); + + solver.load(schemeParam); + + if (failed(solver.initializeAndRun(getOperation()))) { + getOperation()->emitOpError() << "Failed to run the analysis.\n"; + signalPassFailure(); + return; + } + + if (failed(validate(&solver, schemeParam))) { + getOperation()->emitOpError() << "Noise validation failed.\n"; + signalPassFailure(); + return; + } + } + + void runOnOperation() override { + if (model == "bgv-noise-by-bound-coeff-worst-case-pk") { + run>(); + } else if (model == "bgv-noise-by-bound-coeff-average-case-pk") { + run>(); + } else if (model == "bgv-noise-by-bound-coeff-worst-case-sk") { + run>(); + } else if (model == "bgv-noise-by-bound-coeff-average-case-sk") { + run>(); + } else { + getOperation()->emitOpError() << "Unknown noise model.\n"; + signalPassFailure(); + return; + } + } +}; + +} // namespace heir +} // namespace mlir diff --git a/lib/Transforms/ValidateNoise/ValidateNoise.h b/lib/Transforms/ValidateNoise/ValidateNoise.h new file mode 100644 index 000000000..7bc2f3613 --- /dev/null +++ b/lib/Transforms/ValidateNoise/ValidateNoise.h @@ -0,0 +1,18 @@ +#ifndef LIB_TRANSFORMS_VALIDATENOISE_VALIDATENOISE_H_ +#define LIB_TRANSFORMS_VALIDATENOISE_VALIDATENOISE_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DECL +#include "lib/Transforms/ValidateNoise/ValidateNoise.h.inc" + +#define GEN_PASS_REGISTRATION +#include "lib/Transforms/ValidateNoise/ValidateNoise.h.inc" + +} // namespace heir +} // namespace mlir + +#endif // LIB_TRANSFORMS_VALIDATENOISE_VALIDATENOISE_H_ diff --git a/lib/Transforms/ValidateNoise/ValidateNoise.td b/lib/Transforms/ValidateNoise/ValidateNoise.td new file mode 100644 index 000000000..8a5a087fa --- /dev/null +++ b/lib/Transforms/ValidateNoise/ValidateNoise.td @@ -0,0 +1,45 @@ +#ifndef LIB_TRANSFORMS_VALIDATENOISE_VALIDATENOISE_TD_ +#define LIB_TRANSFORMS_VALIDATENOISE_VALIDATENOISE_TD_ + +include "mlir/Pass/PassBase.td" + +def ValidateNoise : Pass<"validate-noise"> { + let summary = "Validate the HE circuit against a given noise model"; + let description = [{ + This pass validates the noise of the HE circuit against a given noise model. + + Currently the pass works for BGV scheme, and there are two noise models + available: "bgv-noise-by-bound-coeff-average-case{-pk,-sk}" and + "bgv-noise-by-bound-coeff-worst-case{-pk,-sk}". + + The two models are taken from KPZ21, and they work by bounding + the coefficient embedding of the ciphertexts. The difference + of the two models is expansion factor used for multiplication + of the coefficients, the first being `2\sqrt{N}` and the second + being `N`. The `-pk`/`-sk` suffixes assume the input ciphertexts are + encrypted using the public/secret key. + + This pass is experimental. The result should be observed using + --debug-only=ValidateNoise. + + This pass relies on the presence of the `mgmt` dialect ops to model + relinearize/modreduce, and it relies on `mgmt.mgmt` attribute to determine + the ciphertext level/dimension. These ops and attributes can be added by + a pass like `--secret-insert-mgmt-` and `--annotate-mgmt`. + + Example + ```bash + # with commandline --debug-only=ValidateNoise + Noise Bound: 29.27 Budget: 149.73 Total: 179.00 for value: of type 'tensor<8xi16>' at index: 0 + Noise Bound: 29.27 Budget: 149.73 Total: 179.00 for value: of type 'tensor<8xi16>' at index: 1 + ``` + }]; + let options = [ + Option<"model", "model", "std::string", + /*default=*/"", "Noise model to validate against.">, + Option<"plaintextModulus", "plaintext-modulus", "int64_t", + /*default=*/"65537", "Plaintext modulus.">, + ]; +} + +#endif // LIB_TRANSFORMS_VALIDATENOISE_VALIDATENOISE_TD_ diff --git a/tests/Transforms/validate-noise/BUILD b/tests/Transforms/validate-noise/BUILD new file mode 100644 index 000000000..c571e6fc6 --- /dev/null +++ b/tests/Transforms/validate-noise/BUILD @@ -0,0 +1,10 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package(default_applicable_licenses = ["@heir//:license"]) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/Transforms/validate-noise/validate_noise_fail.mlir b/tests/Transforms/validate-noise/validate_noise_fail.mlir new file mode 100644 index 000000000..7c8fa1f6e --- /dev/null +++ b/tests/Transforms/validate-noise/validate_noise_fail.mlir @@ -0,0 +1,25 @@ +// RUN: heir-opt --mlir-to-secret-arithmetic --secret-insert-mgmt-bgv --validate-noise="model=bgv-noise-by-bound-coeff-worst-case-pk plaintext-modulus=4295294977" --verify-diagnostics %s + +// This is only for testing whether validate-noise would fail, but +// not for testing the noise model. +// This failing example is handcrafted to fail the noise validation +// using the following conditions: +// + plaintext modulus is large +// + modulus is 60bit +// + consecutive multiplications +// + using worst-case noise model +// + When plaintext modulus is large, modulus of 60 bit is not enough +// Note that if any condition changed the test may fail and changes +// to this file are expected + +// expected-error@below {{'builtin.module' op Noise validation failed.}} +module { + func.func @dot_product(%arg0: i16 {secret.secret}) -> i16 { + %0 = arith.muli %arg0, %arg0 : i16 + %1 = arith.muli %0, %0 : i16 + %2 = arith.muli %1, %1 : i16 + %3 = arith.muli %2, %2 : i16 + %4 = arith.muli %3, %3 : i16 + return %4 : i16 + } +} diff --git a/tests/Transforms/validate-noise/validate_noise_pass.mlir b/tests/Transforms/validate-noise/validate_noise_pass.mlir new file mode 100644 index 000000000..134b322a7 --- /dev/null +++ b/tests/Transforms/validate-noise/validate_noise_pass.mlir @@ -0,0 +1,15 @@ +// RUN: heir-opt --mlir-to-secret-arithmetic --secret-insert-mgmt-bgv --validate-noise=model=bgv-noise-by-bound-coeff-average-case-pk %s | FileCheck %s + +// CHECK-LABEL: @dot_product +func.func @dot_product(%arg0: tensor<8xi16> {secret.secret}, %arg1: tensor<8xi16> {secret.secret}) -> i16 { + %c0 = arith.constant 0 : index + %c0_si16 = arith.constant 0 : i16 + %0 = affine.for %arg2 = 0 to 8 iter_args(%iter = %c0_si16) -> (i16) { + %1 = tensor.extract %arg0[%arg2] : tensor<8xi16> + %2 = tensor.extract %arg1[%arg2] : tensor<8xi16> + %3 = arith.muli %1, %2 : i16 + %4 = arith.addi %iter, %3 : i16 + affine.yield %4 : i16 + } + return %0 : i16 +} diff --git a/tools/BUILD b/tools/BUILD index 7f931bbe0..04093b3b2 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -114,6 +114,7 @@ cc_binary( "@heir//lib/Transforms/StraightLineVectorizer", "@heir//lib/Transforms/TensorToScalars", "@heir//lib/Transforms/UnusedMemRef", + "@heir//lib/Transforms/ValidateNoise", "@heir//lib/Utils/Tablegen:AsmInterfaces", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 03102ccf6..980181b08 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -72,6 +72,7 @@ #include "lib/Transforms/StraightLineVectorizer/StraightLineVectorizer.h" #include "lib/Transforms/TensorToScalars/TensorToScalars.h" #include "lib/Transforms/UnusedMemRef/UnusedMemRef.h" +#include "lib/Transforms/ValidateNoise/ValidateNoise.h" #include "lib/Utils/Tablegen/AsmInterfaces.h" #include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project @@ -278,6 +279,7 @@ int main(int argc, char **argv) { registerOperationBalancerPasses(); registerStraightLineVectorizerPasses(); registerUnusedMemRefPasses(); + registerValidateNoisePasses(); registerOptimizeRelinearizationPasses(); registerLayoutPropagationPasses(); registerLinalgCanonicalizationsPasses();