19
19
from torch .nn import functional as F
20
20
from torch .nn .utils import parameters_to_vector
21
21
22
- from qucumber .utils import cplx , auto_unsqueeze_arg
22
+ from qucumber .utils import cplx , auto_unsqueeze_args
23
23
from qucumber import _warn_on_missing_gpu
24
24
25
25
@@ -127,6 +127,7 @@ def initialize_parameters(self, zero_weights=False):
127
127
requires_grad = False ,
128
128
)
129
129
130
+ @auto_unsqueeze_args ()
130
131
def effective_energy (self , v , a = None ):
131
132
r"""Computes the equivalent of the "effective energy" for the RBM. If
132
133
`a` is `None`, will analytically trace out the auxiliary units.
@@ -139,7 +140,7 @@ def effective_energy(self, v, a=None):
139
140
:returns: The "effective energy" of the RBM. Shape (b,) or (1,).
140
141
:rtype: torch.Tensor
141
142
"""
142
- v = ( v . unsqueeze ( 0 ) if v . dim () < 2 else v ) .to (self .weights_W )
143
+ v = v .to (self .weights_W )
143
144
144
145
vis_term = torch .matmul (v , self .visible_bias ) + F .softplus (
145
146
F .linear (v , self .weights_W , self .hidden_bias )
@@ -191,7 +192,7 @@ def effective_energy_gradient(self, v, reduce=True):
191
192
vec = [W_grad , U_grad , vb_grad , hb_grad , ab_grad ]
192
193
return torch .cat (vec , dim = - 1 )
193
194
194
- @auto_unsqueeze_arg ( 1 )
195
+ @auto_unsqueeze_args ( )
195
196
def prob_h_given_v (self , v , out = None ):
196
197
r"""Given a visible unit configuration, compute the probability
197
198
vector of the hidden units being on
@@ -212,7 +213,7 @@ def prob_h_given_v(self, v, out=None):
212
213
.clamp_ (min = 0 , max = 1 )
213
214
)
214
215
215
- @auto_unsqueeze_arg ( 1 )
216
+ @auto_unsqueeze_args ( )
216
217
def prob_a_given_v (self , v , out = None ):
217
218
r"""Given a visible unit configuration, compute the probability
218
219
vector of the auxiliary units being on
@@ -233,7 +234,7 @@ def prob_a_given_v(self, v, out=None):
233
234
.clamp_ (min = 0 , max = 1 )
234
235
)
235
236
236
- @auto_unsqueeze_arg (1 , 2 )
237
+ @auto_unsqueeze_args (1 , 2 )
237
238
def prob_v_given_ha (self , h , a , out = None ):
238
239
r"""Given a hidden and auxiliary unit configuration, compute
239
240
the probability vector of the hidden units being on
@@ -335,7 +336,7 @@ def gibbs_steps(self, k, initial_state, overwrite=False):
335
336
336
337
return v
337
338
338
- @auto_unsqueeze_arg ( 1 )
339
+ @auto_unsqueeze_args ( )
339
340
def mixing_term (self , v ):
340
341
r"""Describes the extent of mixing in the system,
341
342
:math:`V_\theta = \frac{1}{2}U_\theta \bm{\sigma} + d_\theta`
0 commit comments