Skip to content

Commit 2032767

Browse files
committed
Implement diff for XTensorVariables
1 parent 2ca8212 commit 2032767

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

pytensor/xtensor/type.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,15 @@ def cumsum(self, dim):
535535
def cumprod(self, dim):
536536
return px.reduction.cumprod(self, dim)
537537

538+
def diff(self, dim, n=1):
539+
"""Compute the n-th discrete difference along the given dimension."""
540+
slice1 = {dim: slice(1, None)}
541+
slice2 = {dim: slice(None, -1)}
542+
x = self
543+
for _ in range(n):
544+
x = x[slice1] - x[slice2]
545+
return x
546+
538547
# Reshaping and reorganizing
539548
# https://docs.xarray.dev/en/latest/api.html#id8
540549
def transpose(

tests/xtensor/test_indexing.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,3 +491,22 @@ def test_indexing_renames_into_update_variable():
491491
expected_result = x_test.copy()
492492
expected_result[idx_test] = y_test
493493
xr_assert_allclose(result, expected_result)
494+
495+
496+
@pytest.mark.parametrize("n", ["implicit", 1, 2])
497+
@pytest.mark.parametrize("dim", ["a", "b"])
498+
def test_diff(dim, n):
499+
x = xtensor(dims=("a", "b"), shape=(7, 11))
500+
if n == "implicit":
501+
out = x.diff(dim)
502+
else:
503+
out = x.diff(dim, n=n)
504+
505+
fn = xr_function([x], out)
506+
x_test = xr_arange_like(x)
507+
res = fn(x_test)
508+
if n == "implicit":
509+
expected_res = x_test.diff(dim)
510+
else:
511+
expected_res = x_test.diff(dim, n=n)
512+
xr_assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)