-
Notifications
You must be signed in to change notification settings - Fork 0
/
bfgs.py
156 lines (126 loc) · 4.44 KB
/
bfgs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
This module provides an implementation of the BFGS optimization algorithm, specifically
tailored for our use case. It is a modified version of the implementation found in
Joint_Supervised_Learning_for_SR/src/architectures/bfgs.py
"""
import time
import numpy as np
import sympy as sp
import torch
from scipy.optimize import minimize
import re
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
# Custom module dictionary
MODULES = {
"sin": np.sin,
"cos": np.cos,
"tan": np.tan,
"asin": np.arcsin,
"acos": np.arccos,
"atan": np.arctan,
"sinh": np.sinh,
"cosh": np.cosh,
"tanh": np.tanh,
# Define 'coth' using np.cosh and np.sinh as NumPy doesn't have a direct coth function
"coth": lambda x: np.cosh(x) / np.sinh(x),
"sqrt": np.sqrt,
"log": np.log,
"exp": np.exp,
"Abs": np.abs,
"numpy": np, # Include numpy for other functions and operations
}
class TimedFun:
def __init__(self, fun, stop_after=10):
self.fun_in = fun
self.started = False
self.stop_after = stop_after
def fun(self, x, *args):
if self.started is False:
self.started = time.time()
elif abs(time.time() - self.started) >= self.stop_after:
raise ValueError("Time is over.")
self.fun_value = self.fun_in(*x, *args)
self.x = x
return self.fun_value
def bfgs(pred_str, X, y):
idx_remove = True
total_variables = ["x_1", "x_2"]
# Check where dimensions not use, and replace them with 1 to avoid numerical issues with BFGS (i.e. absent variables placed in the denominator)
y = y.squeeze()
X = X.clone()
bool_dim = (X == 0).all(axis=1).squeeze()
X[:, :, bool_dim] = 1
candidate = re.sub(r"\bc\b", "constant", pred_str)
expr = candidate
for i in range(candidate.count("constant")):
expr = expr.replace("constant", f"c{i}", 1)
# print('Constructing BFGS loss...')
# if cfg.bfgs.idx_remove:
if idx_remove:
# print('Flag idx remove ON, Removing indeces with high values...')
bool_con = (X < 200).all(axis=2).squeeze()
X = X[:, bool_con, :]
max_y = np.max(np.abs(torch.abs(y).cpu().numpy()))
# print('checking input values range...')
# if max_y > 300:
# print('Attention, input values are very large. Optimization may fail due to numerical issues')
diffs = []
for i in range(X.shape[1]):
curr_expr = expr
# for idx, j in enumerate(cfg.total_variables):
for idx, j in enumerate(total_variables):
curr_expr = sp.sympify(curr_expr).subs(j, X[:, i, idx])
diff = curr_expr - y[i]
diffs.append(diff)
loss = np.mean(np.square(diffs))
# Lists where all restarted will be appended
F_loss = []
consts_ = []
funcs = []
symbols = {i: sp.Symbol(f"c{i}") for i in range(candidate.count("constant"))}
# Ensure 20 valid trys
num = 0
tryout = 0
while num < 20 and tryout < 50:
# Compute number of coefficients
np.random.seed(tryout)
x0 = np.random.randn(len(symbols))
s = list(symbols.values())
# bfgs optimization
fun_timed = TimedFun(fun=sp.lambdify(s, loss, modules=["numpy"]))
if len(x0):
try:
minimize(
fun_timed.fun, x0, method="BFGS"
) # check consts interval and if they are int
except Exception as e:
print(f"Encountered in bfgs: {e}")
tryout += 1
continue
consts_.append(fun_timed.x)
else:
consts_.append([])
final = expr
for i in range(len(s)):
final = sp.sympify(final).replace(s[i], fun_timed.x[i])
funcs.append(final)
values = {x: X[:, :, idx].cpu() for idx, x in enumerate(total_variables)}
# Use the custom module dictionary in lambdify
y_found = sp.lambdify(",".join(total_variables), final, modules=MODULES)(
**values
)
final_loss = np.mean(np.square(y_found - y.cpu()).numpy())
F_loss.append(final_loss)
if not np.isnan(final_loss):
num += 1
tryout += 1
try:
k_best = np.nanargmin(F_loss)
except ValueError:
k_best = 0
# guard against domain problem
no_domain = False
if np.isnan(F_loss[k_best]):
no_domain = True
return funcs[k_best], consts_[k_best], F_loss[k_best], expr, no_domain