Skip to content

Commit

Permalink
Imported existing methods from decision_tree
Browse files Browse the repository at this point in the history
  • Loading branch information
sandy9999 committed Jul 16, 2024
1 parent 8d106ed commit 8240527
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 109 deletions.
91 changes: 2 additions & 89 deletions Compiler/decision_tree_optimized.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from Compiler.types import *
from Compiler.sorting import *
from Compiler.library import *
from Compiler.decision_tree import PrefixSum, PrefixSumR, PrefixSum_inv, PrefixSumR_inv, SortPerm, GroupSum, GroupPrefixSum, GroupFirstOne
from Compiler.decision_tree import get_type, PrefixSum, PrefixSumR, PrefixSum_inv, PrefixSumR_inv, SortPerm, GroupSum, GroupPrefixSum, GroupFirstOne, output_decision_tree, pick, run_decision_tree, test_decision_tree
from Compiler import util, oram

from itertools import accumulate
Expand All @@ -11,18 +11,6 @@
debug_split = False
max_leaves = None

def get_type(x):
if isinstance(x, (Array, SubMultiArray)):
return x.value_type
elif isinstance(x, (tuple, list)):
x = x[0] + x[-1]
if util.is_constant(x):
return cint
else:
return type(x)
else:
return type(x)

def GetSortPerm(keys, *to_sort, n_bits=None, time=False):
"""
Compute and return secret shared permutation that stably sorts :param keys.
Expand All @@ -36,7 +24,7 @@ def GetSortPerm(keys, *to_sort, n_bits=None, time=False):
res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x
for x in to_sort)
res = res.transpose()
return radix_sort_permutation_from_matrix(bs, res)
return radix_sort_from_matrix(bs, res)

def ApplyPermutation(perm, x):
res = Array.create_from(x)
Expand Down Expand Up @@ -374,81 +362,6 @@ def get_tree(self, h, Label):
def DecisionTreeTraining(x, y, h, binary=False):
return TreeTrainer(x, y, h, binary=binary).train()

def output_decision_tree(layers):
""" Print decision tree output by :py:class:`TreeTrainer`. """

print_ln('full model %s', util.reveal(layers))
for i, layer in enumerate(layers[:-1]):
print_ln('level %s:', i)
for j, x in enumerate(('NID', 'AID', 'Thr')):
print_ln(' %s: %s', x, util.reveal(layer[j]))
print_ln('leaves:')
for j, x in enumerate(('NID', 'result')):
print_ln(' %s: %s', x, util.reveal(layers[-1][j]))

def pick(bits, x):
if len(bits) == 1:
return bits[0] * x[0]
else:
try:
return x[0].dot_product(bits, x)
except:
return sum(aa * bb for aa, bb in zip(bits, x))

def run_decision_tree(layers, data):
""" Run decision tree against sample data.
:param layers: tree output by :py:class:`TreeTrainer`
:param data: sample data (:py:class:`~Compiler.types.Array`)
:returns: binary label
"""
h = len(layers) - 1
index = 1
for k, layer in enumerate(layers[:-1]):
assert len(layer) == 3
for x in layer:
assert len(x) <= 2 ** k
bits = layer[0].equal(index, k)
threshold = pick(bits, layer[2])
key_index = pick(bits, layer[1])
if key_index.is_clear:
key = data[key_index]
else:
key = pick(
oram.demux(key_index.bit_decompose(util.log2(len(data)))), data)
child = 2 * key < threshold
index += child * 2 ** k
bits = layers[h][0].equal(index, h)
return pick(bits, layers[h][1])

def test_decision_tree(name, layers, y, x, n_threads=None, time=False):
if time:
start_timer(100)
n = len(y)
x = x.transpose().reveal()
y = y.reveal()
guess = regint.Array(n)
truth = regint.Array(n)
correct = regint.Array(2)
parts = regint.Array(2)
layers = [[Array.create_from(util.reveal(x)) for x in layer]
for layer in layers]
@for_range_multithread(n_threads, 1, n)
def _(i):
guess[i] = run_decision_tree([[part[:] for part in layer]
for layer in layers], x[i]).reveal()
truth[i] = y[i].reveal()
@for_range(n)
def _(i):
parts[truth[i]] += 1
c = (guess[i].bit_xor(truth[i]).bit_not())
correct[truth[i]] += c
print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1,
sum(correct), n, correct[0], parts[0], correct[1], parts[1])
if time:
stop_timer(100)

class TreeClassifier:
""" Tree classification that uses
:py:class:`TreeTrainer` internally.
Expand Down
20 changes: 0 additions & 20 deletions Compiler/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,24 +72,4 @@ def _():
@library.else_
def _():
reveal_sort(h, D, reverse=True)

def radix_sort_permutation_from_matrix(bs, D):
n = len(D)
for b in bs:
assert(len(b) == n)
B = types.sint.Matrix(n, 2)
h = types.Array.create_from(types.sint(types.regint.inc(n)))
@library.for_range(len(bs))
def _(i):
b = bs[i]
B.set_column(0, 1 - b.get_vector())
B.set_column(1, b.get_vector())
c = types.Array.create_from(dest_comp(B))
reveal_sort(c, h, reverse=False)
@library.if_e(i < len(bs) - 1)
def _():
reveal_sort(h, bs[i + 1], reverse=True)
@library.else_
def _():
reveal_sort(h, D, reverse=True)
return h

0 comments on commit 8240527

Please sign in to comment.