-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate_time_plots.py
131 lines (104 loc) · 3.64 KB
/
create_time_plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import sys
import os
from zernipax import set_device
set_device("gpu")
import numpy as np
import matplotlib.pyplot as plt
import timeit
from tqdm import tqdm
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("./ZERN/"))
from zern.zern_core import Zernike
import numpy as np
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("./zernike/"))
from zernike import RZern
sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("./zernpy/src/"))
from zernpy import ZernPol
from zernipax.basis import ZernikePolynomial
from zernipax.zernike import *
from zernipax.plotting import plot_comparison
from zernipax.backend import jax
def fun_zernipax_cpu(ns, ms, r):
with jax.default_device(jax.devices("cpu")[0]):
out = zernike_radial(r, ns, ms, 0)
return out
def fun_zernipax(ns, ms, r):
return zernike_radial_old_desc(r[:, np.newaxis], ns, ms, 0)
def fun_zern(ns, ms, r):
zern = Zernike(0)
all = []
for i in range(len(ms)):
all.append(zern.R_nm_Jacobi(int(ns[i]), int(ms[i]), r))
return np.array(all)
def get_Noll(n, m):
j = n * (n + 1) // 2 + abs(m)
if m >= 0 and (n % 4 == 2 or n % 4 == 3):
j += 1
elif m <= 0 and (n % 4 == 0 or n % 4 == 1):
j += 1
return j
def fun_zernike(ns, ms, r):
all = []
cart = RZern(int(max(ns)))
for i in range(len(ms)):
id_Noll = get_Noll(ns[i], ms[i]) - 1
all.append(cart.Rnm(id_Noll, r))
return np.array(all)
# Timing
res_radial = 100
r = np.linspace(0, 1, res_radial)
times = []
num_exec = 100
range_res = np.arange(10, 101, 2)
for res in tqdm(range_res):
basis = ZernikePolynomial(L=res, M=res, spectral_indexing="ansi", sym="cos")
ms = basis.modes[:, 1]
ns = basis.modes[:, 0]
_ = fun_zernipax(ns, ms, r) # run to compile it once
_ = fun_zernipax_cpu(ns, ms, r) # run to compile it once
t1 = timeit.timeit(lambda: fun_zern(ns, ms, r), number=num_exec)
t2 = timeit.timeit(lambda: fun_zernike(ns, ms, r), number=num_exec)
t3 = timeit.timeit(
lambda: fun_zernipax(ns, ms, r).block_until_ready(), number=num_exec
)
t4 = timeit.timeit(lambda: fun_zernipax_cpu(ns, ms, r).block_until_ready(), number=num_exec)
times.append([t1, t2, t3, t4])
times = np.array(times) * 1000 / num_exec
results = np.vstack((range_res, times)).T
np.savetxt(f"results_cpu_gpu_times_r{res_radial}.txt", results)
plt.figure()
plt.plot(range_res, times[:, 0], label="ZERN")
plt.plot(range_res, times[:, 1], label="ZERNIKE")
plt.plot(range_res, times[:, 2], label="ZERNIPAX GPU")
plt.plot(range_res, times[:, 3], label="ZERNIPAX CPU")
plt.xlabel("Resolution")
plt.ylabel("Time (ms)")
plt.title("Execution Times of Radial Zernike Polynomials")
plt.grid()
plt.legend()
plt.savefig("all_t_compare.png", dpi=1000)
plt.figure()
plt.semilogy(range_res, times[:, 0], label="ZERN")
plt.semilogy(range_res, times[:, 1], label="ZERNIKE")
plt.semilogy(range_res, times[:, 2], label="ZERNIPAX GPU")
plt.semilogy(range_res, times[:, 3], label="ZERNIPAX CPU")
plt.xlabel("Resolution")
plt.ylabel("Time (ms)")
plt.title("Execution Times of Radial Zernike Polynomials")
plt.grid()
plt.legend()
plt.savefig("all_t_compare_log.png", dpi=1000)
plt.figure()
plt.plot(range_res[:10], times[:10, 0], label="ZERN")
plt.plot(range_res[:10], times[:10, 1], label="ZERNIKE")
plt.plot(range_res[:10], times[:10, 2], label="ZERNIPAX GPU")
plt.plot(range_res[:10], times[:10, 3], label="ZERNIPAX CPU")
plt.xlabel("Resolution")
plt.ylabel("Time (ms)")
plt.xticks(range_res[:10])
plt.title("Execution Times of Radial Zernike Polynomials")
plt.grid()
plt.legend()
plt.savefig("all_t_compare_low_res.png", dpi=1000)