Skip to content

Commit

Permalink
resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Jan 11, 2025
1 parent 5003e00 commit 311b0de
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 7 deletions.
2 changes: 1 addition & 1 deletion lib/gpt/ml/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@
from gpt.ml.layer.group import group
from gpt.ml.layer.parallel import parallel
from gpt.ml.layer.sequence import sequence
from gpt.ml.layer.residual import residual
import gpt.ml.layer.block
import gpt.ml.layer.parallel_transport_pooling

60 changes: 60 additions & 0 deletions lib/gpt/ml/layer/residual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#
# GPT - Grid Python Toolkit
# Copyright (C) 2020-22 Christoph Lehner ([email protected], https://github.com/lehner/gpt)
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import gpt as g
from gpt.ml.layer.group import group


def add(a, b):
if isinstance(a, list):
return [g(x + y) for x, y in zip(a, b)]
else:
return g(a + b)


class residual(group):
def __init__(self, *layers):
super().__init__(layers)

def __call__(self, weights, input_layer):
current = input_layer
for i in range(len(self.layers)):
current = self.forward(i, weights, current)

return add(current, input_layer)

# out = layer2(w2, layer1(w1, in))
# left_i partial_i out
def projected_gradient_adj(self, weights, input_layer, left):
r = [None for x in weights]
layer_value = [input_layer]
# forward propagation
for i in range(len(self.layers) - 1):
layer_value.append(self.forward(i, weights, layer_value[-1]))
# backward propagation
current_left = left
for i in reversed(range(len(self.layers))):
gr = self.dforward_adj(i, weights, layer_value[i], current_left)
current_left = gr[-1]
i0, i1 = self.weights_index[i]
for j in range(i0, i1):
if r[j] is None:
r[j] = gr[j - i0]
else:
r[j] += gr[j - i0]
return r + [add(current_left, left)]
16 changes: 10 additions & 6 deletions tests/ml/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,22 @@
n = g.ml.model.sequence(
g.ml.layer.parallel_transport_convolution(grid, U, paths, ot_i, ot_w, 1, 3),
g.ml.layer.parallel_transport_convolution(grid, U, paths, ot_i, ot_w, 3, 3),
g.ml.layer.linear(grid, ot_i, ot_w, 3, 1 + len(paths)),
g.ml.layer.parallel_transport(grid, U, paths, ot_i),
g.ml.layer.linear(grid, ot_i, ot_w, 1 + len(paths), 3),
g.ml.layer.residual(
g.ml.layer.linear(grid, ot_i, ot_w, 3, 1 + len(paths)),
g.ml.layer.parallel_transport(grid, U, paths, ot_i),
g.ml.layer.linear(grid, ot_i, ot_w, 1 + len(paths), 3),
),
g.ml.layer.parallel_transport_convolution(grid, U, paths, ot_i, ot_w, 3, 1),
)

n_prime = g.ml.model.sequence(
g.ml.layer.parallel_transport_convolution(grid, U_prime, paths, ot_i, ot_w, 1, 3),
g.ml.layer.parallel_transport_convolution(grid, U_prime, paths, ot_i, ot_w, 3, 3),
g.ml.layer.linear(grid, ot_i, ot_w, 3, 1 + len(paths)),
g.ml.layer.parallel_transport(grid, U_prime, paths, ot_i),
g.ml.layer.linear(grid, ot_i, ot_w, 1 + len(paths), 3),
g.ml.layer.residual(
g.ml.layer.linear(grid, ot_i, ot_w, 3, 1 + len(paths)),
g.ml.layer.parallel_transport(grid, U_prime, paths, ot_i),
g.ml.layer.linear(grid, ot_i, ot_w, 1 + len(paths), 3),
),
g.ml.layer.parallel_transport_convolution(grid, U_prime, paths, ot_i, ot_w, 3, 1),
)

Expand Down

0 comments on commit 311b0de

Please sign in to comment.