Skip to content

Commit

Permalink
more flexible transformed diff functionals
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Jul 10, 2024
1 parent cfcd1e2 commit f2a71ab
Showing 1 changed file with 60 additions and 14 deletions.
74 changes: 60 additions & 14 deletions lib/gpt/core/group/differentiable_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,11 @@ def approximate_gradient(
(cc / epsilon)
* self(
[
g(g.group.compose((dd * epsilon) * weights[dfields.index(f)], f))
if f in dfields
else f
(
g(g.group.compose((dd * epsilon) * weights[dfields.index(f)], f))
if f in dfields
else f
)
for f in fields
]
)
Expand Down Expand Up @@ -108,12 +110,24 @@ def assert_gradient_error(self, rng, fields, dfields, epsilon_approx, epsilon_as
g.message(f"Error: cartesian defect: {eps} > {epsilon_assert}")
assert False

def transformed(self, t):
return transformed(self, t)
def transformed(self, t, indices=None):
return transformed(self, t, indices)

def __add__(self, other):
return added(self, other)

def __radd__(self, other):
# called if not isinstance(other, differentiable_functional)
# needed to make sum([ f1, f2, ... ]) work
assert other == 0
return self

def __mul__(self, other):
return scaled(other, self)

def __rmul__(self, other):
return scaled(other, self)


class added(differentiable_functional):
def __init__(self, a, b):
Expand All @@ -132,24 +146,56 @@ def gradient(self, fields, dfields):
return [g(x + y) for x, y in zip(a_grad, b_grad)]


class scaled(differentiable_functional):
def __init__(self, s, f):
self.s = s
self.f = f

def __call__(self, fields):
return self.s * self.f(fields)

def gradient(self, fields, dfields):
grad = self.f.gradient(fields, dfields)
return [g(self.s * x) for x in grad]


class transformed(differentiable_functional):
def __init__(self, f, t):
def __init__(self, f, t, indices):
self.f = f
self.t = t
self.indices = indices

def __call__(self, fields):
return self.f(self.t(fields))
indices = self.indices if self.indices is not None else range(len(fields))
fields_indices = [fields[i] for i in indices]
fields_transformed = self.t(fields_indices)
fields_prime = [None if i in indices else fields[i] for i in range(len(fields))]
for i, j in zip(range(len(indices)), indices):
fields_prime[j] = fields_transformed[i]
return self.f(fields_prime)

def gradient(self, fields, dfields):
indices = [fields.index(d) for d in dfields]

fields_prime = self.t(fields)
# save indices w.r.t. which we want the gradients
derivative_indices = [fields.index(d) for d in dfields]

# do the forward pass
indices = self.indices if self.indices is not None else range(len(fields))
fields_indices = [fields[i] for i in indices]
fields_transformed = self.t(fields_indices)
fields_prime = [None if i in indices else fields[i] for i in range(len(fields))]
for i, j in zip(range(len(indices)), indices):
fields_prime[j] = fields_transformed[i]

transformed_fields = [i for i in range(len(fields)) if fields_prime[i] is not fields[i]]
# start the backwards pass with a calculation of the gradient with the transformed fields
gradient_prime = self.f.gradient(fields_prime, fields_prime)

# for now only accept gradients with respect to all transformed fields
assert indices == transformed_fields
# now apply the jacobian to the transformed gradients
gradient_transformed = self.t.jacobian(
fields_indices, fields_transformed, [gradient_prime[i] for i in indices]
)

gradient_prime = self.f.gradient(fields_prime, [fields_prime[i] for i in indices])
for i, j in zip(range(len(indices)), indices):
gradient_prime[j] = gradient_transformed[i]

return self.t.jacobian(fields, fields_prime, gradient_prime)
return [gradient_prime[i] for i in derivative_indices]

0 comments on commit f2a71ab

Please sign in to comment.