forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTestBatched.test_if_noelse.expect
39 lines (39 loc) · 1.56 KB
/
TestBatched.test_if_noelse.expect
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
graph(%a.1_data : Tensor,
%a.1_mask : Tensor,
%a.1_dims : Tensor,
%b_data : Tensor,
%b_mask : Tensor,
%b_dims : Tensor):
%6 : int = prim::Constant[value=1]()
%7 : Tensor = aten::gt(%a.1_data, %b_data)
%8 : Tensor = aten::mul(%a.1_mask, %b_mask)
%9 : Long() = prim::NumToTensor(%6)
%alpha : float = prim::Float(%9)
%data : Tensor = aten::add(%a.1_data, %b_data, %alpha)
%mask : Tensor = aten::mul(%a.1_mask, %b_mask)
%dims : Tensor = aten::__or__(%a.1_dims, %b_dims)
%14 : bool = prim::Constant[value=1]()
%15 : int = prim::Constant[value=1]()
%16 : Tensor = aten::type_as(%8, %7)
%data.2 : Tensor = aten::mul(%7, %16)
%18 : int = aten::dim(%data.2)
%19 : bool = aten::eq(%18, %15)
%cond_data : Tensor, %cond_mask : Tensor = prim::If(%19)
block0():
%22 : int = aten::dim(%data)
%23 : int = aten::sub(%22, %15)
%data.4 : Tensor = prim::Loop(%23, %14, %data.2)
block0(%25 : int, %26 : Tensor):
%27 : int = aten::dim(%26)
%data.3 : Tensor = aten::unsqueeze(%26, %27)
-> (%14, %data.3)
%cond_data.1 : Tensor = aten::expand_as(%data.4, %data)
%cond_mask.1 : Tensor = aten::expand_as(%data.4, %mask)
-> (%cond_data.1, %cond_mask.1)
block1():
-> (%data.2, %data.2)
%res_data : Tensor = aten::where(%cond_data, %data, %a.1_data)
%res_mask : Tensor = aten::where(%cond_mask, %mask, %a.1_mask)
%res_dims : Tensor = aten::__or__(%dims, %a.1_dims)
%34 : (Tensor, Tensor, Tensor) = prim::TupleConstruct(%res_data, %res_mask, %res_dims)
return (%34)