-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
47 lines (37 loc) · 1.42 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# A very simple script for checking mzi_onn_sim
import torch
import mzi_onn_sim
num_features = 2
num_samples = 3
num_variations = 4
class funcPS(torch.autograd.Function):
@staticmethod
def forward(ctx, input, params):
output = mzi_onn_sim.bp.forwardPS(input, params)
ctx.save_for_backward(input, params)
return output
@staticmethod
def backward(ctx, grad_output):
input, params = ctx.saved_tensors
grad_input, grad_params = mzi_onn_sim.bp.backwardPS(grad_output, input, params)
return grad_input, grad_params.sum(dim=0)
def test_bp(device):
input = torch.randn((num_samples, num_features), dtype=torch.cfloat, device=device)
param = torch.nn.Parameter(torch.randn((num_features), dtype=torch.float, device=device))
output = funcPS.apply(input, param)
target = torch.randn_like(output)
error = output - target
loss = (error * error.conj()).real.sum()
loss.backward()
print(loss)
print(param.grad)
def test_zo(device):
input = torch.randn((num_variations, num_samples, num_features), dtype=torch.cfloat, device=device)
param = torch.nn.Parameter(torch.randn((num_variations, num_features), dtype=torch.float, device=device))
atten = torch.randn((num_features), dtype=torch.cfloat, device=device)
output = mzi_onn_sim.zo.forwardPS(input, param, atten)
print(output)
test_bp('cpu')
test_bp('cuda')
test_zo('cpu')
test_zo('cuda')