Skip to content

Commit 52cdf4b

Browse files
committed
Cleanup docs and some functions in cplx
1 parent e326465 commit 52cdf4b

File tree

1 file changed

+39
-55
lines changed

1 file changed

+39
-55
lines changed

qucumber/utils/cplx.py

+39-55
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def scalar_mult(x, y, out=None):
9595
"""
9696
y = y.to(x)
9797
if out is None:
98-
out = torch.zeros(2, *((x[0] * y[0]).shape)).to(x)
98+
out = torch.zeros(2, *((real(x) * real(y)).shape)).to(x)
9999
else:
100100
if out is x or out is y:
101101
raise RuntimeError("Can't overwrite an argument!")
@@ -142,17 +142,20 @@ def inner_prod(x, y):
142142
:returns: The inner product, :math:`\\langle x\\vert y\\rangle`.
143143
:rtype: torch.Tensor
144144
"""
145-
z = torch.zeros(2, dtype=x.dtype, device=x.device)
146-
147-
if len(list(x.size())) == 2 and len(list(y.size())) == 2:
148-
z[0] = torch.dot(x[0], y[0]) + torch.dot(x[1], y[1])
149-
z[1] = torch.dot(x[0], y[1]) + torch.dot(-x[1], y[0])
150-
151-
if len(list(x.size())) == 1 and len(list(y.size())) == 1:
152-
z[0] = (x[0] * y[0]) + (x[1] * y[1])
153-
z[1] = (x[0] * y[1]) + (-x[1] * y[0])
145+
y = y.to(x)
154146

155-
return z
147+
if x.dim() == 2 and y.dim() == 2:
148+
return make_complex(
149+
torch.dot(real(x), real(y)) + torch.dot(imag(x), imag(y)),
150+
torch.dot(real(x), imag(y)) - torch.dot(imag(x), real(y)),
151+
)
152+
elif x.dim() == 1 and y.dim() == 1:
153+
return make_complex(
154+
(real(x) * real(y)) + (imag(x) * imag(y)),
155+
(real(x) * imag(y)) - (imag(x) * real(y)),
156+
)
157+
else:
158+
raise ValueError("Unsupported input shapes!")
156159

157160

158161
def outer_prod(x, y):
@@ -171,7 +174,7 @@ def outer_prod(x, y):
171174
:math:`\\vert x \\rangle\\langle y\\vert`.
172175
:rtype: torch.Tensor
173176
"""
174-
if len(list(x.size())) != 2 or len(list(y.size())) != 2:
177+
if x.dim() != 2 or y.dim() != 2:
175178
raise ValueError("An input is not of the right dimension.")
176179

177180
z = torch.zeros(2, x.size()[1], y.size()[1], dtype=x.dtype, device=x.device)
@@ -269,7 +272,7 @@ def elementwise_division(x, y):
269272

270273

271274
def absolute_value(x):
272-
"""Computes the complex absolute value elementwise.
275+
"""Returns the complex absolute value elementwise.
273276
274277
:param x: A complex tensor.
275278
:type x: torch.Tensor
@@ -279,12 +282,11 @@ def absolute_value(x):
279282
"""
280283
x_star = x.clone()
281284
x_star[1] *= -1
282-
return elementwise_mult(x, x_star)[0].sqrt_()
285+
return real(elementwise_mult(x, x_star)).sqrt_()
283286

284287

285288
def kronecker_prod(x, y):
286-
"""A function that returns the tensor / Kronecker product of 2 complex
287-
tensors, x and y.
289+
"""Returns the tensor / Kronecker product of 2 complex matrices, x and y.
288290
289291
:param x: A complex matrix.
290292
:type x: torch.Tensor
@@ -297,40 +299,16 @@ def kronecker_prod(x, y):
297299
:returns: The Kronecker product of x and y, :math:`x \\otimes y`.
298300
:rtype: torch.Tensor
299301
"""
300-
if len(list(x.size())) != 3 or len(list(y.size())) != 3:
301-
raise ValueError("An input is not of the right dimension.")
302+
if not (x.dim() == y.dim() == 3):
303+
raise ValueError("Inputs must be complex matrices!")
302304

303-
z = torch.zeros(
304-
2,
305-
x.size()[1] * y.size()[1],
306-
x.size()[2] * y.size()[2],
307-
dtype=x.dtype,
308-
device=x.device,
305+
return einsum("ab,cd->acbd", x, y).reshape(
306+
2, x.shape[1] * y.shape[1], x.shape[2] * y.shape[2]
309307
)
310308

311-
row_count = 0
312-
313-
for i in range(x.size()[1]):
314-
for k in range(y.size()[1]):
315-
column_count = 0
316-
for j in range(x.size()[2]):
317-
for l in range(y.size()[2]):
318-
319-
z[0][row_count][column_count] = (x[0][i][j] * y[0][k][l]) - (
320-
x[1][i][j] * y[1][k][l]
321-
)
322-
z[1][row_count][column_count] = (x[0][i][j] * y[1][k][l]) + (
323-
x[1][i][j] * y[0][k][l]
324-
)
325-
326-
column_count += 1
327-
row_count += 1
328-
329-
return z
330-
331309

332310
def sigmoid(x, y):
333-
r"""Computes the sigmoid function of a complex number
311+
r"""Returns the sigmoid function of a complex number. Acts elementwise.
334312
335313
:param x: The real part of the complex number
336314
:type x: torch.Tensor
@@ -348,32 +326,38 @@ def sigmoid(x, y):
348326

349327

350328
def scalar_divide(x, y):
351-
"""A function that computes the division of x by y.
329+
"""Divides `x` by `y`.
330+
If `x` and `y` have the same shape, then acts elementwise.
331+
If `y` is a complex scalar, then performs a scalar division.
352332
353-
:param x: The numerator (a complex scalar, vector or matrix).
333+
:param x: The numerator (a complex tensor).
354334
:type x: torch.Tensor
355-
:param y: The denominator (a complex scalar).
335+
:param y: The denominator (a complex tensor).
356336
:type y: torch.Tensor
357337
358338
:returns: x / y
359339
:rtype: torch.Tensor
360340
"""
361-
y_star = conjugate(y)
362-
numerator = scalar_mult(x, y_star)
363-
denominator = real(scalar_mult(y, y_star))
364-
365-
return numerator / denominator
341+
return scalar_mult(x, inverse(y))
366342

367343

368344
def inverse(z):
345+
"""Returns the multiplicative inverse of `z`. Acts elementwise.
346+
347+
:param z: The complex tensor.
348+
:type z: torch.Tensor
349+
350+
:returns: 1 / z
351+
:rtype: torch.Tensor
352+
"""
369353
z_star = conjugate(z)
370354
denominator = real(scalar_mult(z, z_star))
371355

372356
return z_star / denominator
373357

374358

375359
def norm_sqr(x):
376-
"""A function that returns the squared norm of the argument.
360+
"""Returns the squared norm of the argument.
377361
378362
:param x: A complex scalar.
379363
:type x: torch.Tensor
@@ -385,7 +369,7 @@ def norm_sqr(x):
385369

386370

387371
def norm(x):
388-
"""A function that returns the norm of the argument.
372+
"""Returns the norm of the argument.
389373
390374
:param x: A complex scalar.
391375
:type x: torch.Tensor

0 commit comments

Comments
 (0)