Skip to content

Commit

Permalink
Revert "fix: Fix frontend torch.Tensor.unfold method to output correc…
Browse files Browse the repository at this point in the history
…t dimensions"

This reverts commit 9321152.
  • Loading branch information
hmahmood24 committed Sep 15, 2024
1 parent de95d39 commit 8183a22
Showing 1 changed file with 9 additions and 20 deletions.
29 changes: 9 additions & 20 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,35 +805,24 @@ def new_empty(
return torch_frontend.tensor(_data)

def unfold(self, dimension, size, step):
# Ensure the dimension size is large enough for unfolding
if self.shape[dimension] < size:
raise ValueError(
f"Dimension size ({self.shape[dimension]}) is "
"smaller than the unfolding size ({size})."
)

slices = []
self_shape = tuple(self.shape)

# Create sliding window slices
for i in range(0, self_shape[dimension] - size + 1, step):
slicing = [slice(None)] * len(self.shape)
slicing[dimension] = slice(i, i + size)
slices.append(self.ivy_array[tuple(slicing)])

# Stack the slices along a new dimension at 'dimension + 1'
stacked = torch_frontend.stack(slices, dim=dimension + 1)

# Reshape the tensor to insert a new window dimension
stacked = torch_frontend.stack(slices, dim=dimension)
new_shape = list(self.shape)
num_slices = (self.shape[dimension] - size) // step + 1

# Replace size of the unfolded dimension with the number of slices
new_shape[dimension] = num_slices

# Append the window size at the end (correct behavior)
new_shape.append(size)
return stacked.reshape(new_shape)
if dimension == -1:
new_shape.insert(dimension, size)
else:
new_shape.insert(dimension + 1, size)
reshaped = stacked.reshape(new_shape)
dims = list(range(len(stacked.shape)))
dims[-2], dims[-1] = dims[-1], dims[-2]
return reshaped.permute(*dims)

def long(self, memory_format=None):
self.ivy_array = ivy.astype(self.ivy_array, ivy.int64, copy=False)
Expand Down

0 comments on commit 8183a22

Please sign in to comment.