-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
239 lines (180 loc) · 7.45 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
import torch
import matplotlib.pyplot as plt
import numpy as np
from typing import Callable
from FrEIA.utils import force_to
import torch.distributions as D
import os
import tqdm
#Training of an INN
def train(INN,p_data:Callable,device:str,lr:float,milestones:list,gamma:float,batch_size:int,n_batches:int,experiment_name:str,save_freq:int):
"""
parameters:
INN: Normalizing flow to train
p_data: Function to get samples following the target distribution
lr: Learning rate
milestones: Milestones for learning rate decay
gamma: Factor for learning rate decay
batch_size: Bacth size
n_batches: Number of batches
experiment_name: Name of the training run
save_freq: Frequency of saving the state dict
"""
INN.train()
#Create a folder for the training results
folder = "./"+experiment_name
os.mkdir(folder)
#Latent distribution of the model
p_0 = force_to(D.MultivariateNormal(torch.zeros(2),torch.eye(2)),device)
#Initialize the optimizer and the lr scheduler
optimizer = torch.optim.Adam(INN.parameters(),lr = lr)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,milestones=milestones,gamma =gamma)
#Storage
loss_storage = torch.zeros(n_batches)
jacobian_storage = torch.zeros([n_batches,2])
#Train the model
for i in tqdm.tqdm(range(n_batches)):
#Get training data
x = p_data.sample(N = batch_size).to(device)
#Compute the objective
z,jac = INN(x)
nll = - (p_0.log_prob(z) + jac).mean()
#Optimize
optimizer.zero_grad()
nll.backward()
optimizer.step()
lr_scheduler.step()
#Store the results
loss_storage[i] = nll.item()
jacobian_storage[i][0] = jac.mean()
jacobian_storage[i][1] = jac.std().item()
#Store the model
if ((i+1) % save_freq) == 0:
torch.save(INN.state_dict(),folder + f"/state-dict_iteration-{i+1}.pt")
#Save the recorded data
np.savetxt(folder +"/loss.txt",loss_storage.cpu().detach().numpy())
np.savetxt(folder +"/jac.txt",jacobian_storage.cpu().detach().numpy())
#Create coordinate grid
def get_grid_n_dim(res_list:list,lim_list:list):
'''
parameters:
res_list: List of integers containing the number of grid points along the different dimensions
lim_list: Lists of lists contaiing the limits of the gird along the different dimensions
returns:
grid_points: Tensor of shape (N,d) containing the grid points
spacings_tensor: List containing the distance between grid points for each dimension
coordinate_grids_list: List containing the coordinate grids for each dimension
'''
d = len(res_list)
#get ranges for the different dimensions
range_list = [torch.linspace(lim_list[i][0],lim_list[i][1],res_list[i]) for i in range(d)]
#Get the spacings between two points
spacings_tensor = torch.zeros(d)
for i in range(d):
spacings_tensor[i] = range_list[i][1] - range_list[i][0]
#Get grids for the different dimensions
coordinate_grids = torch.meshgrid(range_list,indexing="xy")
#Combine the grids
coordinate_grids_list = []
for i in range(d):
coordinate_grids_list.append(coordinate_grids[i].reshape(-1,1))
grid_points = torch.cat(coordinate_grids_list,-1)
return grid_points,spacings_tensor,coordinate_grids
#Evaluate pdf on a grid
def eval_pdf_on_grid_2D(pdf:Callable,x_lims:list = [-10.0,10.0],y_lims:list = [-10.0,10.0],x_res:int = 200,y_res:int = 200):
"""
parameters:
pdf: Probability density function
x_lims: Limits of the evaluated region in x directions
y_lims: Limits of the evaluated region in y directions
x_res: Number of grid points in x direction
y_res: Number of grid points in y direction
returns:
pdf_grid: Grid of pdf values
x_grid: Grid of x coordinates
y_grid: Grid of y coordinates
"""
grid_points,spacings_tensor,coordinate_grids = get_grid_n_dim(res_list = [x_res,y_res],lim_list = [x_lims,y_lims])
#Evaluate the pdf
pdf_grid = pdf(grid_points).reshape(y_res,x_res)
x_grid = coordinate_grids[0]
y_grid = coordinate_grids[1]
return pdf_grid,x_grid,y_grid
#Visualize the pdf
def plot_pdf_2D(pdf_grid:torch.tensor,x_grid:torch.tensor,y_grid:torch.tensor,ax:plt.axes,fs:int = 20,title:str = ""):
"""
parameters:
pdf_grid: Grid of pdf values
x_grid: Grid of x coordinates
y_grid: Grid of y coordinates
ax: Axes for plotting
fs: Fontsize
title: Title of the plot
"""
ax.pcolormesh(x_grid,y_grid,pdf_grid)
ax.set_xlabel("x",fontsize = fs)
ax.set_ylabel("y",fontsize = fs)
ax.tick_params(axis='x', labelsize=fs)
ax.tick_params(axis='y', labelsize=fs)
ax.set_title(title,fontsize = fs)
plt.tight_layout()
#Data distribution
class GMM():
def __init__(self,means:torch.tensor,covs:torch.tensor,weights:torch.tensor = None)->None:
"""
parameters:
means: Tensor of shape [M,d] containing the locations of the gaussian modes
covs: Tensor of shape [M,d,d] containing the covariance matrices of the gaussian modes
weights: Tensor of shape [M] containing the weights of the gaussian modes. Uniform weights are used if not specified
"""
#get dimensionality of the data set
self.d = len(means[0])
#Get the number of modes
self.M = len(means)
self.mode_list = []
#Check weights
if weights is None:
self.weights = torch.ones(self.M) / self.M
else:
self.weights = weights
if self.weights.sum() != 1.0: raise ValueError()
#Initialize the normal modes
for i in range(self.M):
self.mode_list.append(D.MultivariateNormal(loc = means[i],covariance_matrix = covs[i]))
def __call__(self,x:torch.tensor)->torch.tensor:
"""
Evaluate the pdf of the model.
parameters:
x: Tensor of shape [N,d] containing the evaluation points
returns:
p: Tensor of shape [N] contaiing the pdf value for the evaluation points
"""
p = torch.zeros(len(x))
for i in range(self.M):
p += self.mode_list[i].log_prob(x).exp() * self.weights[i]
return p
def log_prob(self,x)->torch.tensor:
"""
Evaluate the log pdf of the model.
parameters:
x: Tensor of shape [N,d] containing the evaluation points
returns:
log_p: Tensor of shape [N] contaiing the log pdf value for the evaluation points
"""
log_p = self.__call__(x).log()
return log_p
def sample(self,N:int)->torch.tensor:
"""
Generate samples following the distribution
parameters:
N: Number of samples
return:
s: Tensor of shape [N,d] containing the generated samples
"""
i = np.random.choice(a = self.M,size = (N,),p = self.weights)
u,c = np.unique(i,return_counts=True)
s = torch.zeros([0,self.d])
for i in range(self.M):
s_i = self.mode_list[u[i]].sample([c[i]])
s = torch.cat((s,s_i),dim = 0)
return s