diff --git a/examples/matrix_exponentials/.gitignore b/examples/matrix_exponentials/.gitignore new file mode 100644 index 0000000..f2a723b --- /dev/null +++ b/examples/matrix_exponentials/.gitignore @@ -0,0 +1 @@ +*.pkl \ No newline at end of file diff --git a/examples/matrix_exponentials/matrix_generation.py b/examples/matrix_exponentials/matrix_generation.py new file mode 100644 index 0000000..05d1ce8 --- /dev/null +++ b/examples/matrix_exponentials/matrix_generation.py @@ -0,0 +1,13 @@ +from jax import random, numpy as jnp +from scipy.stats import ortho_group + + +def wishart(d: int, key: random.PRNGKey) -> jnp.ndarray: + n = 2 * d # degrees of freedom + G = random.normal(key, shape=(d, n)) + A_wishart = (G @ G.T) / n + return A_wishart + + +def orthogonal(d: int, _) -> jnp.ndarray: + return ortho_group.rvs(d) diff --git a/examples/matrix_exponentials/orthogonal_abs.pdf b/examples/matrix_exponentials/orthogonal_abs.pdf new file mode 100644 index 0000000..6640e02 Binary files /dev/null and b/examples/matrix_exponentials/orthogonal_abs.pdf differ diff --git a/examples/matrix_exponentials/orthogonal_rel.pdf b/examples/matrix_exponentials/orthogonal_rel.pdf new file mode 100644 index 0000000..729286d Binary files /dev/null and b/examples/matrix_exponentials/orthogonal_rel.pdf differ diff --git a/examples/matrix_exponentials/plot.py b/examples/matrix_exponentials/plot.py new file mode 100644 index 0000000..2f7b94a --- /dev/null +++ b/examples/matrix_exponentials/plot.py @@ -0,0 +1,138 @@ +import pickle +import numpy as np +import matplotlib.pyplot as plt +import argparse + + +parser = argparse.ArgumentParser() +parser.add_argument("--save_dir", type=str) +args = parser.parse_args() + + +matrix_type = args.save_dir.split("/")[-1].split("_")[1].split(".")[0] + + +# use latex for plots +plt.rc("text", usetex=True) +# set font +plt.rc("font", family="serif") +# set font size +plt.rcParams.update({"font.size": 10}) + +colors = [ + plt.cm.viridis(0.2), + plt.cm.viridis(0.4), + plt.cm.viridis(0.6), + plt.cm.viridis(0.8), +] + + +results = pickle.load(open(args.save_dir, "rb")) +dt = results["dt"] +NT = results["err_abs"].shape[-1] +D = results["D"] +e0_abs = 8.0 if matrix_type == "wishart" else 19.0 +ylabel_abs = ( + r"$|| \bar{C} - \exp(-A)||_F$" + if matrix_type == "wishart" + else r"$|| \bar{C} - \exp(-M)||_F$" +) +e0_rel = 0.9 +ylabel_rel = ( + r"$\frac{|| \bar{C} - \exp(-A)||_F}{||\exp(-A)||_F}$" + if matrix_type == "wishart" + else r"$\frac{|| \bar{C} - \exp(-M)||_F}{||\exp(-M)||_F}$" +) +fig_label = "(A)" if matrix_type == "wishart" else "(B)" + + +def plot(err, ylabel, e0, save_path, d=False, d_squared=False, fig_label=None): + T = np.arange(NT) * dt + err_mean = err.mean(axis=0) + + # find time where error crosses threshold + TC = np.zeros(len(D)) + for i in range(len(D)): + TC[i] = np.min(T[10:][err_mean[i, 10:] < e0]) + + plt.figure(figsize=(7, 4.5)) + + if fig_label is not None: + plt.gcf().text(0.02, 0.93, fig_label, fontsize=22) + + for i in range(len(D)): + plt.plot(T, err_mean[i], color=colors[i]) + + # Add error bars + for i in range(len(D)): + plt.fill_between( + T, + err_mean[i] - err[:, i].std(axis=0), + err_mean[i] + err[:, i].std(axis=0), + color=colors[i], + alpha=0.3, + zorder=0, + ) + + plt.loglog() + plt.legend(["d = " + str(D[i]) for i in range(len(D))], loc="upper right") + plt.xlabel(r"Time ($\mu$s)", fontsize=18) + plt.ylabel(ylabel, fontsize=18) + + # show threshold as horizontal line + plt.axhline(e0, color="k", linestyle="--") + # show crossing times as vertical lines + for i in range(len(D)): + plt.axvline(TC[i], color=colors[i], linestyle="--") + + plt.xlim(30, T[-1]) + + # inset plot showing crossing time as a function of dimension + ax = plt.axes([0.17, 0.22, 0.3, 0.35]) + ax.tick_params(axis="y", direction="in", pad=-22) + ax.tick_params(axis="x", direction="in", pad=-15) + + for i in range(len(D)): + ax.scatter(D[i], TC[i], color=colors[i], zorder=10) + + ts = np.array([10, 2000]) + + if d: + plt.plot(ts, 100 * ts, color="black", linestyle="--") + plt.text(600, 8e4, s=r"$t_C = d$", rotation=25) + + if d_squared: + plt.plot(ts, 0.3 * ts**2, color="black", linestyle="--") + plt.text(550, 1.7e5, s=r"$t_C = d^2$", rotation=25) + + plt.plot(D, TC, color="black", zorder=0) + plt.xlim(20, 1500) + + plt.loglog() + plt.xlabel(r"$d$", fontsize=15) + plt.ylabel(r"$t_C$", fontsize=15) + plt.minorticks_off() + + plt.tight_layout() + plt.savefig(save_path, dpi=300) + plt.show() + + +plot( + results["err_abs"], + ylabel_abs, + e0_abs, + f"examples/matrix_exponentials/{matrix_type}_abs.pdf", + d_squared=True, + fig_label=fig_label, +) + + +plot( + results["err_rel"], + ylabel_rel, + e0_rel, + f"examples/matrix_exponentials/{matrix_type}_rel.pdf", + d=True, + fig_label=fig_label, +) diff --git a/examples/matrix_exponentials/run.py b/examples/matrix_exponentials/run.py new file mode 100644 index 0000000..88ad020 --- /dev/null +++ b/examples/matrix_exponentials/run.py @@ -0,0 +1,114 @@ +from jax import random, jit, config, numpy as jnp +from jax.scipy.linalg import expm +from jax.lax import scan +import numpy as np +import argparse +from tqdm import tqdm +import thermox +import pickle + +from examples.matrix_exponentials import matrix_generation + +# Set the precision of the computation +config.update("jax_enable_x64", True) + +# Set seed for orthogonal matrix generation +np.random.seed(42) + +# Load n_repeats, matrix_type and alpha from the command line +parser = argparse.ArgumentParser() +parser.add_argument("--n_repeats", type=int, default=1) +parser.add_argument("--matrix_type", type=str, default="wishart") +parser.add_argument("--alpha", type=float, default=0.0) +args = parser.parse_args() +get_matrix = getattr(matrix_generation, args.matrix_type) +alpha = args.alpha + +# Jit for speed (avoid recompilation) +sample = jit(thermox.sample) + +# Hyperparameters shared across all experiments +NT = 10000 +dt = 12 +ts = jnp.arange(NT) * dt +N_burn = 0 +keys = random.split(random.PRNGKey(42), args.n_repeats) +gamma = 1 +beta = 1 +D = [64, 128, 256, 512] + + +# Function to compute array of autocovariance errors from samples +@jit +def samps_to_autocovs_errs(samps, true_exp): + def body_func(prev_mat, n): + new_mat = prev_mat * n / (n + 1) + jnp.outer(samps[n], samps[n - 1]) / (n + 1) + err = jnp.linalg.norm(new_mat * jnp.exp(alpha) - true_exp) + return new_mat, err + + return scan( + body_func, + jnp.zeros((samps.shape[1], samps.shape[1])), + jnp.arange(1, samps.shape[0]), + )[1] + + +# Initialize arrays to store errors +err_abs = np.zeros((args.n_repeats, len(D), NT)) +err_rel = np.zeros_like(err_abs) + +# Loop over repeats and dimensions +for repeat in tqdm(range(args.n_repeats)): + key = keys[repeat] + for i in range(len(D)): + d = D[i] + print(f"Repeat {repeat}/{args.n_repeats}, \t D = {d}") + + A = get_matrix(d, key) + exact_exp_min_A = expm(-A) + + # Shift and scale A and compute symmetrized B + A_shifted = (A + alpha * jnp.eye(A.shape[0])) / dt + B = A_shifted + A_shifted.T + + # Print eigenvalues + A_shifted_lambda_min = jnp.min(jnp.linalg.eig(A_shifted / gamma)[0].real) + print("A Eig min: ", A_shifted_lambda_min) + + D_lambda_min = jnp.min(jnp.linalg.eig(B / (gamma * beta))[0].real) + print("D Eig min: ", D_lambda_min) + + # Initialize at zeros + x0 = np.zeros(d) + + # Run the sampler + X = sample( + key, + ts, + x0, + A_shifted / gamma, + np.zeros(d), + B / (gamma * beta), + ) + + # Compute absolute error + err_abs = samps_to_autocovs_errs(X, exact_exp_min_A) + err_abs[repeat, i, 1:] = err_abs + + # Compute relative error + err_rel[repeat, i, 1:] = err_abs / jnp.linalg.norm(exact_exp_min_A) + + # Save results (overwrites after each repeat) + with open( + f"examples/matrix_exponentials/results_{args.matrix_type}.pkl", "wb" + ) as f: + pickle.dump( + { + "D": D, + "dt": dt, + "alpha": alpha, + "err_abs": err_abs, + "err_rel": err_rel, + }, + f, + ) diff --git a/examples/matrix_exponentials/wishart_abs.pdf b/examples/matrix_exponentials/wishart_abs.pdf new file mode 100644 index 0000000..4390e06 Binary files /dev/null and b/examples/matrix_exponentials/wishart_abs.pdf differ diff --git a/examples/matrix_exponentials/wishart_rel.pdf b/examples/matrix_exponentials/wishart_rel.pdf new file mode 100644 index 0000000..f661fac Binary files /dev/null and b/examples/matrix_exponentials/wishart_rel.pdf differ diff --git a/thermox/utils.py b/thermox/utils.py index ffb4c08..c08e8a6 100644 --- a/thermox/utils.py +++ b/thermox/utils.py @@ -28,10 +28,6 @@ def preprocess_drift_matrix(A: Array) -> ProcessedDriftMatrix: """ A_eigvals, A_eigvecs = eig(A + 0.0j) - - A_eigvals = A_eigvals.real - A_eigvecs = A_eigvecs.real - A_eigvecs_inv = jnp.linalg.inv(A_eigvecs) symA = 0.5 * (A + A.T) @@ -39,7 +35,7 @@ def preprocess_drift_matrix(A: Array) -> ProcessedDriftMatrix: return ProcessedDriftMatrix( A, - A_eigvals.real, + A_eigvals, A_eigvecs, A_eigvecs_inv, symA_eigvals,