This repository has been archived by the owner on May 25, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathminimise_nll_2d.py
62 lines (46 loc) · 2.11 KB
/
minimise_nll_2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
MINIMISATION USING THE 2D QUASI-NEWTON ALGORITHM
01/12/18
@author: SOPHIE MARTIN
"""
import define_functions as f
import matplotlib.pyplot as plt
import numpy as np
import minimiser
from matplotlib.pyplot import cm
from mpl_toolkits.mplot3d import Axes3D
def main():
# Define a range of tau values and alpha to apply the minimiser over
# Assumes that there is a minimum within this range
# Will find local minimum in range
decayfunction = f.DecayFunction()
# Useful to find the initial guess
taus_range = np.linspace(0.2, 0.5, 50)
# alpha cannot be > 1
alpha_range = np.linspace(0.9, 1, 50)
X, Y = np.meshgrid(taus_range, alpha_range)
initialtau, initialalpha, zs = minimiser.find_initial_vectorx(taus_range, alpha_range,
decayfunction.get_2d_nll_values)
Z = zs.reshape(X.shape)
# Returns the minimum vector containing the best value of tau and alpha
minimum, min_list, iterations = minimiser.minimise_quasi_newton(
initialtau, initialalpha, decayfunction.find_2d_nll_value,
0.00001, 0.00001, maxiter=500)
error_matrix = minimiser.find_covariance_error(minimum[0], minimum[1],
decayfunction.find_2d_nll_value, 0.000001)
print('Errors obtained: ', np.sqrt(error_matrix[0,0]), np.sqrt(error_matrix[1,1]))
# Plotting nll over different tau and alpha to decide on best minimum point
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, cmap=cm.coolwarm)
ax.set_xlabel('tau', fontsize=15)
ax.set_ylabel('alpha', fontsize=15)
ax.set_zlabel('NLL value')
ax.set_title('Plotting the NLL over different alpha and tau')
ax.plot([minimum[0,0]],[minimum[1,0]], [decayfunction.find_2d_nll_value(minimum[0,0], minimum[1,0])],
markerfacecolor='yellow', markeredgecolor='yellow', marker='o',
markersize=3, alpha=1)
plt.show()
return minimum, min_list, iterations
if __name__ == "__main__":
minimum, min_list, iterations = main()