@@ -95,7 +95,7 @@ def scalar_mult(x, y, out=None):
95
95
"""
96
96
y = y .to (x )
97
97
if out is None :
98
- out = torch .zeros (2 , * ((x [ 0 ] * y [ 0 ] ).shape )).to (x )
98
+ out = torch .zeros (2 , * ((real ( x ) * real ( y ) ).shape )).to (x )
99
99
else :
100
100
if out is x or out is y :
101
101
raise RuntimeError ("Can't overwrite an argument!" )
@@ -142,17 +142,20 @@ def inner_prod(x, y):
142
142
:returns: The inner product, :math:`\\ langle x\\ vert y\\ rangle`.
143
143
:rtype: torch.Tensor
144
144
"""
145
- z = torch .zeros (2 , dtype = x .dtype , device = x .device )
146
-
147
- if len (list (x .size ())) == 2 and len (list (y .size ())) == 2 :
148
- z [0 ] = torch .dot (x [0 ], y [0 ]) + torch .dot (x [1 ], y [1 ])
149
- z [1 ] = torch .dot (x [0 ], y [1 ]) + torch .dot (- x [1 ], y [0 ])
150
-
151
- if len (list (x .size ())) == 1 and len (list (y .size ())) == 1 :
152
- z [0 ] = (x [0 ] * y [0 ]) + (x [1 ] * y [1 ])
153
- z [1 ] = (x [0 ] * y [1 ]) + (- x [1 ] * y [0 ])
145
+ y = y .to (x )
154
146
155
- return z
147
+ if x .dim () == 2 and y .dim () == 2 :
148
+ return make_complex (
149
+ torch .dot (real (x ), real (y )) + torch .dot (imag (x ), imag (y )),
150
+ torch .dot (real (x ), imag (y )) - torch .dot (imag (x ), real (y )),
151
+ )
152
+ elif x .dim () == 1 and y .dim () == 1 :
153
+ return make_complex (
154
+ (real (x ) * real (y )) + (imag (x ) * imag (y )),
155
+ (real (x ) * imag (y )) - (imag (x ) * real (y )),
156
+ )
157
+ else :
158
+ raise ValueError ("Unsupported input shapes!" )
156
159
157
160
158
161
def outer_prod (x , y ):
@@ -171,7 +174,7 @@ def outer_prod(x, y):
171
174
:math:`\\ vert x \\ rangle\\ langle y\\ vert`.
172
175
:rtype: torch.Tensor
173
176
"""
174
- if len ( list ( x . size ())) != 2 or len ( list ( y . size ()) ) != 2 :
177
+ if x . dim () != 2 or y . dim ( ) != 2 :
175
178
raise ValueError ("An input is not of the right dimension." )
176
179
177
180
z = torch .zeros (2 , x .size ()[1 ], y .size ()[1 ], dtype = x .dtype , device = x .device )
@@ -269,7 +272,7 @@ def elementwise_division(x, y):
269
272
270
273
271
274
def absolute_value (x ):
272
- """Computes the complex absolute value elementwise.
275
+ """Returns the complex absolute value elementwise.
273
276
274
277
:param x: A complex tensor.
275
278
:type x: torch.Tensor
@@ -279,12 +282,11 @@ def absolute_value(x):
279
282
"""
280
283
x_star = x .clone ()
281
284
x_star [1 ] *= - 1
282
- return elementwise_mult (x , x_star )[ 0 ] .sqrt_ ()
285
+ return real ( elementwise_mult (x , x_star )) .sqrt_ ()
283
286
284
287
285
288
def kronecker_prod (x , y ):
286
- """A function that returns the tensor / Kronecker product of 2 complex
287
- tensors, x and y.
289
+ """Returns the tensor / Kronecker product of 2 complex matrices, x and y.
288
290
289
291
:param x: A complex matrix.
290
292
:type x: torch.Tensor
@@ -297,40 +299,16 @@ def kronecker_prod(x, y):
297
299
:returns: The Kronecker product of x and y, :math:`x \\ otimes y`.
298
300
:rtype: torch.Tensor
299
301
"""
300
- if len ( list ( x . size ())) != 3 or len ( list ( y . size ())) != 3 :
301
- raise ValueError ("An input is not of the right dimension. " )
302
+ if not ( x . dim () == y . dim () == 3 ) :
303
+ raise ValueError ("Inputs must be complex matrices! " )
302
304
303
- z = torch .zeros (
304
- 2 ,
305
- x .size ()[1 ] * y .size ()[1 ],
306
- x .size ()[2 ] * y .size ()[2 ],
307
- dtype = x .dtype ,
308
- device = x .device ,
305
+ return einsum ("ab,cd->acbd" , x , y ).reshape (
306
+ 2 , x .shape [1 ] * y .shape [1 ], x .shape [2 ] * y .shape [2 ]
309
307
)
310
308
311
- row_count = 0
312
-
313
- for i in range (x .size ()[1 ]):
314
- for k in range (y .size ()[1 ]):
315
- column_count = 0
316
- for j in range (x .size ()[2 ]):
317
- for l in range (y .size ()[2 ]):
318
-
319
- z [0 ][row_count ][column_count ] = (x [0 ][i ][j ] * y [0 ][k ][l ]) - (
320
- x [1 ][i ][j ] * y [1 ][k ][l ]
321
- )
322
- z [1 ][row_count ][column_count ] = (x [0 ][i ][j ] * y [1 ][k ][l ]) + (
323
- x [1 ][i ][j ] * y [0 ][k ][l ]
324
- )
325
-
326
- column_count += 1
327
- row_count += 1
328
-
329
- return z
330
-
331
309
332
310
def sigmoid (x , y ):
333
- r"""Computes the sigmoid function of a complex number
311
+ r"""Returns the sigmoid function of a complex number. Acts elementwise.
334
312
335
313
:param x: The real part of the complex number
336
314
:type x: torch.Tensor
@@ -348,32 +326,38 @@ def sigmoid(x, y):
348
326
349
327
350
328
def scalar_divide (x , y ):
351
- """A function that computes the division of x by y.
329
+ """Divides `x` by `y`.
330
+ If `x` and `y` have the same shape, then acts elementwise.
331
+ If `y` is a complex scalar, then performs a scalar division.
352
332
353
- :param x: The numerator (a complex scalar, vector or matrix ).
333
+ :param x: The numerator (a complex tensor ).
354
334
:type x: torch.Tensor
355
- :param y: The denominator (a complex scalar ).
335
+ :param y: The denominator (a complex tensor ).
356
336
:type y: torch.Tensor
357
337
358
338
:returns: x / y
359
339
:rtype: torch.Tensor
360
340
"""
361
- y_star = conjugate (y )
362
- numerator = scalar_mult (x , y_star )
363
- denominator = real (scalar_mult (y , y_star ))
364
-
365
- return numerator / denominator
341
+ return scalar_mult (x , inverse (y ))
366
342
367
343
368
344
def inverse (z ):
345
+ """Returns the multiplicative inverse of `z`. Acts elementwise.
346
+
347
+ :param z: The complex tensor.
348
+ :type z: torch.Tensor
349
+
350
+ :returns: 1 / z
351
+ :rtype: torch.Tensor
352
+ """
369
353
z_star = conjugate (z )
370
354
denominator = real (scalar_mult (z , z_star ))
371
355
372
356
return z_star / denominator
373
357
374
358
375
359
def norm_sqr (x ):
376
- """A function that returns the squared norm of the argument.
360
+ """Returns the squared norm of the argument.
377
361
378
362
:param x: A complex scalar.
379
363
:type x: torch.Tensor
@@ -385,7 +369,7 @@ def norm_sqr(x):
385
369
386
370
387
371
def norm (x ):
388
- """A function that returns the norm of the argument.
372
+ """Returns the norm of the argument.
389
373
390
374
:param x: A complex scalar.
391
375
:type x: torch.Tensor
0 commit comments