Skip to content

Commit

Permalink
Making the code compatible with numpy 2.0 and unyt 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
miekkasarki committed Jul 26, 2024
1 parent b4e4b7f commit f180148
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 30 deletions.
2 changes: 1 addition & 1 deletion a5py/ascot5io/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ def ppappe2ekinpitch(dist, mass, ekin_edges=10, pitch_edges=10):
dist.abscissa_edges("ppar")[0]**2,
dist.abscissa_edges("pperp")[-1]**2)
ekinmax = physlib.energy_momentum(mass, np.sqrt(p2max)).to("eV")
ekin_edges = np.linspace(0, ekinmax, ekin_edges)
ekin_edges = np.linspace(0*unyt.eV, ekinmax, ekin_edges)
if isinstance(pitch_edges, int):
pitch_edges = np.linspace(-1, 1, pitch_edges)

Expand Down
6 changes: 4 additions & 2 deletions a5py/ascot5io/orbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,10 @@ def _val(q, mask=None):
"""
with self as h5:
if q in h5:
q = fileapi.read_data(h5, q)
return q if mask is None else q[mask]
qnt = fileapi.read_data(h5, q)
if q == "weight":
qnt *= unyt.unyt_quantity.from_string("particles/s")
return qnt if mask is None else qnt[mask]
return None

# Sort using the fact that inistate.get return values ordered by ID
Expand Down
6 changes: 3 additions & 3 deletions a5py/ascotpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ def integrate_mc(rho, theta, phimin, phimax):
np.append(phimin, phimax),
indexing="ij" )
rc, zc = self.input_rhotheta2rz(
rhoc.ravel(), thc.ravel()*unyt.rad, phic.ravel()*unyt.rad, t)
rhoc.ravel(), thc.ravel(), phic.ravel(), t)
bbox = np.ones((4,)) * unyt.m
bbox[0] = np.nanmin(rc)
bbox[1] = np.nanmax(rc)
Expand Down Expand Up @@ -1363,8 +1363,8 @@ def input_eval_orbitresonance(

nmrk = mrk['n']
passing = np.zeros((n1, n2))
torfreq = np.zeros((n1, n2)) + np.NaN
polfreq = np.zeros((n1, n2)) + np.NaN
torfreq = np.zeros((n1, n2)) + np.nan
polfreq = np.zeros((n1, n2)) + np.nan
for i in range(nmrk):
ixi = int(i / n1)
irho = i % n1
Expand Down
6 changes: 3 additions & 3 deletions a5py/ascotpy/libascot.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,8 +797,8 @@ def input_rhotheta2rz(self, rho, theta, phi, time, maxiter=100, tol=1e-5):
self._requireinit("bfield")
rho = np.asarray(rho).ravel().astype(dtype="f8")
Neval = rho.size
r = np.NaN * np.zeros((Neval,), dtype="f8") * unyt.m
z = np.NaN * np.zeros((Neval,), dtype="f8") * unyt.m
r = np.nan * np.zeros((Neval,), dtype="f8") * unyt.m
z = np.nan * np.zeros((Neval,), dtype="f8") * unyt.m

if theta.size == 1:
theta = theta * np.ones(rho.shape).astype(dtype="f8")
Expand Down Expand Up @@ -857,7 +857,7 @@ def input_findpsi0(self, psi1):
psi0 = self._eval_bfield(
ax["axisr"], 0.0*unyt.rad, ax["axisz"], 0.0*unyt.s, evalrho=True)

psi = np.NaN * np.zeros((1,), dtype="f8") * unyt.Wb
psi = np.nan * np.zeros((1,), dtype="f8") * unyt.Wb
rz = np.zeros((2,), dtype="f8") * unyt.m
rz[0] = ax["axisr"]
rz[1] = ax["axisz"]
Expand Down
2 changes: 2 additions & 0 deletions a5py/routines/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,8 @@ def hist2d(x, y, xbins=None, ybins=None, weights=None, xlog="linear",
norm = None
if logscale: norm = mpl.colors.LogNorm()

if weights is not None:
weights = weights.v # Cannot have units in weights yet in 2D histogram
h,_,_,m = axes.hist2d(x, y, bins=[xbins, ybins], weights=weights, norm=norm)

cbar = plt.colorbar(m, ax=axes, cax=cax)
Expand Down
8 changes: 3 additions & 5 deletions a5py/routines/runmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def getstate(self, *qnt, mode="gc", state="ini", ids=None, endcond=None):

# Go through each unique end cond and mark that end cond valid or
# not. This can then be used to make udix as boolean mask array.
uecs, uidx = np.unique(self._endstate.get("endcond"),
uecs, uidx = np.unique(self._endstate.get("endcond")[0],
return_inverse=True)
mask = np.zeros(uecs.shape, dtype=bool)
for i, uec in enumerate(uecs):
Expand All @@ -161,7 +161,7 @@ def getstate(self, *qnt, mode="gc", state="ini", ids=None, endcond=None):
if ids is not None:
idx = np.logical_and(idx, np.in1d(self._inistate.get("ids"), ids))

for i in range(len(data)):
for i in range(len(qnt)):
data[i] = data[i][idx]
if "mu" in qnt:
data[qnt.index("mu")].convert_to_units("eV/T")
Expand Down Expand Up @@ -1109,11 +1109,10 @@ def parsearg(arg, mode, endcond):
# Sort data so that when the stacked histogram is plotted, the stack
# with most markers is at the bottom.
idx = np.argsort([len(i) for i in xcs])[::-1]
xcs = [xcs[i].v for i in idx]
xcs = [xcs[i] for i in idx]
ecs = [endconds[i][1] + " : %.2e" % endconds[i][0] for i in idx]
weights = [weights[i] for i in idx]
if not weight: weights = None

a5plt.hist1d(x=xcs, xbins=xbins, weights=weights, xlog=xlog,
logscale=logscale, xlabel=x, axes=axes, legend=ecs)

Expand All @@ -1124,7 +1123,6 @@ def parsearg(arg, mode, endcond):
weights = self.getstate("weight", state="ini", endcond=endcond,
ids=ids)
if not weight: weights = None

a5plt.hist2d(xc, yc, xbins=xbins, ybins=ybins, weights=weights,
xlog=xlog, ylog=ylog, logscale=logscale, xlabel=x,
ylabel=y, axesequal=axesequal, axes=axes, cax=cax)
Expand Down
24 changes: 12 additions & 12 deletions a5py/testascot/unittests.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,10 @@ def test_resultnode(self):
qid2 = a5.data.efield[grp2].get_qid()
with h5py.File(self.testfilename, "a") as h5:
group = fileapi.add_group(h5, "results", "run", desc="RUN1")
group.attrs["qid_efield"] = np.string_(qid1)
group.attrs["qid_efield"] = np.bytes_(qid1)
time.sleep(1.0)
group = fileapi.add_group(h5, "results", "run", desc="RUN2")
group.attrs["qid_efield"] = np.string_(qid2)
group.attrs["qid_efield"] = np.bytes_(qid2)

a5 = Ascot(self.testfilename)
run1 = a5.data["RUN1"].get_qid()
Expand Down Expand Up @@ -276,7 +276,7 @@ def test_resultnode(self):
# Create a new run and test the remove all results method
with h5py.File(self.testfilename, "a") as h5:
group = fileapi.add_group(h5, "results", "run", desc="RUN1")
group.attrs["qid_efield"] = np.string_(qid1)
group.attrs["qid_efield"] = np.bytes_(qid1)

a5.data.destroy(repack=False)
self.assertFalse("results" in a5.data,
Expand Down Expand Up @@ -822,19 +822,19 @@ def test_moments(self):
mrkdist = {}
mrkdist["density"] = weight * dt
mrkdist["chargedensity"] = weight * charge * dt
mrkdist["energydensity"] = weight * ( energy * dt ).to("J")
mrkdist["pressure"] = weight * (mass * vnorm**2 * dt).to("J")/3
mrkdist["toroidalcurrent"] = weight * ( charge * vphi * dt ).to("A*m")
mrkdist["parallelcurrent"] = weight * ( charge * vpar * dt ).to("A*m")
mrkdist["powerdep"] = weight * dEtot_d.to("W")
mrkdist["electronpowerdep"] = weight * dEele_d.to("W")
mrkdist["ionpowerdep"] = weight * dEion_d.to("W")
mrkdist["energydensity"] = weight * (energy * dt).to("J*s")
mrkdist["pressure"] = weight * (mass * vnorm**2*dt).to("J*s")/3
mrkdist["toroidalcurrent"] = weight * ( charge * vphi * dt ).to("A*s*m")
mrkdist["parallelcurrent"] = weight * ( charge * vpar * dt ).to("A*s*m")
mrkdist["powerdep"] = weight * dEtot_d.to("J")
mrkdist["electronpowerdep"] = weight * dEele_d.to("J")
mrkdist["ionpowerdep"] = weight * dEion_d.to("J")
mrkdist["jxbtorque"] = weight * (-charge * dpsi/unyt.s).to("N*m")
mrkdist["colltorque"] = weight * (r*dppar*(bphi/bnorm)*dt).to("J")
mrkdist["colltorque"] = weight * (r*dppar*(bphi/bnorm)*dt).to("J*s")
mrkdist["canmomtorque"] = weight * -charge * dPphi

print(weight[0]*tf)
print(((ef-ei)*weight[0]/unyt.s).to("W"))
print(((ef-ei)*weight[0]).to("W"))
for o in ordinates:
a1 = np.sum(rhomom.ordinate(o) * rhomom.volume)
a2 = np.sum(mom.ordinate(o) * mom.volume)
Expand Down
2 changes: 1 addition & 1 deletion environment-dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies:
- h5py
- mpi4py
- xmlschema
- unyt==2.9.5
- unyt
- wurlitzer
- matplotlib
- pyvista
Expand Down
2 changes: 1 addition & 1 deletion environment-mpi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- h5py=*=*openmpi*
- mpi4py
- xmlschema
- unyt==2.9.5
- unyt
- wurlitzer
- matplotlib
- pyvista
Expand Down
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies:
- scipy
- h5py
- xmlschema
- unyt==2.9.5
- unyt
- wurlitzer
- matplotlib
- pyvista
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"numpy",
"scipy",
"h5py",
"unyt==2.9.5",
"unyt",
"wurlitzer",
"xmlschema",
]
Expand Down

0 comments on commit f180148

Please sign in to comment.