Skip to content

Commit

Permalink
Merge pull request #87 from ddudt/dev
Browse files Browse the repository at this point in the history
IO Updates
  • Loading branch information
f0uriest authored Jun 2, 2021
2 parents 033cc3a + 36ccce2 commit 1798c1e
Show file tree
Hide file tree
Showing 38 changed files with 360 additions and 597 deletions.
21 changes: 1 addition & 20 deletions desc/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import mpmath
from abc import ABC, abstractmethod
from math import factorial
from desc.utils import sign, flatten_list, equals
from desc.utils import sign, flatten_list
from desc.io import IOAble

__all__ = [
Expand All @@ -19,25 +19,6 @@ class Basis(IOAble, ABC):

_io_attrs_ = ["_L", "_M", "_N", "_NFP", "_modes", "_sym", "_spectral_indexing"]

def __eq__(self, other):
"""Overloads the == operator
Parameters
----------
other : Basis
another Basis object to compare to
Returns
-------
bool
True if other is a Basis with the same attributes as self
False otherwise
"""
if self.__class__ != other.__class__:
return False
return equals(self.__dict__, other.__dict__)

def _enforce_symmetry(self):
"""Enforces stellarator symmetry"""

Expand Down
30 changes: 1 addition & 29 deletions desc/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from desc.optimize.constraint import LinearEqualityConstraint
from desc.grid import LinearGrid
from desc.transform import Transform
from desc.utils import unpack_state, equals
from desc.utils import unpack_state


__all__ = [
Expand Down Expand Up @@ -51,34 +51,6 @@ def __init__(
def recover_from_constraints(self, y, Rb_lmn=None, Zb_lmn=None):
"""Recover full state vector that satifies linear constraints."""

# note: we can't override __eq__ here because that breaks the hashing that jax uses
# when jitting functions
def eq(self, other):
"""Test for equivalence between conditions.
Parameters
----------
other : BoundaryCondition
another BoundaryCondition object to compare to
Returns
-------
bool
True if other is a BoundaryCondition with the same attributes as self
False otherwise
"""
if self.__class__ != other.__class__:
return False
ignore_keys = []
dict1 = {
key: val for key, val in self.__dict__.items() if key not in ignore_keys
}
dict2 = {
key: val for key, val in other.__dict__.items() if key not in ignore_keys
}
return equals(dict1, dict2)

@property
@abstractmethod
def name(self):
Expand Down
12 changes: 2 additions & 10 deletions desc/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from desc.objective_funs import get_objective_function
from desc.basis import (
PowerSeries,
FourierSeries,
DoubleFourierSeries,
ZernikePolynomial,
FourierZernikeBasis,
Expand Down Expand Up @@ -78,13 +77,6 @@ class _Configuration(IOAble, ABC):
"_profiles",
]

_object_lib_ = {
"PowerSeries": PowerSeries,
"FourierSeries": FourierSeries,
"DoubleFourierSeries": DoubleFourierSeries,
"FourierZernikeBasis": FourierZernikeBasis,
}

def __init__(self, inputs):
"""Initialize a Configuration.
Expand Down Expand Up @@ -136,8 +128,8 @@ def __init__(self, inputs):
self._R_sym = "cos"
self._Z_sym = "sin"
else:
self._R_sym = None
self._Z_sym = None
self._R_sym = False
self._Z_sym = False

# create bases
self._set_basis()
Expand Down
41 changes: 5 additions & 36 deletions desc/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,8 @@
from desc.utils import Timer, isalmostequal
from desc.configuration import _Configuration, format_boundary, format_profiles
from desc.io import IOAble
from desc.boundary_conditions import (
get_boundary_condition,
BoundaryCondition,
LCFSConstraint,
PoincareConstraint,
)
from desc.objective_funs import (
get_objective_function,
ObjectiveFunction,
ForceErrorNodes,
ForceErrorGalerkin,
EnergyVolIntegral,
)
from desc.boundary_conditions import get_boundary_condition, BoundaryCondition
from desc.objective_funs import get_objective_function, ObjectiveFunction
from desc.optimize import Optimizer
from desc.grid import Grid, LinearGrid, ConcentricGrid, QuadratureGrid
from desc.transform import Transform
Expand Down Expand Up @@ -80,24 +69,6 @@ class Equilibrium(_Configuration, IOAble):
"optimizer_results",
"_optimizer",
]
_object_lib_ = _Configuration._object_lib_
_object_lib_.update(
{
"_Configuration": _Configuration,
"Grid": Grid,
"LinearGrid": LinearGrid,
"ConcentricGrid": ConcentricGrid,
"QuadratureGrid": QuadratureGrid,
"Transform": Transform,
"Optimizer": Optimizer,
"ForceErrorNodes": ForceErrorNodes,
"ForceErrorGalerkin": ForceErrorGalerkin,
"EnergyVolIntegral": EnergyVolIntegral,
"BoundaryCondition": BoundaryCondition,
"LCFSConstraint": LCFSConstraint,
"PoincareConstraint": PoincareConstraint,
}
)

def __init__(self, inputs):

Expand Down Expand Up @@ -472,9 +443,9 @@ def optimizer(self):
def optimizer(self, optimizer):
if optimizer is None:
self._optimizer = optimizer
elif isinstance(optimizer, Optimizer) and optimizer == self.optimizer:
elif isinstance(optimizer, Optimizer) and optimizer.eq(self.optimizer):
return
elif isinstance(optimizer, Optimizer) and optimizer != self.optimizer:
elif isinstance(optimizer, Optimizer) and not optimizer.eq(self.optimizer):
self._optimizer = optimizer
elif optimizer in Optimizer._all_methods:
self._optimizer = Optimizer(optimizer)
Expand Down Expand Up @@ -729,9 +700,7 @@ class EquilibriaFamily(IOAble, MutableSequence):
"""

_io_attrs_ = ["equilibria"]
_object_lib_ = Equilibrium._object_lib_
_object_lib_.update({"Equilibrium": Equilibrium})
_io_attrs_ = ["_equilibria"]

def __init__(self, inputs):
# did we get 1 set of inputs or several?
Expand Down
24 changes: 3 additions & 21 deletions desc/grid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from termcolor import colored
from desc.utils import equals

from desc.io import IOAble
from scipy import special

Expand Down Expand Up @@ -52,25 +52,6 @@ def __init__(self, nodes, sort=True):
self._sort_nodes()
self._find_axis()

def __eq__(self, other):
"""Overloads the == operator
Parameters
----------
other : Grid
another Grid object to compare to
Returns
-------
bool
True if other is a Grid with the same attributes as self
False otherwise
"""
if self.__class__ != other.__class__:
return False
return equals(self.__dict__, other.__dict__)

def _enforce_symmetry(self):
"""Enforces stellarator symmetry"""
if self.sym: # remove nodes with theta > pi
Expand Down Expand Up @@ -649,7 +630,8 @@ def ocs(L):
2 * np.pi / (2 * M + np.ceil((M / L) * (5 - 4 * iring)).astype(int))
)
theta = np.arange(0, 2 * np.pi, dtheta)
theta = (theta + dtheta / 3) % (2 * np.pi)
if self.sym:
theta = (theta + dtheta / 3) % (2 * np.pi)
for tk in theta:
r.append(rho[-iring])
t.append(tk)
Expand Down
4 changes: 2 additions & 2 deletions desc/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .input_reader import InputReader
from .equilibrium_io import IOAble
from .equilibrium_io import IOAble, load
from .pickle_io import PickleReader, PickleWriter
from .hdf5_io import hdf5Reader, hdf5Writer
from .ascii_io import read_ascii, write_ascii

__all__ = ["InputReader"]
__all__ = ["InputReader", "load"]
Loading

0 comments on commit 1798c1e

Please sign in to comment.