Skip to content

Commit

Permalink
extended ama facility
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Jan 3, 2025
1 parent daf7397 commit 55677fc
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 46 deletions.
62 changes: 35 additions & 27 deletions lib/gpt/qcd/sparse_propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, array, cache_size, cache_line_size):
self.source_domain = flav0.source_domain.copy()
self.sink_domain = flav0.sink_domain.copy()

for fac, flav in array[1:]:
for tag, flav in array[1:]:
self.sink_domain.restrict(flav.sink_domain)
self.source_domain.restrict(flav.source_domain)

Expand Down Expand Up @@ -134,23 +134,21 @@ def get_propagator_full(self, i):
keys = [f"{self.coordinates[il].tolist()}" for il in ilist]
paths = [f"/{key}/propagator" for key in keys]
g.default.push_verbose("io", False)
data = {flav: g.load(flav.filename, paths=paths) for fac, flav in self.array}
data = {flav: g.load(flav.filename, paths=paths) for tag, flav in self.array}
g.default.pop_verbose()

# process cache
prp = []
for il, key in zip(ilist, keys):
prp_il = None
for fac, flav in self.array:
prop = g.convert(data[flav][1 + il][key]["propagator"], g.double)
prp_il = {}
for tag, flav in self.array:
prp_il[tag] = g.convert(data[flav][1 + il][key]["propagator"], g.double)
if self.sink_domain.sdomain is not flav.sink_domain.sdomain:
prop = g(
self.sink_domain.sdomain.project * flav.sink_domain.sdomain.promote * prop
prp_il[tag] = g(
self.sink_domain.sdomain.project
* flav.sink_domain.sdomain.promote
* prp_il[tag]
)
if prp_il is None:
prp_il = g(fac * prop)
else:
prp_il += fac * prop
prp.append(prp_il)

self.cache.append((cache_line_idx, prp))
Expand All @@ -159,12 +157,13 @@ def get_propagator_full(self, i):

def __getitem__(self, args):

if isinstance(args, tuple):
i, without = args
assert isinstance(args, tuple)
if len(args) == 3:
tag, i, without = args
else:
i, without = args, None
tag, i, without = *args, None

prp = self.get_propagator_full(i)
prp = self.get_propagator_full(i)[tag]

# sparsen sink if requested
if without is not None:
Expand All @@ -174,9 +173,9 @@ def __getitem__(self, args):

return prp

def __call__(self, i_sink, i_src):
def __call__(self, tag, i_sink, i_src):
# should also work if I give a list of i_sink
prp = self.get_propagator_full(i_src)
prp = self.get_propagator_full(i_src)[tag]
return prp[self.ec[i_sink]]

def source_mask(self):
Expand Down Expand Up @@ -266,25 +265,37 @@ def flavor(roots, *cache_param):
if ttag in prop:
return prop[ttag]

weights = [(1.0, [])]
ret = []
for tag, prec in tag_prec:
low_file = f"{tag}/full/low"
sloppy_file = f"{tag}/full/sloppy"
exact_file = f"{tag}/full/exact"

if len(prec) == 1:
for w in weights:
w[1].append(prec)
elif len(prec) == 3:
weights_a = []
weights_b = []
for w in weights:
weights_a.append((w[0], w[1] + [prec[0]]))
weights_b.append((-w[0], w[1] + [prec[2]]))
weights = weights_a + weights_b

if prec == "s":
prp = flavor_multi([(1.0, get_quark(sloppy_file))], *cache_param)
prp = flavor_multi([("s", get_quark(sloppy_file))], *cache_param)
elif prec == "e":
prp = flavor_multi([(1.0, get_quark(exact_file))], *cache_param)
prp = flavor_multi([("e", get_quark(exact_file))], *cache_param)
elif prec == "l":
prp = flavor_multi([(1.0, get_quark(low_file))], *cache_param)
prp = flavor_multi([("l", get_quark(low_file))], *cache_param)
elif prec == "ems":
prp = flavor_multi(
[(1.0, get_quark(exact_file)), (-1.0, get_quark(sloppy_file))], *cache_param
[("e", get_quark(exact_file)), ("s", get_quark(sloppy_file))], *cache_param
)
elif prec == "sml":
prp = flavor_multi(
[(1.0, get_quark(sloppy_file)), (-1.0, get_quark(low_file))], *cache_param
[("s", get_quark(sloppy_file)), ("l", get_quark(low_file))], *cache_param
)
else:
raise Exception(f"Unknown precision: {prec}")
Expand All @@ -307,8 +318,5 @@ def flavor(roots, *cache_param):
flav.sink_domain = common_sink_domain
flav.sink_domain_update(max_ec_size)

if len(ret) == 1:
ret = ret[0]

prop[ttag] = ret
return ret
prop[ttag] = [weights] + ret
return prop[ttag]
66 changes: 47 additions & 19 deletions tests/qcd/sparse_propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,16 @@
"""
)
quark = g.qcd.sparse_propagator.flavor(f"{work_dir}/30_combined/light.{tag}", *cache_params)
weight, quark = g.qcd.sparse_propagator.flavor(
f"{work_dir}/30_combined/light.{tag}", *cache_params
)
if len(tag) == 1:
assert len(weight) == 1
assert weight[0] == (1.0, [tag])
elif len(tag) == 3:
assert len(weight) == 2
assert weight[0] == (1.0, [tag[0]])
assert weight[1] == (-1.0, [tag[2]])

coordinates = quark.sink_domain.coordinates
nsrc = quark.source_domain.sampled_sites
Expand All @@ -47,10 +56,11 @@
rng = g.random("test")
to_sample = [[rng.uniform_int(min=0, max=nsrc - 1) for i in range(3)] for j in range(20)]
sampled = []
tag0 = weight[0][1][0]
for s in g.qcd.sparse_propagator.cache_optimized_sampler([quark], to_sample):
q0 = quark[s[0]]
q1 = quark[s[1]]
q2 = quark[s[2]]
q0 = quark[tag0, s[0]]
q1 = quark[tag0, s[1]]
q2 = quark[tag0, s[2]]
sampled.append(s)

# check that sampled is only a permutation of to_sample
Expand All @@ -66,10 +76,10 @@

g.message(f"Self-consistency check for src={i_src}")
# get src-to-src propagator as numpy matrices
src_src = quark(np.arange(nsrc), i_src)
src_src = quark(tag0, np.arange(nsrc), i_src)

# get embedding in full lattice to test
full = quark.sink_domain.sdomain.promote(quark[i_src])
full = quark.sink_domain.sdomain.promote(quark[tag0, i_src])

# test that src_src is consistent with quark[i_src] which embeds all into the sparse domain
for i_snk in range(nsrc):
Expand All @@ -79,7 +89,7 @@
# now remove elements
mask = quark.source_mask()
remove = [x for x in [1, 8, 11] if x < nsrc]
sparsened = g(mask * quark[i_src, remove])
sparsened = g(mask * quark[tag0, i_src, remove])

slice = quark.sink_domain.sdomain.slice(sparsened, 3)
for i_snk in range(nsrc):
Expand All @@ -92,9 +102,9 @@


# extra tests between tags
quark_sml = g.qcd.sparse_propagator.flavor(f"{work_dir}/30_combined/light.sml", *cache_params)
quark_s = g.qcd.sparse_propagator.flavor(f"{work_dir}/30_combined/light.s", *cache_params)
quark_l = g.qcd.sparse_propagator.flavor(f"{work_dir}/30_combined/light.l", *cache_params)
_, quark_sml = g.qcd.sparse_propagator.flavor(f"{work_dir}/30_combined/light.sml", *cache_params)
_, quark_s = g.qcd.sparse_propagator.flavor(f"{work_dir}/30_combined/light.s", *cache_params)
_, quark_l = g.qcd.sparse_propagator.flavor(f"{work_dir}/30_combined/light.l", *cache_params)
qs_ec = quark_s.sink_domain.sdomain.unique_embedded_coordinates(quark_l.sink_domain.coordinates)
ql_ec = quark_l.sink_domain.sdomain.unique_embedded_coordinates(quark_l.sink_domain.coordinates)
qsml_ec = quark_sml.sink_domain.sdomain.unique_embedded_coordinates(quark_l.sink_domain.coordinates)
Expand All @@ -104,26 +114,44 @@
assert len(quark_sml.sink_domain.coordinates) == quark_sml.sink_domain.sampled_sites

for i in range(quark_sml.source_domain.sampled_sites):
qs_i = quark_s[i]
ql_i = quark_l[i]
qsml_i = quark_sml[i]
qs_i = quark_s["s", i]
ql_i = quark_l["l", i]
qsml_s_i = quark_sml["s", i]
qsml_l_i = quark_sml["l", i]
for j in range(quark_l.sink_domain.sampled_sites):
eps = qs_i[qs_ec[j]] - ql_i[ql_ec[j]]
# for test data same solver is used for sloppy and low
eps = qs_i[qs_ec[j]] - qsml_s_i[qsml_ec[j]]
assert g.norm2(eps) == 0.0

eps = ql_i[ql_ec[j]] - qsml_l_i[qsml_ec[j]]
assert g.norm2(eps) == 0.0
assert g.norm2(qsml_i[qsml_ec[j]]) == 0.0

qsml_i = g(qsml_s_i - qsml_l_i)
x = quark_sml.sink_domain.sdomain.slice(qsml_i, 3)
for y in x:
assert g.norm2(y) < 1e-13

# test weights for sml-sml
weights, quark_sml_a, quark_sml_b = g.qcd.sparse_propagator.flavor(
[f"{work_dir}/30_combined/light.sml", f"{work_dir}/30_combined/light.sml"], *cache_params
)
assert weights == [(1.0, ["s", "s"]), (-1.0, ["l", "s"]), (-1.0, ["s", "l"]), (1.0, ["l", "l"])]
for i in range(quark_sml_a.source_domain.sampled_sites):
eps = g.norm2(quark_sml_a["s", i] - quark_sml_b["l", i])
assert eps < 1e-13
eps = g.norm2(quark_sml_a["s", i] - quark_sml_a["l", i])
assert eps < 1e-13
eps = g.norm2(quark_sml_a["l", i] - quark_sml_a["l", i])
assert eps < 1e-13

# now test sink-conformal mapping;
# sloppy has more sink sites than low,
# by mapping them jointly we restrict sloppy
# to low's sink domain automatically such that
# they can be used in a conformal manner
quark_s, quark_l = g.qcd.sparse_propagator.flavor(
weights, quark_s, quark_l = g.qcd.sparse_propagator.flavor(
[f"{work_dir}/30_combined/light.s", f"{work_dir}/30_combined/light.l"], *cache_params
)

assert np.array_equal(
quark_s.sink_domain.sdomain.kernel.local_coordinates,
quark_l.sink_domain.sdomain.kernel.local_coordinates,
Expand All @@ -133,6 +161,6 @@
assert len(quark_s.ec) == len(quark_l.ec)

for i in range(quark_sml.source_domain.sampled_sites):
qs_i = quark_s[i]
ql_i = quark_l[i]
qs_i = quark_s["s", i]
ql_i = quark_l["l", i]
assert g.norm2(qs_i - ql_i) < 1e-13

0 comments on commit 55677fc

Please sign in to comment.