Skip to content

Commit f1afb5d

Browse files
committedDec 29, 2019
Rename and add default behaviour for auto_unsqueeze_arg decorator
1 parent 52cdf4b commit f1afb5d

10 files changed

+34
-62
lines changed
 

‎qucumber/nn_states/complex_wavefunction.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,7 @@ def rotated_gradient(self, basis, sample):
176176
)
177177
inv_Upsi = cplx.inverse(Upsi)
178178

179-
vr = v.reshape(-1, v.shape[-1])
180-
raw_grads = [
181-
self.am_grads(vr).reshape(2, *v.shape[:-1], -1),
182-
self.ph_grads(vr).reshape(2, *v.shape[:-1], -1),
183-
]
179+
raw_grads = [self.am_grads(v), self.ph_grads(v)]
184180

185181
rotated_grad = [cplx.einsum("s...,s...g->...g", Upsi_v, g) for g in raw_grads]
186182
grad = [

‎qucumber/nn_states/neural_state.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ def generate_hilbert_space(self, size=None, device=None):
177177

178178
def normalization(self, space):
179179
r"""Compute the normalization constant of the state.
180+
In the case of a pure state, this is the norm of the unnormalized wavefunction.
181+
In the case of a mixed state, this is the trace of the unnormalized density
182+
matrix.
180183
181184
.. math::
182185
@@ -190,7 +193,7 @@ def normalization(self, space):
190193
return self.rbm_am.partition(space)
191194

192195
def compute_normalization(self, space):
193-
"""Alias for `normalization`"""
196+
"""Alias for :func:`normalization<qucumber.nn_states.NeuralStateBase.normalization>`"""
194197
return self.normalization(space)
195198

196199
def save(self, location, metadata=None):

‎qucumber/nn_states/positive_wavefunction.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from qucumber import _warn_on_missing_gpu
1919
from qucumber.rbm import BinaryRBM
20-
from qucumber.utils import cplx, auto_unsqueeze_arg
20+
from qucumber.utils import cplx, auto_unsqueeze_args
2121
from .wavefunction import WaveFunctionBase
2222

2323

@@ -91,7 +91,7 @@ def amplitude(self, v):
9191
"""
9292
return super().amplitude(v)
9393

94-
@auto_unsqueeze_arg(1)
94+
@auto_unsqueeze_args()
9595
def phase(self, v):
9696
r"""Compute the phase of a given vector/matrix of visible states.
9797

‎qucumber/nn_states/wavefunction.py

-14
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,6 @@ def importance_sampling_numerator(self, v, vp):
8686
def importance_sampling_denominator(self, vp):
8787
return self.psi(vp)
8888

89-
def normalization(self, space):
90-
r"""Compute the norm of the wavefunction.
91-
92-
.. math::
93-
94-
Z_{\bm{\lambda}}=
95-
\sum_{\bm{\sigma}} p_{\bm{\lambda}}(\bm{\sigma})
96-
97-
:param space: A rank 2 tensor of the entire visible space.
98-
:type space: torch.Tensor
99-
100-
"""
101-
return super().normalization(space)
102-
10389

10490
# make module path show up properly in sphinx docs
10591
WaveFunctionBase.__module__ = "qucumber.nn_states"

‎qucumber/rbm/binary_rbm.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torch.nn.utils import parameters_to_vector
2121

2222
from qucumber import _warn_on_missing_gpu
23-
from qucumber.utils import auto_unsqueeze_arg
23+
from qucumber.utils import auto_unsqueeze_args
2424

2525

2626
class BinaryRBM(nn.Module):
@@ -71,6 +71,7 @@ def initialize_parameters(self, zero_weights=False):
7171
requires_grad=False,
7272
)
7373

74+
@auto_unsqueeze_args()
7475
def effective_energy(self, v):
7576
r"""The effective energies of the given visible states.
7677
@@ -88,7 +89,7 @@ def effective_energy(self, v):
8889
:returns: The effective energies of the given visible states.
8990
:rtype: torch.Tensor
9091
"""
91-
v = (v.unsqueeze(0) if v.dim() < 2 else v).to(self.weights)
92+
v = v.to(self.weights)
9293
visible_bias_term = torch.matmul(v, self.visible_bias)
9394
hid_bias_term = F.softplus(F.linear(v, self.weights, self.hidden_bias)).sum(-1)
9495

@@ -124,7 +125,7 @@ def effective_energy_gradient(self, v, reduce=True):
124125
vec = [W_grad.view(*v.shape[:-1], -1), vb_grad, hb_grad]
125126
return torch.cat(vec, dim=-1)
126127

127-
@auto_unsqueeze_arg(1)
128+
@auto_unsqueeze_args()
128129
def prob_v_given_h(self, h, out=None):
129130
"""Given a hidden unit configuration, compute the probability
130131
vector of the visible units being on.
@@ -145,7 +146,7 @@ def prob_v_given_h(self, h, out=None):
145146
.clamp_(min=0, max=1)
146147
)
147148

148-
@auto_unsqueeze_arg(1)
149+
@auto_unsqueeze_args()
149150
def prob_h_given_v(self, v, out=None):
150151
"""Given a visible unit configuration, compute the probability
151152
vector of the hidden units being on.

‎qucumber/rbm/purification_rbm.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.nn import functional as F
2020
from torch.nn.utils import parameters_to_vector
2121

22-
from qucumber.utils import cplx, auto_unsqueeze_arg
22+
from qucumber.utils import cplx, auto_unsqueeze_args
2323
from qucumber import _warn_on_missing_gpu
2424

2525

@@ -127,6 +127,7 @@ def initialize_parameters(self, zero_weights=False):
127127
requires_grad=False,
128128
)
129129

130+
@auto_unsqueeze_args()
130131
def effective_energy(self, v, a=None):
131132
r"""Computes the equivalent of the "effective energy" for the RBM. If
132133
`a` is `None`, will analytically trace out the auxiliary units.
@@ -139,7 +140,7 @@ def effective_energy(self, v, a=None):
139140
:returns: The "effective energy" of the RBM. Shape (b,) or (1,).
140141
:rtype: torch.Tensor
141142
"""
142-
v = (v.unsqueeze(0) if v.dim() < 2 else v).to(self.weights_W)
143+
v = v.to(self.weights_W)
143144

144145
vis_term = torch.matmul(v, self.visible_bias) + F.softplus(
145146
F.linear(v, self.weights_W, self.hidden_bias)
@@ -191,7 +192,7 @@ def effective_energy_gradient(self, v, reduce=True):
191192
vec = [W_grad, U_grad, vb_grad, hb_grad, ab_grad]
192193
return torch.cat(vec, dim=-1)
193194

194-
@auto_unsqueeze_arg(1)
195+
@auto_unsqueeze_args()
195196
def prob_h_given_v(self, v, out=None):
196197
r"""Given a visible unit configuration, compute the probability
197198
vector of the hidden units being on
@@ -212,7 +213,7 @@ def prob_h_given_v(self, v, out=None):
212213
.clamp_(min=0, max=1)
213214
)
214215

215-
@auto_unsqueeze_arg(1)
216+
@auto_unsqueeze_args()
216217
def prob_a_given_v(self, v, out=None):
217218
r"""Given a visible unit configuration, compute the probability
218219
vector of the auxiliary units being on
@@ -233,7 +234,7 @@ def prob_a_given_v(self, v, out=None):
233234
.clamp_(min=0, max=1)
234235
)
235236

236-
@auto_unsqueeze_arg(1, 2)
237+
@auto_unsqueeze_args(1, 2)
237238
def prob_v_given_ha(self, h, a, out=None):
238239
r"""Given a hidden and auxiliary unit configuration, compute
239240
the probability vector of the hidden units being on
@@ -335,7 +336,7 @@ def gibbs_steps(self, k, initial_state, overwrite=False):
335336

336337
return v
337338

338-
@auto_unsqueeze_arg(1)
339+
@auto_unsqueeze_args()
339340
def mixing_term(self, v):
340341
r"""Describes the extent of mixing in the system,
341342
:math:`V_\theta = \frac{1}{2}U_\theta \bm{\sigma} + d_\theta`

‎qucumber/utils/__init__.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
from functools import wraps
1818

1919

20-
class auto_unsqueeze_arg:
20+
class auto_unsqueeze_args:
2121
def __init__(self, *arg_indices):
22-
self.arg_indices = arg_indices
22+
self.arg_indices = list(arg_indices)
23+
24+
if len(self.arg_indices) == 0:
25+
self.arg_indices.append(1)
2326

2427
def __call__(self, f):
2528
@wraps(f)

‎qucumber/utils/data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def load_data_DM(
7777
:type bases_path: str
7878
7979
:returns: A list of all input parameters, with the real and imaginary parts
80-
combined into one (PyTorch-hack) complex matrix.
80+
of the target density matrix (if provided) combined into one complex matrix.
8181
:rtype: list
8282
"""
8383
data = []

‎tests/test_models_misc.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,11 @@ def test_positive_wavefunction_psi():
150150
nn_state = PositiveWaveFunction(10, gpu=False)
151151

152152
vis_state = torch.ones(10).to(dtype=torch.double)
153-
actual_psi = nn_state.psi(vis_state)[1].to(vis_state)
154-
expected_psi = torch.zeros(1).to(vis_state)
153+
actual_psi_im = cplx.imag(nn_state.psi(vis_state)).to(vis_state)
154+
expected_psi_im = torch.zeros(1).squeeze().to(vis_state)
155155

156156
msg = "PositiveWaveFunction is giving a non-zero imaginary part!"
157-
assert torch.equal(actual_psi, expected_psi), msg
157+
assert torch.equal(actual_psi_im, expected_psi_im), msg
158158

159159

160160
def test_density_matrix_hermiticity():

‎tests/test_observables.py

+6-24
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
import pytest
1919
import torch
2020

21-
import qucumber.observables as observables
22-
from qucumber.nn_states import WaveFunctionBase
21+
from qucumber import observables
22+
from qucumber.nn_states import WaveFunctionBase, NeuralStateBase
23+
from qucumber.utils import cplx, auto_unsqueeze_args
2324

2425

2526
class MockWaveFunction(WaveFunctionBase):
@@ -66,35 +67,16 @@ def device(self, new_val):
6667
def networks(self):
6768
return ["rbm_am"]
6869

70+
@auto_unsqueeze_args()
6971
def phase(self, v):
70-
if v.dim() == 1:
71-
v = v.unsqueeze(0)
72-
unsqueezed = True
73-
else:
74-
unsqueezed = False
75-
76-
phase = torch.zeros(v.shape[0])
77-
78-
if unsqueezed:
79-
return phase.squeeze_(0)
80-
else:
81-
return phase
72+
return torch.zeros(v.shape[0])
8273

8374
def amplitude(self, v):
8475
return torch.ones(v.size(0)) / torch.sqrt(torch.tensor(float(self.nqubits)))
8576

8677
def psi(self, v):
8778
# vector/tensor of shape (len(v),)
88-
amplitude = self.amplitude(v)
89-
90-
# complex vector; shape: (2, len(v))
91-
psi = torch.zeros((2,) + amplitude.shape).to(
92-
dtype=torch.double, device=self.device
93-
)
94-
psi[0] = amplitude
95-
96-
# squeeze down to complex scalar if there was only one visible state
97-
return psi.squeeze()
79+
return cplx.make_complex(self.amplitude(v))
9880

9981

10082
@pytest.fixture(scope="module")

0 commit comments

Comments
 (0)
Please sign in to comment.