Skip to content

Commit

Permalink
Merge pull request #241 from nonhermitian/update-benchmark
Browse files Browse the repository at this point in the history
Allow changing benchmark distance
  • Loading branch information
nonhermitian authored Oct 8, 2024
2 parents 57b636a + 49dfa10 commit bc2b7a6
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 5 deletions.
12 changes: 9 additions & 3 deletions benchmarks/iterative.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
import json
import time
import argparse

import mthree

import logging
logging.basicConfig(level=logging.INFO)

def main():
def main(distance=3):
with open("data/eagle_large_counts.json") as json_file:
counts = json.load(json_file)

mit = mthree.M3Mitigation()
mit.cals_from_file("data/eagle_large_cals.json")

st = time.perf_counter()
quasi = mit.apply_correction(counts, range(127), distance=3)
_ = mit.apply_correction(counts, range(127), distance=distance)
fin = time.perf_counter()
print(fin - st)


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument('--distance', nargs='?', const=3, type=int)
args = parser.parse_args()
main(args.distance)
19 changes: 19 additions & 0 deletions mthree/iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# that they have been altered from the originals.
# pylint: disable=no-name-in-module, invalid-name
"""Iterative solver routines"""
import logging
import time
import numpy as np
import scipy.sparse.linalg as spla

Expand All @@ -19,6 +21,8 @@
from mthree.utils import counts_to_vector, vector_to_quasiprobs
from mthree.exceptions import M3Error

logger = logging.getLogger(__name__)


def iterative_solver(
mitigator,
Expand Down Expand Up @@ -50,14 +54,20 @@ def iterative_solver(
M3Error: Solver did not converge.
"""
cals = mitigator._form_cals(qubits)
st = time.perf_counter()
M = M3MatVec(dict(counts), cals, distance)
fin = time.perf_counter()
logger.info(f"MatVec build time is {fin-st}")
L = spla.LinearOperator(
(M.num_elems, M.num_elems),
matvec=M.matvec,
rmatvec=M.rmatvec,
dtype=np.float32,
)
st = time.perf_counter()
diags = M.get_diagonal()
fin = time.perf_counter()
logger.info(f"Diagonal build time: {fin-st}")

def precond_matvec(x):
out = x / diags
Expand All @@ -66,8 +76,12 @@ def precond_matvec(x):
P = spla.LinearOperator(
(M.num_elems, M.num_elems), precond_matvec, dtype=np.float32
)
st = time.perf_counter()
vec = counts_to_vector(M.sorted_counts)
fin = time.perf_counter()
logger.info(f"Counts to vector time: {fin-st}")

st = time.perf_counter()
out, error = spla.gmres(
L,
vec,
Expand All @@ -78,14 +92,19 @@ def precond_matvec(x):
callback=callback,
callback_type="legacy",
)
fin = time.perf_counter()
logger.info(f"Iterative solver time: {fin-st}")
if error:
raise M3Error("GMRES did not converge: {}".format(error))

gamma = None
if return_mitigation_overhead:
gamma = ainv_onenorm_est_iter(M, tol=tol, max_iter=max_iter)

st = time.perf_counter()
quasi = vector_to_quasiprobs(out, M.sorted_counts)
fin = time.perf_counter()
logger.info(f"Vector to quasi time: {fin-st}")
if details:
return quasi, M.get_col_norms(), gamma
return quasi, gamma
5 changes: 5 additions & 0 deletions mthree/matvec.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# cython: c_string_type=unicode, c_string_encoding=UTF-8
import logging
cimport cython
from cython.parallel cimport prange
import numpy as np
Expand All @@ -21,6 +22,7 @@ from libcpp.map cimport map
from libcpp.string cimport string
from cython.operator cimport dereference, postincrement


cdef extern from "src/distance.h" nogil:
unsigned int hamming_terms(unsigned int num_bits,
unsigned int distance,
Expand Down Expand Up @@ -66,6 +68,7 @@ cdef extern from "src/matvec.h" nogil:
int num_terms,
bool MAX_DIST)

logger = logging.getLogger(__name__)

cdef class M3MatVec():
cdef unsigned char * bitstrings
Expand Down Expand Up @@ -96,6 +99,8 @@ cdef class M3MatVec():
if not self.MAX_DIST:
self.num_terms = <int>hamming_terms(self.num_bits, self.distance, self.num_elems)

logger.info(f"Number of Hamming terms: {self.num_terms}")

self.bitstrings = <unsigned char *>malloc(self.num_bits*self.num_elems*sizeof(unsigned char))
self.col_norms = <float *>malloc(self.num_elems*sizeof(float))

Expand Down
3 changes: 2 additions & 1 deletion mthree/mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,10 @@ def callback(_):
tol,
max_iter,
0,
None,
callback,
return_mitigation_overhead,
)
logger.info(f"Number of GMRES iterations: {iter_count[0]}")
mit_counts.shots = shots
if gamma is not None:
mit_counts.mitigation_overhead = gamma * gamma
Expand Down
2 changes: 1 addition & 1 deletion mthree/src/matvec.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void matvec(const float * __restrict x,
terms += 1;
if (terms == num_terms)
{
break;
break; /* Break out of col for-loop*/
}
}
}
Expand Down

0 comments on commit bc2b7a6

Please sign in to comment.