Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use single LinearConstraintProjection in ProximalProjection #1409

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion desc/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa
A = A_augmented[:, :-1]
b = np.atleast_1d(A_augmented[:, -1].squeeze())

A_nondegenerate = A.copy()

# will store the global index of the unfixed rows, idx
indices_row = np.arange(A.shape[0])
indices_idx = np.arange(A.shape[1])
Expand Down Expand Up @@ -244,7 +246,19 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa
"or be due to floating point error.",
)

return xp, A, b, Z, D, unfixed_idx, project, recover
return (
xp,
A,
b,
Z,
D,
unfixed_idx,
project,
recover,
A_inv,
A_nondegenerate,
row_idx_to_delete,
)


class _Project(IOAble):
Expand Down
36 changes: 35 additions & 1 deletion desc/optimize/_constraint_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def build(self, use_jit=None, verbose=1):
self._unfixed_idx,
self._project,
self._recover,
self._Ainv,
self._A_full_nondegenerate,
self._degenerate_idx, # maybe we need those for b_new
) = factorize_linear_constraints(
self._objective,
self._constraint,
Expand Down Expand Up @@ -164,6 +167,28 @@ def unpack_state(self, x, per_objective=True):
x = self.recover(x)
return self._objective.unpack_state(x, per_objective)

def update_constraint_target(self, eq_new):
"""Update the target of the constraint."""
for con in self._constraint.objectives:
if hasattr(con, "update_target"):
con.update_target(eq_new)

x0 = jnp.zeros(self._constraint.dim_x)
b_new = -self._constraint.compute_scaled_error(x0)
b_new = np.delete(b_new, self._degenerate_idx)
xp_new = jnp.zeros_like(self._xp)
fixed_idx = np.setdiff1d(np.arange(self._xp.size), self._unfixed_idx)
xp_new[fixed_idx] = b_new[fixed_idx]
xp_new[self._unfixed_idx] = self._Ainv @ (
b_new - self._A_full_nondegenerate[:, fixed_idx] @ xp_new[fixed_idx]
)
from desc.objectives.utils import _Project, _Recover

self._project = _Project(self._Z, self._D, xp_new, self._unfixed_idx)
self._recover = _Recover(
self._Z, self._D, xp_new, self._unfixed_idx, self._objective.dim_x
)

def compute_unscaled(self, x_reduced, constants=None):
"""Compute the unscaled form of the objective function.

Expand Down Expand Up @@ -533,7 +558,7 @@ def _set_eq_state_vector(self):
self._args.remove(arg)
linear_constraint = ObjectiveFunction(self._linear_constraints)
linear_constraint.build()
_, _, _, self._Z, self._D, self._unfixed_idx, _, _ = (
(_, _, _, self._Z, self._D, self._unfixed_idx, *_) = (
factorize_linear_constraints(self._constraint, linear_constraint)
)

Expand Down Expand Up @@ -592,6 +617,12 @@ def build(self, use_jit=None, verbose=1): # noqa: C901
for constraint in self._linear_constraints:
constraint.build(use_jit=use_jit, verbose=verbose)

self._eq_solve_objective = LinearConstraintProjection(
self._constraint,
ObjectiveFunction(self._linear_constraints),
)
self._eq_solve_objective.build()

errorif(
self._constraint.things != [eq],
ValueError,
Expand Down Expand Up @@ -759,6 +790,9 @@ def _update_equilibrium(self, x, store=False):
x_dict = x_list[self._eq_idx]
x_dict_old = x_list_old[self._eq_idx]
deltas = {str(key): x_dict[key] - x_dict_old[key] for key in x_dict}
# Add some logic to perturb and solve to take single
# LinearConstraintProjection!
self._eq_solve_objective.update_constraint_target(self._eq)
self._eq = self._eq.perturb(
objective=self._constraint,
constraints=self._linear_constraints,
Expand Down
6 changes: 4 additions & 2 deletions desc/perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def perturb( # noqa: C901
if verbose > 0:
print("Factorizing linear constraints")
timer.start("linear constraint factorize")
xp, _, _, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
xp, _, _, Z, D, unfixed_idx, project, recover, *_ = factorize_linear_constraints(
objective, constraint
)
timer.stop("linear constraint factorize")
Expand Down Expand Up @@ -750,7 +750,9 @@ def optimal_perturb( # noqa: C901
con.update_target(eq_new)
constraint = ObjectiveFunction(constraints)
constraint.build(verbose=verbose)
_, _, _, _, _, _, _, recover = factorize_linear_constraints(objective_f, constraint)
_, _, _, _, _, _, _, recover, *_ = factorize_linear_constraints(
objective_f, constraint
)

# update other attributes
dx_reduced = dx1_reduced + dx2_reduced
Expand Down
2 changes: 1 addition & 1 deletion desc/vmec.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def load(
constraints = maybe_add_self_consistency(eq, constraints)
objective = ObjectiveFunction(constraints)
objective.build(verbose=0)
_, _, _, _, _, _, project, recover = factorize_linear_constraints(
_, _, _, _, _, _, project, recover, *_ = factorize_linear_constraints(
objective, objective
)
args = objective.unpack_state(recover(project(objective.x(eq))), False)[0]
Expand Down
8 changes: 4 additions & 4 deletions tests/test_linear_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def test_correct_indexing_passed_modes():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
xp, A, b, Z, D, unfixed_idx, project, recover, *_ = factorize_linear_constraints(
objective, constraint
)

Expand Down Expand Up @@ -508,7 +508,7 @@ def test_correct_indexing_passed_modes_and_passed_target():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
xp, A, b, Z, D, unfixed_idx, project, recover, *_ = factorize_linear_constraints(
objective, constraint
)

Expand Down Expand Up @@ -568,7 +568,7 @@ def test_correct_indexing_passed_modes_axis():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
xp, A, b, Z, D, unfixed_idx, project, recover, *_ = factorize_linear_constraints(
objective, constraint
)

Expand Down Expand Up @@ -697,7 +697,7 @@ def test_correct_indexing_passed_modes_and_passed_target_axis():
constraint = ObjectiveFunction(constraints, use_jit=False)
constraint.build()

xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints(
xp, A, b, Z, D, unfixed_idx, project, recover, *_ = factorize_linear_constraints(
objective, constraint
)

Expand Down
Loading