File tree 5 files changed +48
-0
lines changed
5 files changed +48
-0
lines changed Original file line number Diff line number Diff line change @@ -708,6 +708,20 @@ false
708
708
[torch .LongStorage of size 1 ]
709
709
```
710
710
711
+ <a name =" torch.Tensor.isSameSizeAs " />
712
+ ### [ boolean] isSameSizeAs(tensor) ###
713
+
714
+ Returns ` true ` iff the dimensions of the ` Tensor ` and the argument ` Tensor ` are exactly the same.
715
+ ``` lua
716
+ > x = torch .Tensor (4 ,5 )
717
+ > y = torch .Tensor (4 ,5 )
718
+ > = x :isSameSizeAs (y )
719
+ true
720
+ > y = torch .Tensor (4 , 6 )
721
+ > = x :isSameSizeAs (y )
722
+ false
723
+ ```
724
+
711
725
<a name =" torch.Tensor.nElement " />
712
726
### [ number] nElement() ###
713
727
Original file line number Diff line number Diff line change @@ -512,6 +512,14 @@ static int torch_Tensor_(isContiguous)(lua_State *L)
512
512
return 1 ;
513
513
}
514
514
515
+ static int torch_Tensor_ (isSameSizeAs )(lua_State * L )
516
+ {
517
+ THTensor * tensor1 = luaT_checkudata (L , 1 , torch_Tensor );
518
+ THTensor * tensor2 = luaT_checkudata (L , 2 , torch_Tensor );
519
+ lua_pushboolean (L , THTensor_ (isSameSizeAs )(tensor1 , tensor2 ));
520
+ return 1 ;
521
+ }
522
+
515
523
static int torch_Tensor_ (nElement )(lua_State * L )
516
524
{
517
525
THTensor * tensor = luaT_checkudata (L , 1 , torch_Tensor );
@@ -1148,6 +1156,7 @@ static const struct luaL_Reg torch_Tensor_(_) [] = {
1148
1156
{"t" , torch_Tensor_ (t )},
1149
1157
{"unfold" , torch_Tensor_ (unfold )},
1150
1158
{"isContiguous" , torch_Tensor_ (isContiguous )},
1159
+ {"isSameSizeAs" , torch_Tensor_ (isSameSizeAs )},
1151
1160
{"nElement" , torch_Tensor_ (nElement )},
1152
1161
{"copy" , torch_Tensor_ (copy )},
1153
1162
{"apply" , torch_Tensor_ (apply )},
Original file line number Diff line number Diff line change @@ -531,6 +531,19 @@ int THTensor_(isContiguous)(const THTensor *self)
531
531
return 1 ;
532
532
}
533
533
534
+ int THTensor_ (isSameSizeAs )(const THTensor * self , const THTensor * src )
535
+ {
536
+ int d ;
537
+ if (self -> nDimension != src -> nDimension )
538
+ return 0 ;
539
+ for (d = 0 ; d < self -> nDimension ; ++ d )
540
+ {
541
+ if (self -> size [d ] != src -> size [d ])
542
+ return 0 ;
543
+ }
544
+ return 1 ;
545
+ }
546
+
534
547
long THTensor_ (nElement )(const THTensor * self )
535
548
{
536
549
if (self -> nDimension == 0 )
Original file line number Diff line number Diff line change @@ -103,6 +103,7 @@ TH_API void THTensor_(squeeze)(THTensor *self, THTensor *src);
103
103
TH_API void THTensor_ (squeeze1d )(THTensor * self , THTensor * src , int dimension_ );
104
104
105
105
TH_API int THTensor_ (isContiguous )(const THTensor * self );
106
+ TH_API int THTensor_ (isSameSizeAs )(const THTensor * self , const THTensor * src );
106
107
TH_API long THTensor_ (nElement )(const THTensor * self );
107
108
108
109
TH_API void THTensor_ (retain )(THTensor * self );
Original file line number Diff line number Diff line change @@ -1426,6 +1426,17 @@ function torchtest.view()
1426
1426
mytester :asserteq ((target_tensor - tensor ):abs ():max (), 0 , ' Error in viewAs' )
1427
1427
end
1428
1428
1429
+ function torchtest .isSameSizeAs ()
1430
+ local t1 = torch .Tensor (3 , 4 , 9 , 10 )
1431
+ local t2 = torch .Tensor (3 , 4 )
1432
+ local t3 = torch .Tensor (1 , 9 , 3 , 3 )
1433
+ local t4 = torch .Tensor (3 , 4 , 9 , 10 )
1434
+
1435
+ mytester :assert (t1 :isSameSizeAs (t2 ) == false , " wrong answer " )
1436
+ mytester :assert (t1 :isSameSizeAs (t3 ) == false , " wrong answer " )
1437
+ mytester :assert (t1 :isSameSizeAs (t4 ) == true , " wrong answer " )
1438
+ end
1439
+
1429
1440
function torch .test (tests )
1430
1441
math.randomseed (os.time ())
1431
1442
if torch .getdefaulttensortype () == ' torch.FloatTensor' then
You can’t perform that action at this time.
0 commit comments