Skip to content

Commit

Permalink
merge latest main
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudiaComito committed Aug 19, 2024
1 parent 9d466f4 commit 69f129d
Showing 1 changed file with 41 additions and 7 deletions.
48 changes: 41 additions & 7 deletions heat/core/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,12 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
convolution product is only given for points where the signals
overlap completely. Values outside the signal boundary have no
effect.
Examples
--------
Note how the convolution operator flips the second array
before "sliding" the two across one another:
>>> a = ht.ones(10)
>>> v = ht.arange(3).astype(ht.float)
>>> ht.convolve(a, v, mode='full')
Expand All @@ -175,13 +177,15 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
>>> v = ht.arange(3, split = 0).astype(ht.float)
>>> ht.convolve(a, v, mode='valid')
DNDarray([3., 3., 3., 3., 3., 3., 3., 3.])
[0/3] DNDarray([3., 3., 3.])
[1/3] DNDarray([3., 3., 3.])
[2/3] DNDarray([3., 3.])
>>> a = ht.ones(10, split = 0)
>>> v = ht.arange(3, split = 0)
>>> ht.convolve(a, v)
DNDarray([0., 1., 3., 3., 3., 3., 3., 3., 3., 3., 3., 2.], dtype=ht.float32, device=cpu:0, split=0)
[0/3] DNDarray([0., 1., 3., 3.])
[1/3] DNDarray([3., 3., 3., 3.])
[2/3] DNDarray([3., 3., 3., 2.])
Expand Down Expand Up @@ -214,7 +218,23 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
[ 0., 40., 81., 83., 85., 87., 44.],
[ 0., 0., 45., 46., 47., 48., 49.]], dtype=ht.float64, device=cpu:0, split=0)
"""
a, v = inputcheck(a, v)
if np.isscalar(a):
a = array([a])
if np.isscalar(v):
v = array([v])
if not isinstance(a, DNDarray):
try:
a = array(a)
except TypeError:
raise TypeError(f"non-supported type for signal: {type(a)}")
if not isinstance(v, DNDarray):
try:
v = array(v)
except TypeError:
raise TypeError(f"non-supported type for filter: {type(v)}")
promoted_type = promote_types(a.dtype, v.dtype)
a = a.astype(promoted_type)
v = v.astype(promoted_type)

# check if the filter is longer than the signal and swap them if necessary
if v.shape[-1] > a.shape[-1]:
Expand Down Expand Up @@ -383,7 +403,12 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:

# accumulate relevant slice of filtered signal
# note, this is a binary operation between unevenly distributed dndarrays and will require communication, check out _operations.__binary_op()
signal_filtered += global_signal_filtered[start_idx : start_idx + gshape]
try:
signal_filtered += global_signal_filtered[start_idx : start_idx + gshape]
except (ValueError, TypeError):
signal_filtered = (
signal_filtered + global_signal_filtered[start_idx : start_idx + gshape]
)
if r != size - 1:
start_idx += v.lshape_map[r + 1][0].item()
return signal_filtered
Expand Down Expand Up @@ -598,12 +623,21 @@ def convolve2d(a, v, mode="full", boundary="fill", fillvalue=0):

# accumulate relevant slice of filtered signal
# note, this is a binary operation between unevenly distributed dndarrays and will require communication, check out _operations.__binary_op()
# print(
# "DEVICES: signal_filtered, global_signal_filtered, start_idx, gshape",
# signal_filtered.device,
# global_signal_filtered.device,
# start_idx,
# gshape,
# )
print(
"DEVICES: signal_filtered, global_signal_filtered, start_idx, gshape",
signal_filtered.device,
global_signal_filtered.device,
start_idx,
gshape,
"DEBUGGING: signal_filtered.split, global_signal_filtered.split, gshapes, lshapes",
signal_filtered.split,
global_signal_filtered.split,
signal_filtered.gshape,
global_signal_filtered.gshape,
signal_filtered.lshape,
global_signal_filtered.lshape,
)
if split_axis == 0:
signal_filtered += global_signal_filtered[start_idx : start_idx + gshape[0]]
Expand Down

0 comments on commit 69f129d

Please sign in to comment.