Skip to content

Commit 64821c9

Browse files
committed
adding torch.isSameSizeAs
1 parent 7f99066 commit 64821c9

File tree

5 files changed

+48
-0
lines changed

5 files changed

+48
-0
lines changed

doc/tensor.md

+14
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,20 @@ false
708708
[torch.LongStorage of size 1]
709709
```
710710

711+
<a name="torch.Tensor.isSameSizeAs"/>
712+
### [boolean] isSameSizeAs(tensor) ###
713+
714+
Returns `true` iff the dimensions of the `Tensor` and the argument `Tensor` are exactly the same.
715+
```lua
716+
> x = torch.Tensor(4,5)
717+
> y = torch.Tensor(4,5)
718+
> = x:isSameSizeAs(y)
719+
true
720+
> y = torch.Tensor(4, 6)
721+
> = x:isSameSizeAs(y)
722+
false
723+
```
724+
711725
<a name="torch.Tensor.nElement"/>
712726
### [number] nElement() ###
713727

generic/Tensor.c

+9
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,14 @@ static int torch_Tensor_(isContiguous)(lua_State *L)
512512
return 1;
513513
}
514514

515+
static int torch_Tensor_(isSameSizeAs)(lua_State *L)
516+
{
517+
THTensor *tensor1 = luaT_checkudata(L, 1, torch_Tensor);
518+
THTensor *tensor2 = luaT_checkudata(L, 2, torch_Tensor);
519+
lua_pushboolean(L, THTensor_(isSameSizeAs)(tensor1, tensor2));
520+
return 1;
521+
}
522+
515523
static int torch_Tensor_(nElement)(lua_State *L)
516524
{
517525
THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
@@ -1148,6 +1156,7 @@ static const struct luaL_Reg torch_Tensor_(_) [] = {
11481156
{"t", torch_Tensor_(t)},
11491157
{"unfold", torch_Tensor_(unfold)},
11501158
{"isContiguous", torch_Tensor_(isContiguous)},
1159+
{"isSameSizeAs", torch_Tensor_(isSameSizeAs)},
11511160
{"nElement", torch_Tensor_(nElement)},
11521161
{"copy", torch_Tensor_(copy)},
11531162
{"apply", torch_Tensor_(apply)},

lib/TH/generic/THTensor.c

+13
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,19 @@ int THTensor_(isContiguous)(const THTensor *self)
531531
return 1;
532532
}
533533

534+
int THTensor_(isSameSizeAs)(const THTensor *self, const THTensor* src)
535+
{
536+
int d;
537+
if (self->nDimension != src->nDimension)
538+
return 0;
539+
for(d = 0; d < self->nDimension; ++d)
540+
{
541+
if(self->size[d] != src->size[d])
542+
return 0;
543+
}
544+
return 1;
545+
}
546+
534547
long THTensor_(nElement)(const THTensor *self)
535548
{
536549
if(self->nDimension == 0)

lib/TH/generic/THTensor.h

+1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ TH_API void THTensor_(squeeze)(THTensor *self, THTensor *src);
103103
TH_API void THTensor_(squeeze1d)(THTensor *self, THTensor *src, int dimension_);
104104

105105
TH_API int THTensor_(isContiguous)(const THTensor *self);
106+
TH_API int THTensor_(isSameSizeAs)(const THTensor *self, const THTensor *src);
106107
TH_API long THTensor_(nElement)(const THTensor *self);
107108

108109
TH_API void THTensor_(retain)(THTensor *self);

test/test.lua

+11
Original file line numberDiff line numberDiff line change
@@ -1426,6 +1426,17 @@ function torchtest.view()
14261426
mytester:asserteq((target_tensor-tensor):abs():max(), 0, 'Error in viewAs')
14271427
end
14281428

1429+
function torchtest.isSameSizeAs()
1430+
local t1 = torch.Tensor(3, 4, 9, 10)
1431+
local t2 = torch.Tensor(3, 4)
1432+
local t3 = torch.Tensor(1, 9, 3, 3)
1433+
local t4 = torch.Tensor(3, 4, 9, 10)
1434+
1435+
mytester:assert(t1:isSameSizeAs(t2) == false, "wrong answer ")
1436+
mytester:assert(t1:isSameSizeAs(t3) == false, "wrong answer ")
1437+
mytester:assert(t1:isSameSizeAs(t4) == true, "wrong answer ")
1438+
end
1439+
14291440
function torch.test(tests)
14301441
math.randomseed(os.time())
14311442
if torch.getdefaulttensortype() == 'torch.FloatTensor' then

0 commit comments

Comments
 (0)