From 9b3f1bf5ba9acaa519d93e20df4d7b83e1adbfb9 Mon Sep 17 00:00:00 2001 From: YigitElma Date: Wed, 27 Nov 2024 01:22:04 -0500 Subject: [PATCH] initial commit --- desc/objectives/utils.py | 16 +++++++++++- desc/optimize/_constraint_wrappers.py | 36 ++++++++++++++++++++++++++- desc/perturbations.py | 6 +++-- desc/vmec.py | 2 +- tests/test_linear_objectives.py | 8 +++--- 5 files changed, 59 insertions(+), 9 deletions(-) diff --git a/desc/objectives/utils.py b/desc/objectives/utils.py index d02e70bc7..ab4e0b6b5 100644 --- a/desc/objectives/utils.py +++ b/desc/objectives/utils.py @@ -122,6 +122,8 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa A = A_augmented[:, :-1] b = np.atleast_1d(A_augmented[:, -1].squeeze()) + A_nondegenerate = A.copy() + # will store the global index of the unfixed rows, idx indices_row = np.arange(A.shape[0]) indices_idx = np.arange(A.shape[1]) @@ -244,7 +246,19 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa "or be due to floating point error.", ) - return xp, A, b, Z, D, unfixed_idx, project, recover + return ( + xp, + A, + b, + Z, + D, + unfixed_idx, + project, + recover, + A_inv, + A_nondegenerate, + row_idx_to_delete, + ) class _Project(IOAble): diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index dc2e3033c..27f046469 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -108,6 +108,9 @@ def build(self, use_jit=None, verbose=1): self._unfixed_idx, self._project, self._recover, + self._Ainv, + self._A_full_nondegenerate, + self._degenerate_idx, # maybe we need those for b_new ) = factorize_linear_constraints( self._objective, self._constraint, @@ -164,6 +167,28 @@ def unpack_state(self, x, per_objective=True): x = self.recover(x) return self._objective.unpack_state(x, per_objective) + def update_constraint_target(self, eq_new): + """Update the target of the constraint.""" + for con in self._constraint.objectives: + if hasattr(con, "update_target"): + con.update_target(eq_new) + + x0 = jnp.zeros(self._constraint.dim_x) + b_new = -self._constraint.compute_scaled_error(x0) + b_new = np.delete(b_new, self._degenerate_idx) + xp_new = jnp.zeros_like(self._xp) + fixed_idx = np.setdiff1d(np.arange(self._xp.size), self._unfixed_idx) + xp_new[fixed_idx] = b_new[fixed_idx] + xp_new[self._unfixed_idx] = self._Ainv @ ( + b_new - self._A_full_nondegenerate[:, fixed_idx] @ xp_new[fixed_idx] + ) + from desc.objectives.utils import _Project, _Recover + + self._project = _Project(self._Z, self._D, xp_new, self._unfixed_idx) + self._recover = _Recover( + self._Z, self._D, xp_new, self._unfixed_idx, self._objective.dim_x + ) + def compute_unscaled(self, x_reduced, constants=None): """Compute the unscaled form of the objective function. @@ -533,7 +558,7 @@ def _set_eq_state_vector(self): self._args.remove(arg) linear_constraint = ObjectiveFunction(self._linear_constraints) linear_constraint.build() - _, _, _, self._Z, self._D, self._unfixed_idx, _, _ = ( + (_, _, _, self._Z, self._D, self._unfixed_idx, *_) = ( factorize_linear_constraints(self._constraint, linear_constraint) ) @@ -592,6 +617,12 @@ def build(self, use_jit=None, verbose=1): # noqa: C901 for constraint in self._linear_constraints: constraint.build(use_jit=use_jit, verbose=verbose) + self._eq_solve_objective = LinearConstraintProjection( + self._constraint, + ObjectiveFunction(self._linear_constraints), + ) + self._eq_solve_objective.build() + errorif( self._constraint.things != [eq], ValueError, @@ -759,6 +790,9 @@ def _update_equilibrium(self, x, store=False): x_dict = x_list[self._eq_idx] x_dict_old = x_list_old[self._eq_idx] deltas = {str(key): x_dict[key] - x_dict_old[key] for key in x_dict} + # Add some logic to perturb and solve to take single + # LinearConstraintProjection! + self._eq_solve_objective.update_constraint_target(self._eq) self._eq = self._eq.perturb( objective=self._constraint, constraints=self._linear_constraints, diff --git a/desc/perturbations.py b/desc/perturbations.py index 5c35e8f6d..7f024ca40 100644 --- a/desc/perturbations.py +++ b/desc/perturbations.py @@ -186,7 +186,7 @@ def perturb( # noqa: C901 if verbose > 0: print("Factorizing linear constraints") timer.start("linear constraint factorize") - xp, _, _, Z, D, unfixed_idx, project, recover = factorize_linear_constraints( + xp, _, _, Z, D, unfixed_idx, project, recover, *_ = factorize_linear_constraints( objective, constraint ) timer.stop("linear constraint factorize") @@ -750,7 +750,9 @@ def optimal_perturb( # noqa: C901 con.update_target(eq_new) constraint = ObjectiveFunction(constraints) constraint.build(verbose=verbose) - _, _, _, _, _, _, _, recover = factorize_linear_constraints(objective_f, constraint) + _, _, _, _, _, _, _, recover, *_ = factorize_linear_constraints( + objective_f, constraint + ) # update other attributes dx_reduced = dx1_reduced + dx2_reduced diff --git a/desc/vmec.py b/desc/vmec.py index 14c4cafff..a7c8d642d 100644 --- a/desc/vmec.py +++ b/desc/vmec.py @@ -192,7 +192,7 @@ def load( constraints = maybe_add_self_consistency(eq, constraints) objective = ObjectiveFunction(constraints) objective.build(verbose=0) - _, _, _, _, _, _, project, recover = factorize_linear_constraints( + _, _, _, _, _, _, project, recover, *_ = factorize_linear_constraints( objective, objective ) args = objective.unpack_state(recover(project(objective.x(eq))), False)[0] diff --git a/tests/test_linear_objectives.py b/tests/test_linear_objectives.py index 35df7f993..6ceca5a1d 100644 --- a/tests/test_linear_objectives.py +++ b/tests/test_linear_objectives.py @@ -445,7 +445,7 @@ def test_correct_indexing_passed_modes(): constraint = ObjectiveFunction(constraints, use_jit=False) constraint.build() - xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints( + xp, A, b, Z, D, unfixed_idx, project, recover, *_ = factorize_linear_constraints( objective, constraint ) @@ -508,7 +508,7 @@ def test_correct_indexing_passed_modes_and_passed_target(): constraint = ObjectiveFunction(constraints, use_jit=False) constraint.build() - xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints( + xp, A, b, Z, D, unfixed_idx, project, recover, *_ = factorize_linear_constraints( objective, constraint ) @@ -568,7 +568,7 @@ def test_correct_indexing_passed_modes_axis(): constraint = ObjectiveFunction(constraints, use_jit=False) constraint.build() - xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints( + xp, A, b, Z, D, unfixed_idx, project, recover, *_ = factorize_linear_constraints( objective, constraint ) @@ -697,7 +697,7 @@ def test_correct_indexing_passed_modes_and_passed_target_axis(): constraint = ObjectiveFunction(constraints, use_jit=False) constraint.build() - xp, A, b, Z, D, unfixed_idx, project, recover = factorize_linear_constraints( + xp, A, b, Z, D, unfixed_idx, project, recover, *_ = factorize_linear_constraints( objective, constraint )