Skip to content

Commit

Permalink
Merge pull request #213 from hiddenSymmetries/fw/smalltweaks
Browse files Browse the repository at this point in the history
Small performance tweaks
  • Loading branch information
mbkumar authored Apr 21, 2022
2 parents 4f63543 + e51d8d6 commit ec4162b
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 33 deletions.
32 changes: 21 additions & 11 deletions src/simsopt/_core/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def __missing__(self, key):

def copy_numpy_dict(d):
res = OptimizableDefaultDict({})
for k in d:
res[k] = d[k].copy()
for k, v in d.items():
res[k] = v.copy()
return res


Expand Down Expand Up @@ -114,31 +114,41 @@ def __add__(self, other):
y = other.data
z = copy_numpy_dict(x)
for k in y:
z[k] += y[k]

if k in z:
z[k] += y[k]
else:
z[k] = y[k].copy()
return Derivative(z)

def __sub__(self, other):
x = self.data
y = other.data
z = copy_numpy_dict(x)
for k in y:
z[k] -= y[k]

for k, yk in y.items():
if k in z:
z[k] -= yk
else:
z[k] = -yk
return Derivative(z)

def __iadd__(self, other):
x = self.data
y = other.data
for k in y:
x[k] += y[k]
for k, yk in y.items():
if k in x:
x[k] += yk
else:
x[k] = yk.copy()
return self

def __isub__(self, other):
x = self.data
y = other.data
for k in y:
x[k] -= y[k]
for k, yk in y.items():
if k in x:
x[k] -= yk
else:
x[k] = -yk
return self

def __mul__(self, other):
Expand Down
5 changes: 3 additions & 2 deletions src/simsopt/_core/graph_optimizable.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,8 @@ def __init__(self,
# instances of same class
self._id = ImmutableId(next(self.__class__._ids))
self.name = self.__class__.__name__ + str(self._id.id)
hash_str = hashlib.sha256(self.name.encode('utf-8')).hexdigest()
self.hash = int(hash_str, 16) % 10**32 # 32 digit int as hash
self._children = set() # This gets populated when the object is passed
# as argument to another Optimizable object
self.return_fns = WeakKeyDefaultDict(list) # Store return fn's required by each child
Expand Down Expand Up @@ -583,8 +585,7 @@ def __str__(self):
return self.name

def __hash__(self) -> int:
hash_str = hashlib.sha256(self.name.encode('utf-8')).hexdigest()
return int(hash_str, 16) % 10**32 # 32 digit int as hash
return self.hash

def __eq__(self, other: Optimizable) -> bool:
"""
Expand Down
8 changes: 6 additions & 2 deletions src/simsopt/geo/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,8 +662,12 @@ def gamma_impl(self, gamma, quadpoints):
"""

self.curve.gamma_impl(gamma, quadpoints)
gamma[:] = gamma @ self.rotmat
if len(quadpoints) == len(self.curve.quadpoints) \
and np.sum((quadpoints-self.curve.quadpoints)**2) < 1e-15:
gamma[:] = self.curve.gamma() @ self.rotmat
else:
self.curve.gamma_impl(gamma, quadpoints)
gamma[:] = gamma @ self.rotmat

def gammadash_impl(self, gammadash):
r"""
Expand Down
14 changes: 4 additions & 10 deletions src/simsopt/objectives/fluxobjective.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from simsopt._core.graph_optimizable import Optimizable
from .._core.derivative import derivative_dec
import numpy as np
import simsoptpp as sopp


class SquaredFlux(Optimizable):
Expand Down Expand Up @@ -32,17 +33,10 @@ def __init__(self, surface, field, target=None):
Optimizable.__init__(self, x0=np.asarray([]), depends_on=[field])

def J(self):
xyz = self.surface.gamma()
n = self.surface.normal()
absn = np.linalg.norm(n, axis=2)
unitn = n * (1./absn)[:, :, None]
Bcoil = self.field.B().reshape(xyz.shape)
Bcoil_n = np.sum(Bcoil*unitn, axis=2)
if self.target is not None:
B_n = (Bcoil_n - self.target)
else:
B_n = Bcoil_n
return 0.5 * np.mean(B_n**2 * absn)
Bcoil = self.field.B().reshape(n.shape)
Btarget = self.target if self.target is not None else []
return sopp.integral_BdotN(Bcoil, Btarget, n)

@derivative_dec
def dJ(self):
Expand Down
23 changes: 23 additions & 0 deletions src/simsoptpp/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,29 @@ PYBIND11_MODULE(simsoptpp, m) {
return C;
});

m.def("integral_BdotN", [](PyArray& Bcoil, PyArray& Btarget, PyArray& n) {
int nphi = Bcoil.shape(0);
int ntheta = Bcoil.shape(1);
double *Bcoil_ptr = &(Bcoil(0, 0, 0));
double *Btarget_ptr = NULL;
if(Btarget.size() == Bcoil.size())
Btarget_ptr = &(Btarget(0, 0, 0));
double *n_ptr = &(n(0, 0, 0));
double res = 0;
#pragma omp parallel for reduction(+:res)
for(int i=0; i<nphi*ntheta; i++){
double normN = std::sqrt(n_ptr[3*i+0]*n_ptr[3*i+0] + n_ptr[3*i+1]*n_ptr[3*i+1] + n_ptr[3*i+2]*n_ptr[3*i+2]);
double Nx = n_ptr[3*i+0]/normN;
double Ny = n_ptr[3*i+1]/normN;
double Nz = n_ptr[3*i+2]/normN;
double BcoildotN = Bcoil_ptr[3*i+0]*Nx + Bcoil_ptr[3*i+1]*Ny + Bcoil_ptr[3*i+2]*Nz;
if(Btarget_ptr != NULL)
BcoildotN -= Btarget_ptr[3*i];
res += (BcoildotN * BcoildotN) * normN;
}
return 0.5 * res / (nphi*ntheta);
});

#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;
#else
Expand Down
42 changes: 34 additions & 8 deletions tests/core/test_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,23 +273,49 @@ def test_sub_mul(self):
assert np.allclose(dj1m2(opt1), -1*dj1(opt1))
assert np.allclose(dj1m2(opt2), -dj2(opt2))

def test_iadd_isub_imul(self):
def test_iadd(self):
opt1 = Opt(n=3)
opt2 = Opt(n=2)

dj1 = opt1.dfoo_vjp(np.ones(3))
dj1_ = opt1.dfoo_vjp(np.ones(3))
dj2 = opt2.dfoo_vjp(np.ones(2))
dj2_ = opt2.dfoo_vjp(np.ones(2))

dj1 += dj1_
dj1 += dj2
assert np.allclose(dj1(opt2), dj2(opt2))
dj1 += dj1
assert np.allclose(dj1(opt1), 2*dj1_(opt1))
dj1 -= 3*dj2
assert np.allclose(dj1(opt2), -1*dj2(opt2))
dj1 *= 1.5
assert np.allclose(dj1(opt2), -1.5*dj2(opt2))
assert np.allclose(dj1(opt1), 3*dj1_(opt1))
assert np.allclose(dj1(opt2), dj2_(opt2))

def test_isub(self):
opt1 = Opt(n=3)
opt2 = Opt(n=2)

dj1 = opt1.dfoo_vjp(np.ones(3))
dj1_ = opt1.dfoo_vjp(np.ones(3))
dj2 = opt2.dfoo_vjp(np.ones(2))
dj2_ = opt2.dfoo_vjp(np.ones(2))

dj1 -= 2*dj1_
dj1 -= dj2
assert np.allclose(dj1(opt1), (-1)*dj1_(opt1))
assert np.allclose(dj1(opt2), -dj2_(opt2))

def test_imul(self):
opt1 = Opt(n=3)
opt2 = Opt(n=2)

dj1 = opt1.dfoo_vjp(np.ones(3))
dj2 = opt2.dfoo_vjp(np.ones(2))

dj1_ = opt1.dfoo_vjp(np.ones(3))
dj2_ = opt2.dfoo_vjp(np.ones(2))

dj1 *= 2.
assert np.allclose(dj1(opt1), 2*dj1_(opt1))
dj = dj1 + 4*dj2
assert np.allclose(dj(opt1), 2*dj1_(opt1))
assert np.allclose(dj(opt2), 4*dj2_(opt2))

def test_zero_when_not_found(self):
opt1 = Opt(n=3)
Expand Down
15 changes: 15 additions & 0 deletions tests/geo/test_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,21 @@ def test_plot(self):
ax = curve.plot(engine=engine, ax=ax, show=False, close=close)
c.plot(engine=engine, ax=ax, close=close, plot_derivative=True, show=show)

def test_rotated_curve_gamma_impl(self):
rc = get_curve("CurveXYZFourier", True, x=100)
c = rc.curve
mat = rc.rotmat

rcg = rc.gamma()
cg = c.gamma()
quadpoints = rc.quadpoints

assert np.allclose(rcg, cg@mat)
# run gamma_impl so that the `else` in RotatedCurve.gamma_impl gets triggered
tmp = np.zeros_like(cg[:10, :])
rc.gamma_impl(tmp, quadpoints[:10])
assert np.allclose(cg[:10, :]@mat, tmp)


if __name__ == "__main__":
unittest.main()

0 comments on commit ec4162b

Please sign in to comment.