Skip to content

Commit

Permalink
Removed some redundant blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
sandy9999 committed Jul 16, 2024
1 parent 4edabbf commit 8d106ed
Showing 1 changed file with 1 addition and 86 deletions.
87 changes: 1 addition & 86 deletions Compiler/decision_tree_optimized.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from Compiler.types import *
from Compiler.sorting import *
from Compiler.library import *
from Compiler.decision_tree import PrefixSum, PrefixSumR, PrefixSum_inv, PrefixSumR_inv, SortPerm, GroupSum, GroupPrefixSum, GroupFirstOne
from Compiler import util, oram

from itertools import accumulate
Expand Down Expand Up @@ -37,29 +38,6 @@ def GetSortPerm(keys, *to_sort, n_bits=None, time=False):
res = res.transpose()
return radix_sort_permutation_from_matrix(bs, res)

def PrefixSum(x):
return x.get_vector().prefix_sum()

def PrefixSumR(x):
tmp = get_type(x).Array(len(x))
tmp.assign_vector(x)
break_point()
tmp[:] = tmp.get_reverse_vector().prefix_sum()
break_point()
return tmp.get_reverse_vector()

def PrefixSum_inv(x):
tmp = get_type(x).Array(len(x) + 1)
tmp.assign_vector(x, base=1)
tmp[0] = 0
return tmp.get_vector(size=len(x), base=1) - tmp.get_vector(size=len(x))

def PrefixSumR_inv(x):
tmp = get_type(x).Array(len(x) + 1)
tmp.assign_vector(x)
tmp[-1] = 0
return tmp.get_vector(size=len(x)) - tmp.get_vector(base=1, size=len(x))

def ApplyPermutation(perm, x):
res = Array.create_from(x)
reveal_sort(perm, res, False)
Expand All @@ -70,71 +48,13 @@ def ApplyInversePermutation(perm, x):
reveal_sort(perm, res, True)
return res

class SortPerm:
def __init__(self, x):
B = sint.Matrix(len(x), 2)
B.set_column(0, 1 - x.get_vector())
B.set_column(1, x.get_vector())
self.perm = Array.create_from(dest_comp(B))
def apply(self, x):
res = Array.create_from(x)
reveal_sort(self.perm, res, False)
return res
def unapply(self, x):
res = Array.create_from(x)
reveal_sort(self.perm, res, True)
return res

def Sort(keys, *to_sort, n_bits=None, time=False):
if time:
start_timer(1)
for k in keys:
assert len(k) == len(keys[0])
n_bits = n_bits or [None] * len(keys)
bs = Matrix.create_from(
sum([k.get_vector().bit_decompose(nb)
for k, nb in reversed(list(zip(keys, n_bits)))], []))
get_vec = lambda x: x[:] if isinstance(x, Array) else x
res = Matrix.create_from(get_vec(x).v if isinstance(get_vec(x), sfix) else x
for x in to_sort)
res = res.transpose()
if time:
start_timer(11)
radix_sort_from_matrix(bs, res)
if time:
stop_timer(11)
stop_timer(1)
res = res.transpose()
return [sfix._new(get_vec(x), k=get_vec(y).k, f=get_vec(y).f)
if isinstance(get_vec(y), sfix)
else x for (x, y) in zip(res, to_sort)]

def VectMax(key, *data, debug=False):
def reducer(x, y):
b = x[0]*y[1] > y[0]*x[1]
return [b.if_else(xx, yy) for xx, yy in zip(x, y)]
res = util.tree_reduce(reducer, zip(key, *data))
return res

def GroupSum(g, x):
assert len(g) == len(x)
p = PrefixSumR(x) * g
pi = SortPerm(g.get_vector().bit_not())
p1 = pi.apply(p)
s1 = PrefixSumR_inv(p1)
d1 = PrefixSum_inv(s1)
d = pi.unapply(d1) * g
return PrefixSum(d)

def GroupPrefixSum(g, x):
assert len(g) == len(x)
s = get_type(x).Array(len(x) + 1)
s[0] = 0
s.assign_vector(PrefixSum(x), base=1)
q = get_type(s).Array(len(x))
q.assign_vector(s.get_vector(size=len(x)) * g)
return s.get_vector(size=len(x), base=1) - GroupSum(g, q)

def Custom_GT_Fractions(x_num, x_den, y_num, y_den, n_threads=2):
b = (x_num*y_den) > (x_den*y_num)
b = Array.create_from(b).get_vector()
Expand Down Expand Up @@ -223,11 +143,6 @@ def TrainLeafNodes(h, g, y, NID, Label, debug=False):
assert len(g) == len(NID)
return FormatLayer(h, g, NID, Label, debug=debug)

def GroupFirstOne(g, b):
assert len(g) == len(b)
s = GroupPrefixSum(g, b)
return s * b == 1

class TreeTrainer:
def GetInversePermutation(self, perm, n_threads=2):
res = Array.create_from(self.identity_permutation)
Expand Down

0 comments on commit 8d106ed

Please sign in to comment.