forked from openvinotoolkit/openvino_contrib
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcomplex_mul.py
25 lines (21 loc) · 918 Bytes
/
complex_mul.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
class ComplexMul(torch.autograd.Function):
@staticmethod
def symbolic(g, input_tensor, other_tensor, is_conj = True):
return g.op("ComplexMultiplication", input_tensor, other_tensor, is_conj_i=int(is_conj))
@staticmethod
def forward(self, input_tensor, other_tensor):
complex_index = -1
real_part = input_tensor[..., 0] * other_tensor[..., 0] - input_tensor[..., 1] * other_tensor[..., 1]
imaginary_part = input_tensor[..., 0] * other_tensor[..., 1] + input_tensor[..., 1] * other_tensor[..., 0]
multiplication = torch.cat(
[
real_part.unsqueeze(dim=complex_index),
imaginary_part.unsqueeze(dim=complex_index),
],
dim=complex_index,
)
return multiplication