-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
155 lines (107 loc) · 3.92 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import math
import random
import numpy as np
import torch
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def divide(tensors, scalar):
return [xx/scalar for xx in tensors]
def multi(tensors, scalar):
return [xx*scalar for xx in tensors]
def sub(tensor_one, tensor_two):
return [xx - yy for xx, yy in zip(tensor_one, tensor_two)]
def add(tensor_one, tensor_two):
return [xx + yy for xx, yy in zip(tensor_one, tensor_two)]
def norm(tensors):
return math.sqrt(sum([torch.sum(tensor ** 2).item() for tensor in tensors]))
def concat(tensors_one, tensors_two):
return list(tensors_one) + list(tensors_two)
def dot(tensors_one, tensors_two):
ret = tensors_one[0].new_zeros((1, ), requires_grad=True)
for t1, t2 in zip(tensors_one, tensors_two):
ret = ret + torch.sum(t1 * t2)
return ret
def images2vectors(images):
return images.view(images.size(0), 784)
def vectors2images(vectors):
return vectors.view(vectors.size(0), 1, 28, 28)
@torch.no_grad()
def confidence(discriminator, data, generator=None):
if generator is not None:
data = generator(data)
if discriminator.__class__.__name__ == "Discriminator":
return discriminator(data).mean()
else:
return discriminator(data).sigmoid().mean()
@torch.no_grad()
def conjugate_gradient(_hvp, b, maxiter=None, tol=1e-30, lam=0.0, use_cache=0, negcur=False, eps=-1e-7):
"""
Minimize 0.5 x^T H^T H x - b^T H x, where H is symmetric
Args:
_hvp (function): hessian vector product, only takes a sequence of tensors as input
b (sequence of tensors): b
maxiter (int): number of iterations
lam (float): regularization constant to avoid singularity of hessian. lam can be positive, zero or negative
(Q = H^T H)
"""
def hvp(inputs):
with torch.enable_grad():
outputs = _hvp(inputs)
outputs = [xx + lam * yy for xx, yy in zip(outputs, inputs)]
return outputs
with torch.enable_grad():
Hb = hvp(b)
# zero initialization
xxs = [hb.new_zeros(hb.size()) for hb in Hb]
ggs = [- hb.clone().detach() for hb in Hb]
dds = [- hb.clone().detach() for hb in Hb]
i = 0
while True:
i += 1
# if negcur:
# # print("i: ", i, "running negcur")
# if dot(dds, hvp(dds)) >= eps:
# if i == 1:
# # print("GDA")
# return b
# else:
# break
with torch.enable_grad():
Qdds = hvp(hvp(dds))
# Qdds = hvp(dds)
# print(dot(ggs, ggs))
# print(norm(ggs))
# if dot(ggs, ggs) < tol:
if norm(ggs) < tol:
break
# one step steepest descent
alpha = - dot(dds, ggs) / dot(dds, Qdds)
xxs = [xx + alpha * dd for xx, dd in zip(xxs, dds)]
# update gradient
ggs = [gg + alpha * Qdd for gg, Qdd in zip(ggs, Qdds)]
# compute the next conjugate direction
beta = dot(ggs, Qdds) / dot(dds, Qdds)
dds = [gg - beta * dd for gg, dd in zip(ggs, dds)]
if maxiter is not None and i >= maxiter:
break
# print("# of conjugate steps {:d}".format(i))
return xxs
def test_conjugate_gradient():
"""Solving A x = grads"""
print('testing conjugate gradient:')
def hvp(lst_tensors):
A = torch.tensor([[2, 1], [1, 3]], dtype=torch.float, device=device)
return [A.mm(tensor) for tensor in lst_tensors]
grads = [torch.tensor([[3], [4]], dtype=torch.float, device=device)]
ret = conjugate_gradient(hvp, grads, maxiter=1)
print(ret)
ret = conjugate_gradient(hvp, grads, maxiter=2)
print(ret)
ret = conjugate_gradient(hvp, grads, maxiter=2, lam=0.01)
print(ret)
# expect ret = [[1], [1]] after two iterations
if __name__ == "__main__":
device = "cuda:0"
test_conjugate_gradient()