Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix io of MixedCoilSets #1016

Merged
merged 8 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@ class CoilSet(OptimizableCollection, _Coil, MutableSequence):
"""

_io_attrs_ = _Coil._io_attrs_ + ["_coils", "_NFP", "_sym"]
_io_attrs_.remove("_current")

def __init__(self, *coils, NFP=1, sym=False, name=""):
coils = flatten_list(coils, flatten_tuple=True)
Expand Down
55 changes: 31 additions & 24 deletions desc/io/hdf5_io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Classes for reading and writing HDF5 files."""

import numbers
import pydoc
import warnings

Expand Down Expand Up @@ -294,6 +295,9 @@
"""
loc = self.resolve_where(where)

def isarray(x):
return hasattr(x, "shape") and hasattr(x, "dtype")

# save name of object class
loc.create_dataset("__class__", data=fullname(obj))
from desc import __version__
Expand All @@ -302,38 +306,41 @@
for attr in obj._io_attrs_:
try:
data = getattr(obj, attr)
if data is None:
data = "None"
compression = (
"gzip"
if isinstance(data, np.ndarray) and np.asarray(data).size > 1
else None
)
loc.create_dataset(attr, data=data, compression=compression)
except AttributeError:
warnings.warn(
"Save attribute '{}' was not saved as it does "
"not exist.".format(attr),
RuntimeWarning,
)
except TypeError:
theattr = getattr(obj, attr)
if isinstance(theattr, dict):
group = loc.create_group(attr)
self.write_dict(theattr, where=group)
elif isinstance(theattr, list):
continue
if data is None:
data = "None"
if isarray(data):
data = np.asarray(data) # convert jax arrs to np

if (
isarray(data)
or isinstance(data, numbers.Number)
or isinstance(data, str)
):
compression = "gzip" if isarray(data) and data.size > 1 else None
loc.create_dataset(attr, data=data, compression=compression)
elif isinstance(data, dict):
group = loc.create_group(attr)
self.write_dict(data, where=group)
elif isinstance(data, (list, tuple)):
group = loc.create_group(attr)
self.write_list(data, where=group)
else:
from .equilibrium_io import IOAble

if isinstance(data, IOAble):
group = loc.create_group(attr)
self.write_list(theattr, where=group)
data.save(group)
else:
try:

group = loc.create_group(attr)
sub_obj = getattr(obj, attr)
sub_obj.save(group)
except AttributeError:
warnings.warn(
"Could not save object '{}'.".format(attr), RuntimeWarning
)
raise TypeError(

Check warning on line 341 in desc/io/hdf5_io.py

View check run for this annotation

Codecov / codecov/patch

desc/io/hdf5_io.py#L341

Added line #L341 was not covered by tests
f"don't know how to save attribute {attr} of type {type(data)}"
)

def write_dict(self, thedict, where=None):
"""Write dictionary to file in group specified by where argument.
Expand Down
71 changes: 71 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@
from netCDF4 import Dataset

from desc.__main__ import main
from desc.coils import (
CoilSet,
FourierPlanarCoil,
FourierRZCoil,
FourierXYZCoil,
MixedCoilSet,
)
from desc.compute import rpz2xyz_vec
from desc.equilibrium import EquilibriaFamily, Equilibrium
from desc.examples import get
from desc.grid import LinearGrid
from desc.vmec import VMECIO


Expand Down Expand Up @@ -217,6 +227,67 @@ def DummyStellarator(tmpdir_factory):
return DummyStellarator_out


@pytest.fixture(scope="session")
def DummyCoilSet(tmpdir_factory):
"""Create and save a dummy coil set for testing."""
output_dir = tmpdir_factory.mktemp("result")
output_path_sym = output_dir.join("DummyCoilSet_sym.h5")
output_path_asym = output_dir.join("DummyCoilSet_asym.h5")

eq = get("precise_QH")
minor_radius = eq.compute("a")["a"]

# CoilSet with symmetry
num_coils = 3 # number of unique coils per half field period
grid = LinearGrid(rho=[0.0], M=0, zeta=2 * num_coils, NFP=eq.NFP * (eq.sym + 1))
with pytest.warns(UserWarning): # because eq.NFP != grid.NFP
data_center = eq.axis.compute("x", grid=grid, basis="xyz")
data_normal = eq.compute("e^zeta", grid=grid)
centers = data_center["x"]
normals = rpz2xyz_vec(data_normal["e^zeta"], phi=grid.nodes[:, 2])
coils = []
for k in range(1, 2 * num_coils + 1, 2):
coil = FourierPlanarCoil(
current=1e6,
center=centers[k, :],
normal=normals[k, :],
r_n=[0, minor_radius + 0.5, 0],
)
coils.append(coil)
coilset_sym = CoilSet(coils, NFP=eq.NFP, sym=eq.sym)
coilset_sym.save(output_path_sym)

# equivalent CoilSet without symmetry
coilset_asym = CoilSet.from_symmetry(coilset_sym, NFP=eq.NFP, sym=eq.sym)
coilset_asym.save(output_path_asym)

DummyCoilSet_out = {
"output_path_sym": output_path_sym,
"output_path_asym": output_path_asym,
}
return DummyCoilSet_out


@pytest.fixture(scope="session")
def DummyMixedCoilSet(tmpdir_factory):
"""Create and save a dummy mixed coil set for testing."""
output_dir = tmpdir_factory.mktemp("result")
output_path = output_dir.join("DummyMixedCoilSet.h5")

tf_coil = FourierPlanarCoil(center=[2, 0, 0], normal=[0, 1, 0], r_n=[1])
tf_coilset = CoilSet.linspaced_angular(tf_coil, n=4)
vf_coil = FourierRZCoil(R_n=3, Z_n=-1)
vf_coilset = CoilSet.linspaced_linear(
vf_coil, displacement=[0, 0, 2], n=3, endpoint=True
)
xyz_coil = FourierXYZCoil()
full_coilset = MixedCoilSet((tf_coilset, vf_coilset, xyz_coil))

full_coilset.save(output_path)
DummyMixedCoilSet_out = {"output_path": output_path}
return DummyMixedCoilSet_out


@pytest.fixture(scope="session")
def writer_test_file(tmpdir_factory):
"""Create temporary output directory."""
Expand Down
Loading