Skip to content

Commit

Permalink
tune
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Jan 10, 2025
1 parent 88a421c commit e83055c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
13 changes: 9 additions & 4 deletions lib/gpt/ml/layer/parallel_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def __init__(
self.access_cache = {}

tmp = [g.lattice(data_grid, ot_input) for i in range(len(paths))]
self.transport = [g.parallel_transport(self.U, [p], [t]) for t, p in zip(tmp, paths)]
self.transport = [
g.parallel_transport(self.U, [p], [t]) for t, p in zip(tmp, paths)
]
self.itransport = None

def weights(self):
Expand All @@ -48,8 +50,8 @@ def _get_field_list(self, layer_input, ttr):

assert len(layer_input) == 1 + len(self.paths)
assert len(ttr) == len(self.paths)
ret_f = [ layer_input[0] ]

ret_f = [layer_input[0]]

for ttrl, l in zip(ttr, layer_input[1:]):
xx = list(ttrl(self.U, [l]))
Expand All @@ -74,7 +76,10 @@ def projected_gradient_adj(self, weights, layer_input, left):
t = g.timer("projected_gradient_adj")
t("field list")
if self.itransport is None:
self.itransport = [ g.parallel_transport(self.U, [p.inverse()], [l]) for p, l in zip(self.paths, left) ]
self.itransport = [
g.parallel_transport(self.U, [p.inverse()], [l])
for p, l in zip(self.paths, left)
]

t("inverse field list")
ileft = self._get_field_list(left, self.itransport)
Expand Down
12 changes: 6 additions & 6 deletions tests/ml/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,18 @@
n = g.ml.model.sequence(
g.ml.layer.parallel_transport_convolution(grid, U, paths, ot_i, ot_w, 1, 3),
g.ml.layer.parallel_transport_convolution(grid, U, paths, ot_i, ot_w, 3, 3),
g.ml.layer.linear(grid, ot_i, ot_w, 3, 1 + len(paths)*0+1),
g.ml.layer.parallel_transport(grid, U, [paths[0]], ot_i), #paths
g.ml.layer.linear(grid, ot_i, ot_w, 1 + len(paths)*0+1, 3),
g.ml.layer.linear(grid, ot_i, ot_w, 3, 1 + len(paths)),
g.ml.layer.parallel_transport(grid, U, paths, ot_i),
g.ml.layer.linear(grid, ot_i, ot_w, 1 + len(paths), 3),
g.ml.layer.parallel_transport_convolution(grid, U, paths, ot_i, ot_w, 3, 1),
)

n_prime = g.ml.model.sequence(
g.ml.layer.parallel_transport_convolution(grid, U_prime, paths, ot_i, ot_w, 1, 3),
g.ml.layer.parallel_transport_convolution(grid, U_prime, paths, ot_i, ot_w, 3, 3),
g.ml.layer.linear(grid, ot_i, ot_w, 3, 1 + len(paths)*0+1),
g.ml.layer.parallel_transport(grid, U_prime, [paths[0]], ot_i), #paths
g.ml.layer.linear(grid, ot_i, ot_w, 1 + len(paths)*0+1, 3),
g.ml.layer.linear(grid, ot_i, ot_w, 3, 1 + len(paths)),
g.ml.layer.parallel_transport(grid, U_prime, paths, ot_i),
g.ml.layer.linear(grid, ot_i, ot_w, 1 + len(paths), 3),
g.ml.layer.parallel_transport_convolution(grid, U_prime, paths, ot_i, ot_w, 3, 1),
)

Expand Down

0 comments on commit e83055c

Please sign in to comment.