diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index d9a2b6dc3459..3a28b5727c6b 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -805,35 +805,24 @@ def new_empty( return torch_frontend.tensor(_data) def unfold(self, dimension, size, step): - # Ensure the dimension size is large enough for unfolding - if self.shape[dimension] < size: - raise ValueError( - f"Dimension size ({self.shape[dimension]}) is " - "smaller than the unfolding size ({size})." - ) - slices = [] self_shape = tuple(self.shape) - - # Create sliding window slices for i in range(0, self_shape[dimension] - size + 1, step): slicing = [slice(None)] * len(self.shape) slicing[dimension] = slice(i, i + size) slices.append(self.ivy_array[tuple(slicing)]) - - # Stack the slices along a new dimension at 'dimension + 1' - stacked = torch_frontend.stack(slices, dim=dimension + 1) - - # Reshape the tensor to insert a new window dimension + stacked = torch_frontend.stack(slices, dim=dimension) new_shape = list(self.shape) num_slices = (self.shape[dimension] - size) // step + 1 - - # Replace size of the unfolded dimension with the number of slices new_shape[dimension] = num_slices - - # Append the window size at the end (correct behavior) - new_shape.append(size) - return stacked.reshape(new_shape) + if dimension == -1: + new_shape.insert(dimension, size) + else: + new_shape.insert(dimension + 1, size) + reshaped = stacked.reshape(new_shape) + dims = list(range(len(stacked.shape))) + dims[-2], dims[-1] = dims[-1], dims[-2] + return reshaped.permute(*dims) def long(self, memory_format=None): self.ivy_array = ivy.astype(self.ivy_array, ivy.int64, copy=False)