Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce NoiseAnalysis Framework #1343

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions lib/Analysis/DimensionAnalysis/DimensionAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Analysis/Utils.h"
#include "lib/Dialect/Mgmt/IR/MgmtAttributes.h"
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
Expand Down Expand Up @@ -106,6 +107,26 @@ int getDimension(Value value, DataFlowSolver *solver) {
return lattice->getValue().getDimension();
}

int getDimensionFromMgmtAttr(Value value) {
Attribute attr;
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
auto *parentOp = blockArg.getOwner()->getParentOp();
auto genericOp = dyn_cast<secret::GenericOp>(parentOp);
if (genericOp) {
attr = genericOp.getArgAttr(blockArg.getArgNumber(),
mgmt::MgmtDialect::kArgMgmtAttrName);
}
} else {
auto *parentOp = value.getDefiningOp();
attr = parentOp->getAttr(mgmt::MgmtDialect::kArgMgmtAttrName);
}
if (!mlir::isa<mgmt::MgmtAttr>(attr)) {
assert(false && "MgmtAttr not found");
}
auto mgmtAttr = mlir::cast<mgmt::MgmtAttr>(attr);
return mgmtAttr.getDimension();
}

void annotateDimension(Operation *top, DataFlowSolver *solver) {
auto getIntegerAttr = [&](int dimension) {
return IntegerAttr::get(IntegerType::get(top->getContext(), 64), dimension);
Expand Down
2 changes: 2 additions & 0 deletions lib/Analysis/DimensionAnalysis/DimensionAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class DimensionAnalysis
// initialized
DimensionState::DimensionType getDimension(Value value, DataFlowSolver *solver);

DimensionState::DimensionType getDimensionFromMgmtAttr(Value value);

void annotateDimension(Operation *top, DataFlowSolver *solver);

} // namespace heir
Expand Down
26 changes: 22 additions & 4 deletions lib/Analysis/LevelAnalysis/LevelAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@

#include "lib/Analysis/SecretnessAnalysis/SecretnessAnalysis.h"
#include "lib/Analysis/Utils.h"
#include "lib/Dialect/Mgmt/IR/MgmtAttributes.h"
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {
Expand Down Expand Up @@ -142,5 +140,25 @@ void annotateLevel(Operation *top, DataFlowSolver *solver) {
});
}

LevelState::LevelType getLevelFromMgmtAttr(Value value) {
Attribute attr;
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
auto *parentOp = blockArg.getOwner()->getParentOp();
auto genericOp = dyn_cast<secret::GenericOp>(parentOp);
if (genericOp) {
attr = genericOp.getArgAttr(blockArg.getArgNumber(),
mgmt::MgmtDialect::kArgMgmtAttrName);
}
} else {
auto *parentOp = value.getDefiningOp();
attr = parentOp->getAttr(mgmt::MgmtDialect::kArgMgmtAttrName);
}
if (!mlir::isa<mgmt::MgmtAttr>(attr)) {
assert(false && "MgmtAttr not found");
}
auto mgmtAttr = mlir::cast<mgmt::MgmtAttr>(attr);
return mgmtAttr.getLevel();
}

} // namespace heir
} // namespace mlir
2 changes: 2 additions & 0 deletions lib/Analysis/LevelAnalysis/LevelAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class LevelAnalysis
}
};

LevelState::LevelType getLevelFromMgmtAttr(Value value);

void annotateLevel(Operation *top, DataFlowSolver *solver);

} // namespace heir
Expand Down
40 changes: 40 additions & 0 deletions lib/Analysis/NoiseAnalysis/BGV/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "NoiseByBoundCoeffModel",
srcs = [
"NoiseByBoundCoeffModel.cpp",
"NoiseByBoundCoeffModelAnalysis.cpp",
],
hdrs = [
"NoiseByBoundCoeffModel.h",
],
deps = [
":Noise",
"@heir//lib/Analysis/DimensionAnalysis",
"@heir//lib/Analysis/LevelAnalysis",
"@heir//lib/Analysis/NoiseAnalysis",
"@heir//lib/Dialect/Mgmt/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@heir//lib/Parameters/BGV:Params",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "Noise",
srcs = ["Noise.cpp"],
hdrs = [
"Noise.h",
],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
30 changes: 30 additions & 0 deletions lib/Analysis/NoiseAnalysis/BGV/Noise.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "lib/Analysis/NoiseAnalysis/BGV/Noise.h"

#include <cmath>

#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace bgv {

std::string NoiseState::toString() const {
switch (noiseType) {
case (NoiseType::UNINITIALIZED):
return "NoiseState(uninitialized)";
case (NoiseType::SET):
return "NoiseState(" + std::to_string(log(getValue()) / log(2)) + ") ";
}
}

llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const NoiseState &noise) {
return os << noise.toString();
}

Diagnostic &operator<<(Diagnostic &diagnostic, const NoiseState &noise) {
return diagnostic << noise.toString();
}

} // namespace bgv
} // namespace heir
} // namespace mlir
88 changes: 88 additions & 0 deletions lib/Analysis/NoiseAnalysis/BGV/Noise.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#ifndef INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISE_H_
#define INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISE_H_

#include <optional>

#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace bgv {

// This class could be shared among all noise models that tracks the noise by a
// single value. Noise model could have different interpretation of the value.
// In BGV world, most noise model just use a single value, either as bound or
// as variance.
class NoiseState {
public:
enum NoiseType {
// A min value for the lattice, discarable when joined with anything else.
UNINITIALIZED,
// A known value for the lattice, when noise can be inferred.
SET,
};

static NoiseState uninitialized() {
return NoiseState(NoiseType::UNINITIALIZED, std::nullopt);
}
static NoiseState of(double value) {
return NoiseState(NoiseType::SET, value);
}

/// Create an integer value range lattice value.
/// The default constructor must be equivalent to the "entry state" of the
/// lattice, i.e., an uninitialized noise.
NoiseState(NoiseType noiseType = NoiseType::UNINITIALIZED,
std::optional<double> value = std::nullopt)
: noiseType(noiseType), value(value) {}

bool isKnown() const { return noiseType == NoiseType::SET; }

bool isInitialized() const { return noiseType != NoiseType::UNINITIALIZED; }

const double &getValue() const {
assert(isKnown());
return *value;
}

bool operator==(const NoiseState &rhs) const {
return noiseType == rhs.noiseType && value == rhs.value;
}

static NoiseState join(const NoiseState &lhs, const NoiseState &rhs) {
// Uninitialized noises correspond to values that are not secret,
// which may be the inputs to an encryption operation.
if (lhs.noiseType == NoiseType::UNINITIALIZED) {
return rhs;
}
if (rhs.noiseType == NoiseType::UNINITIALIZED) {
return lhs;
}

assert(lhs.noiseType == NoiseType::SET && rhs.noiseType == NoiseType::SET);
return NoiseState::of(std::max(lhs.getValue(), rhs.getValue()));
}

void print(llvm::raw_ostream &os) const { os << value; }

std::string toString() const;

friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const NoiseState &noise);

friend Diagnostic &operator<<(Diagnostic &diagnostic,
const NoiseState &noise);

private:
NoiseType noiseType;
// notice that when level becomes large (e.g. 17), the max Q could be like
// 3523 bits and could not be represented in double.
std::optional<double> value;
};

} // namespace bgv
} // namespace heir
} // namespace mlir

#endif // INCLUDE_ANALYSIS_NOISEANALYSIS_BGV_NOISE_H_
Loading
Loading