Skip to content

Commit

Permalink
add l1 regulator
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Jan 13, 2025
1 parent 6be21b4 commit 0caf060
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
30 changes: 29 additions & 1 deletion lib/gpt/ml/regulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,32 @@ def __call__(self, a):
return (self.lam / 2) * sum([g.norm2(a[i]) for i in self.indices])

def gradient(self, a, da):
return [g(self.lam * x) for x in da]
return [g((self.lam if a.index(x) in self.indices else 0.0) * x) for x in da]


class L1(differentiable_functional):
def __init__(self, lam, indices):
self.lam = lam
self.indices = indices

def __call__(self, a):
a = g.util.to_list(a)
r = 0.0
for i in self.indices:
x = g(
g.component.sqrt(g.component.abs(g.component.real(a[i])))
+ 1j * g.component.sqrt(g.component.abs(g.component.imag(a[i])))
)
r += g.sum(
g.trace(g.adj(x) * x)
)
return r * self.lam

def gradient(self, a, da):
dabs = g.component.drelu(-1)
return [
g(
(self.lam if a.index(x) in self.indices else 0.0) *
(dabs(g.component.real(x)) + 1j * dabs(g.component.imag(x)))
) for x in da
]
4 changes: 2 additions & 2 deletions tests/ml/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
training_input = [rng.cnormal(g.lattice(grid, ot_i)) for i in range(n_training)]
training_output = [rng.cnormal(g.lattice(grid, ot_i)) for i in range(n_training)]

c = n.cost() + g.ml.regulator.L2(0.1, range(len(W)))
c = n.cost() + g.ml.regulator.L2(0.1, [1,2])
g.message("Cost:", c(W + training_input + training_output))

c.assert_gradient_error(
Expand Down Expand Up @@ -110,7 +110,7 @@
n_training = 3
training_output = [rng.normal(g.lattice(grid, ot_i)) for i in range(n_training)]
training_input = [rng.normal(g.lattice(grid, ot_i)) for i in range(n_training)]
c = n.cost()
c = n.cost() + g.ml.regulator.L1(0.1, [0])

c.assert_gradient_error(rng, W + training_input + training_output, W, 1e-3, 1e-8)

Expand Down

0 comments on commit 0caf060

Please sign in to comment.