-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add reproducible code for matrix exponential simulations (#37)
* Add matrix exponential simulations * Fix bug enforcing real eigen values * Add label to plots * Minor plot change * Remove import * Add figures * Add comments to run * Better casing
- Loading branch information
1 parent
eb6dde1
commit ccd6b73
Showing
9 changed files
with
267 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.pkl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from jax import random, numpy as jnp | ||
from scipy.stats import ortho_group | ||
|
||
|
||
def wishart(d: int, key: random.PRNGKey) -> jnp.ndarray: | ||
n = 2 * d # degrees of freedom | ||
G = random.normal(key, shape=(d, n)) | ||
A_wishart = (G @ G.T) / n | ||
return A_wishart | ||
|
||
|
||
def orthogonal(d: int, _) -> jnp.ndarray: | ||
return ortho_group.rvs(d) |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import pickle | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import argparse | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--save_dir", type=str) | ||
args = parser.parse_args() | ||
|
||
|
||
matrix_type = args.save_dir.split("/")[-1].split("_")[1].split(".")[0] | ||
|
||
|
||
# use latex for plots | ||
plt.rc("text", usetex=True) | ||
# set font | ||
plt.rc("font", family="serif") | ||
# set font size | ||
plt.rcParams.update({"font.size": 10}) | ||
|
||
colors = [ | ||
plt.cm.viridis(0.2), | ||
plt.cm.viridis(0.4), | ||
plt.cm.viridis(0.6), | ||
plt.cm.viridis(0.8), | ||
] | ||
|
||
|
||
results = pickle.load(open(args.save_dir, "rb")) | ||
dt = results["dt"] | ||
NT = results["err_abs"].shape[-1] | ||
D = results["D"] | ||
e0_abs = 8.0 if matrix_type == "wishart" else 19.0 | ||
ylabel_abs = ( | ||
r"$|| \bar{C} - \exp(-A)||_F$" | ||
if matrix_type == "wishart" | ||
else r"$|| \bar{C} - \exp(-M)||_F$" | ||
) | ||
e0_rel = 0.9 | ||
ylabel_rel = ( | ||
r"$\frac{|| \bar{C} - \exp(-A)||_F}{||\exp(-A)||_F}$" | ||
if matrix_type == "wishart" | ||
else r"$\frac{|| \bar{C} - \exp(-M)||_F}{||\exp(-M)||_F}$" | ||
) | ||
fig_label = "(A)" if matrix_type == "wishart" else "(B)" | ||
|
||
|
||
def plot(err, ylabel, e0, save_path, d=False, d_squared=False, fig_label=None): | ||
T = np.arange(NT) * dt | ||
err_mean = err.mean(axis=0) | ||
|
||
# find time where error crosses threshold | ||
TC = np.zeros(len(D)) | ||
for i in range(len(D)): | ||
TC[i] = np.min(T[10:][err_mean[i, 10:] < e0]) | ||
|
||
plt.figure(figsize=(7, 4.5)) | ||
|
||
if fig_label is not None: | ||
plt.gcf().text(0.02, 0.93, fig_label, fontsize=22) | ||
|
||
for i in range(len(D)): | ||
plt.plot(T, err_mean[i], color=colors[i]) | ||
|
||
# Add error bars | ||
for i in range(len(D)): | ||
plt.fill_between( | ||
T, | ||
err_mean[i] - err[:, i].std(axis=0), | ||
err_mean[i] + err[:, i].std(axis=0), | ||
color=colors[i], | ||
alpha=0.3, | ||
zorder=0, | ||
) | ||
|
||
plt.loglog() | ||
plt.legend(["d = " + str(D[i]) for i in range(len(D))], loc="upper right") | ||
plt.xlabel(r"Time ($\mu$s)", fontsize=18) | ||
plt.ylabel(ylabel, fontsize=18) | ||
|
||
# show threshold as horizontal line | ||
plt.axhline(e0, color="k", linestyle="--") | ||
# show crossing times as vertical lines | ||
for i in range(len(D)): | ||
plt.axvline(TC[i], color=colors[i], linestyle="--") | ||
|
||
plt.xlim(30, T[-1]) | ||
|
||
# inset plot showing crossing time as a function of dimension | ||
ax = plt.axes([0.17, 0.22, 0.3, 0.35]) | ||
ax.tick_params(axis="y", direction="in", pad=-22) | ||
ax.tick_params(axis="x", direction="in", pad=-15) | ||
|
||
for i in range(len(D)): | ||
ax.scatter(D[i], TC[i], color=colors[i], zorder=10) | ||
|
||
ts = np.array([10, 2000]) | ||
|
||
if d: | ||
plt.plot(ts, 100 * ts, color="black", linestyle="--") | ||
plt.text(600, 8e4, s=r"$t_C = d$", rotation=25) | ||
|
||
if d_squared: | ||
plt.plot(ts, 0.3 * ts**2, color="black", linestyle="--") | ||
plt.text(550, 1.7e5, s=r"$t_C = d^2$", rotation=25) | ||
|
||
plt.plot(D, TC, color="black", zorder=0) | ||
plt.xlim(20, 1500) | ||
|
||
plt.loglog() | ||
plt.xlabel(r"$d$", fontsize=15) | ||
plt.ylabel(r"$t_C$", fontsize=15) | ||
plt.minorticks_off() | ||
|
||
plt.tight_layout() | ||
plt.savefig(save_path, dpi=300) | ||
plt.show() | ||
|
||
|
||
plot( | ||
results["err_abs"], | ||
ylabel_abs, | ||
e0_abs, | ||
f"examples/matrix_exponentials/{matrix_type}_abs.pdf", | ||
d_squared=True, | ||
fig_label=fig_label, | ||
) | ||
|
||
|
||
plot( | ||
results["err_rel"], | ||
ylabel_rel, | ||
e0_rel, | ||
f"examples/matrix_exponentials/{matrix_type}_rel.pdf", | ||
d=True, | ||
fig_label=fig_label, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from jax import random, jit, config, numpy as jnp | ||
from jax.scipy.linalg import expm | ||
from jax.lax import scan | ||
import numpy as np | ||
import argparse | ||
from tqdm import tqdm | ||
import thermox | ||
import pickle | ||
|
||
from examples.matrix_exponentials import matrix_generation | ||
|
||
# Set the precision of the computation | ||
config.update("jax_enable_x64", True) | ||
|
||
# Set seed for orthogonal matrix generation | ||
np.random.seed(42) | ||
|
||
# Load n_repeats, matrix_type and alpha from the command line | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--n_repeats", type=int, default=1) | ||
parser.add_argument("--matrix_type", type=str, default="wishart") | ||
parser.add_argument("--alpha", type=float, default=0.0) | ||
args = parser.parse_args() | ||
get_matrix = getattr(matrix_generation, args.matrix_type) | ||
alpha = args.alpha | ||
|
||
# Jit for speed (avoid recompilation) | ||
sample = jit(thermox.sample) | ||
|
||
# Hyperparameters shared across all experiments | ||
NT = 10000 | ||
dt = 12 | ||
ts = jnp.arange(NT) * dt | ||
N_burn = 0 | ||
keys = random.split(random.PRNGKey(42), args.n_repeats) | ||
gamma = 1 | ||
beta = 1 | ||
D = [64, 128, 256, 512] | ||
|
||
|
||
# Function to compute array of autocovariance errors from samples | ||
@jit | ||
def samps_to_autocovs_errs(samps, true_exp): | ||
def body_func(prev_mat, n): | ||
new_mat = prev_mat * n / (n + 1) + jnp.outer(samps[n], samps[n - 1]) / (n + 1) | ||
err = jnp.linalg.norm(new_mat * jnp.exp(alpha) - true_exp) | ||
return new_mat, err | ||
|
||
return scan( | ||
body_func, | ||
jnp.zeros((samps.shape[1], samps.shape[1])), | ||
jnp.arange(1, samps.shape[0]), | ||
)[1] | ||
|
||
|
||
# Initialize arrays to store errors | ||
err_abs = np.zeros((args.n_repeats, len(D), NT)) | ||
err_rel = np.zeros_like(err_abs) | ||
|
||
# Loop over repeats and dimensions | ||
for repeat in tqdm(range(args.n_repeats)): | ||
key = keys[repeat] | ||
for i in range(len(D)): | ||
d = D[i] | ||
print(f"Repeat {repeat}/{args.n_repeats}, \t D = {d}") | ||
|
||
A = get_matrix(d, key) | ||
exact_exp_min_A = expm(-A) | ||
|
||
# Shift and scale A and compute symmetrized B | ||
A_shifted = (A + alpha * jnp.eye(A.shape[0])) / dt | ||
B = A_shifted + A_shifted.T | ||
|
||
# Print eigenvalues | ||
A_shifted_lambda_min = jnp.min(jnp.linalg.eig(A_shifted / gamma)[0].real) | ||
print("A Eig min: ", A_shifted_lambda_min) | ||
|
||
D_lambda_min = jnp.min(jnp.linalg.eig(B / (gamma * beta))[0].real) | ||
print("D Eig min: ", D_lambda_min) | ||
|
||
# Initialize at zeros | ||
x0 = np.zeros(d) | ||
|
||
# Run the sampler | ||
X = sample( | ||
key, | ||
ts, | ||
x0, | ||
A_shifted / gamma, | ||
np.zeros(d), | ||
B / (gamma * beta), | ||
) | ||
|
||
# Compute absolute error | ||
err_abs = samps_to_autocovs_errs(X, exact_exp_min_A) | ||
err_abs[repeat, i, 1:] = err_abs | ||
|
||
# Compute relative error | ||
err_rel[repeat, i, 1:] = err_abs / jnp.linalg.norm(exact_exp_min_A) | ||
|
||
# Save results (overwrites after each repeat) | ||
with open( | ||
f"examples/matrix_exponentials/results_{args.matrix_type}.pkl", "wb" | ||
) as f: | ||
pickle.dump( | ||
{ | ||
"D": D, | ||
"dt": dt, | ||
"alpha": alpha, | ||
"err_abs": err_abs, | ||
"err_rel": err_rel, | ||
}, | ||
f, | ||
) |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters