Skip to content

Commit

Permalink
Add reproducible code for matrix exponential simulations (#37)
Browse files Browse the repository at this point in the history
* Add matrix exponential simulations

* Fix bug enforcing real eigen values

* Add label to plots

* Minor plot change

* Remove import

* Add figures

* Add comments to run

* Better casing
  • Loading branch information
SamDuffield authored Jul 22, 2024
1 parent eb6dde1 commit ccd6b73
Show file tree
Hide file tree
Showing 9 changed files with 267 additions and 5 deletions.
1 change: 1 addition & 0 deletions examples/matrix_exponentials/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pkl
13 changes: 13 additions & 0 deletions examples/matrix_exponentials/matrix_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from jax import random, numpy as jnp
from scipy.stats import ortho_group


def wishart(d: int, key: random.PRNGKey) -> jnp.ndarray:
n = 2 * d # degrees of freedom
G = random.normal(key, shape=(d, n))
A_wishart = (G @ G.T) / n
return A_wishart


def orthogonal(d: int, _) -> jnp.ndarray:
return ortho_group.rvs(d)
Binary file added examples/matrix_exponentials/orthogonal_abs.pdf
Binary file not shown.
Binary file added examples/matrix_exponentials/orthogonal_rel.pdf
Binary file not shown.
138 changes: 138 additions & 0 deletions examples/matrix_exponentials/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import pickle
import numpy as np
import matplotlib.pyplot as plt
import argparse


parser = argparse.ArgumentParser()
parser.add_argument("--save_dir", type=str)
args = parser.parse_args()


matrix_type = args.save_dir.split("/")[-1].split("_")[1].split(".")[0]


# use latex for plots
plt.rc("text", usetex=True)
# set font
plt.rc("font", family="serif")
# set font size
plt.rcParams.update({"font.size": 10})

colors = [
plt.cm.viridis(0.2),
plt.cm.viridis(0.4),
plt.cm.viridis(0.6),
plt.cm.viridis(0.8),
]


results = pickle.load(open(args.save_dir, "rb"))
dt = results["dt"]
NT = results["err_abs"].shape[-1]
D = results["D"]
e0_abs = 8.0 if matrix_type == "wishart" else 19.0
ylabel_abs = (
r"$|| \bar{C} - \exp(-A)||_F$"
if matrix_type == "wishart"
else r"$|| \bar{C} - \exp(-M)||_F$"
)
e0_rel = 0.9
ylabel_rel = (
r"$\frac{|| \bar{C} - \exp(-A)||_F}{||\exp(-A)||_F}$"
if matrix_type == "wishart"
else r"$\frac{|| \bar{C} - \exp(-M)||_F}{||\exp(-M)||_F}$"
)
fig_label = "(A)" if matrix_type == "wishart" else "(B)"


def plot(err, ylabel, e0, save_path, d=False, d_squared=False, fig_label=None):
T = np.arange(NT) * dt
err_mean = err.mean(axis=0)

# find time where error crosses threshold
TC = np.zeros(len(D))
for i in range(len(D)):
TC[i] = np.min(T[10:][err_mean[i, 10:] < e0])

plt.figure(figsize=(7, 4.5))

if fig_label is not None:
plt.gcf().text(0.02, 0.93, fig_label, fontsize=22)

for i in range(len(D)):
plt.plot(T, err_mean[i], color=colors[i])

# Add error bars
for i in range(len(D)):
plt.fill_between(
T,
err_mean[i] - err[:, i].std(axis=0),
err_mean[i] + err[:, i].std(axis=0),
color=colors[i],
alpha=0.3,
zorder=0,
)

plt.loglog()
plt.legend(["d = " + str(D[i]) for i in range(len(D))], loc="upper right")
plt.xlabel(r"Time ($\mu$s)", fontsize=18)
plt.ylabel(ylabel, fontsize=18)

# show threshold as horizontal line
plt.axhline(e0, color="k", linestyle="--")
# show crossing times as vertical lines
for i in range(len(D)):
plt.axvline(TC[i], color=colors[i], linestyle="--")

plt.xlim(30, T[-1])

# inset plot showing crossing time as a function of dimension
ax = plt.axes([0.17, 0.22, 0.3, 0.35])
ax.tick_params(axis="y", direction="in", pad=-22)
ax.tick_params(axis="x", direction="in", pad=-15)

for i in range(len(D)):
ax.scatter(D[i], TC[i], color=colors[i], zorder=10)

ts = np.array([10, 2000])

if d:
plt.plot(ts, 100 * ts, color="black", linestyle="--")
plt.text(600, 8e4, s=r"$t_C = d$", rotation=25)

if d_squared:
plt.plot(ts, 0.3 * ts**2, color="black", linestyle="--")
plt.text(550, 1.7e5, s=r"$t_C = d^2$", rotation=25)

plt.plot(D, TC, color="black", zorder=0)
plt.xlim(20, 1500)

plt.loglog()
plt.xlabel(r"$d$", fontsize=15)
plt.ylabel(r"$t_C$", fontsize=15)
plt.minorticks_off()

plt.tight_layout()
plt.savefig(save_path, dpi=300)
plt.show()


plot(
results["err_abs"],
ylabel_abs,
e0_abs,
f"examples/matrix_exponentials/{matrix_type}_abs.pdf",
d_squared=True,
fig_label=fig_label,
)


plot(
results["err_rel"],
ylabel_rel,
e0_rel,
f"examples/matrix_exponentials/{matrix_type}_rel.pdf",
d=True,
fig_label=fig_label,
)
114 changes: 114 additions & 0 deletions examples/matrix_exponentials/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from jax import random, jit, config, numpy as jnp
from jax.scipy.linalg import expm
from jax.lax import scan
import numpy as np
import argparse
from tqdm import tqdm
import thermox
import pickle

from examples.matrix_exponentials import matrix_generation

# Set the precision of the computation
config.update("jax_enable_x64", True)

# Set seed for orthogonal matrix generation
np.random.seed(42)

# Load n_repeats, matrix_type and alpha from the command line
parser = argparse.ArgumentParser()
parser.add_argument("--n_repeats", type=int, default=1)
parser.add_argument("--matrix_type", type=str, default="wishart")
parser.add_argument("--alpha", type=float, default=0.0)
args = parser.parse_args()
get_matrix = getattr(matrix_generation, args.matrix_type)
alpha = args.alpha

# Jit for speed (avoid recompilation)
sample = jit(thermox.sample)

# Hyperparameters shared across all experiments
NT = 10000
dt = 12
ts = jnp.arange(NT) * dt
N_burn = 0
keys = random.split(random.PRNGKey(42), args.n_repeats)
gamma = 1
beta = 1
D = [64, 128, 256, 512]


# Function to compute array of autocovariance errors from samples
@jit
def samps_to_autocovs_errs(samps, true_exp):
def body_func(prev_mat, n):
new_mat = prev_mat * n / (n + 1) + jnp.outer(samps[n], samps[n - 1]) / (n + 1)
err = jnp.linalg.norm(new_mat * jnp.exp(alpha) - true_exp)
return new_mat, err

return scan(
body_func,
jnp.zeros((samps.shape[1], samps.shape[1])),
jnp.arange(1, samps.shape[0]),
)[1]


# Initialize arrays to store errors
err_abs = np.zeros((args.n_repeats, len(D), NT))
err_rel = np.zeros_like(err_abs)

# Loop over repeats and dimensions
for repeat in tqdm(range(args.n_repeats)):
key = keys[repeat]
for i in range(len(D)):
d = D[i]
print(f"Repeat {repeat}/{args.n_repeats}, \t D = {d}")

A = get_matrix(d, key)
exact_exp_min_A = expm(-A)

# Shift and scale A and compute symmetrized B
A_shifted = (A + alpha * jnp.eye(A.shape[0])) / dt
B = A_shifted + A_shifted.T

# Print eigenvalues
A_shifted_lambda_min = jnp.min(jnp.linalg.eig(A_shifted / gamma)[0].real)
print("A Eig min: ", A_shifted_lambda_min)

D_lambda_min = jnp.min(jnp.linalg.eig(B / (gamma * beta))[0].real)
print("D Eig min: ", D_lambda_min)

# Initialize at zeros
x0 = np.zeros(d)

# Run the sampler
X = sample(
key,
ts,
x0,
A_shifted / gamma,
np.zeros(d),
B / (gamma * beta),
)

# Compute absolute error
err_abs = samps_to_autocovs_errs(X, exact_exp_min_A)
err_abs[repeat, i, 1:] = err_abs

# Compute relative error
err_rel[repeat, i, 1:] = err_abs / jnp.linalg.norm(exact_exp_min_A)

# Save results (overwrites after each repeat)
with open(
f"examples/matrix_exponentials/results_{args.matrix_type}.pkl", "wb"
) as f:
pickle.dump(
{
"D": D,
"dt": dt,
"alpha": alpha,
"err_abs": err_abs,
"err_rel": err_rel,
},
f,
)
Binary file added examples/matrix_exponentials/wishart_abs.pdf
Binary file not shown.
Binary file added examples/matrix_exponentials/wishart_rel.pdf
Binary file not shown.
6 changes: 1 addition & 5 deletions thermox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,14 @@ def preprocess_drift_matrix(A: Array) -> ProcessedDriftMatrix:
"""

A_eigvals, A_eigvecs = eig(A + 0.0j)

A_eigvals = A_eigvals.real
A_eigvecs = A_eigvecs.real

A_eigvecs_inv = jnp.linalg.inv(A_eigvecs)

symA = 0.5 * (A + A.T)
symA_eigvals, symA_eigvecs = jnp.linalg.eigh(symA)

return ProcessedDriftMatrix(
A,
A_eigvals.real,
A_eigvals,
A_eigvecs,
A_eigvecs_inv,
symA_eigvals,
Expand Down

0 comments on commit ccd6b73

Please sign in to comment.