From 69f129d7144c2f9aadee99483cee21d68dc21e3a Mon Sep 17 00:00:00 2001 From: Claudia Comito <39374113+ClaudiaComito@users.noreply.github.com> Date: Mon, 19 Aug 2024 06:29:03 +0200 Subject: [PATCH] merge latest main --- heat/core/signal.py | 48 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/heat/core/signal.py b/heat/core/signal.py index 48b02dabb..e7ead1d47 100644 --- a/heat/core/signal.py +++ b/heat/core/signal.py @@ -159,10 +159,12 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: convolution product is only given for points where the signals overlap completely. Values outside the signal boundary have no effect. + Examples -------- Note how the convolution operator flips the second array before "sliding" the two across one another: + >>> a = ht.ones(10) >>> v = ht.arange(3).astype(ht.float) >>> ht.convolve(a, v, mode='full') @@ -175,6 +177,7 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: >>> v = ht.arange(3, split = 0).astype(ht.float) >>> ht.convolve(a, v, mode='valid') DNDarray([3., 3., 3., 3., 3., 3., 3., 3.]) + [0/3] DNDarray([3., 3., 3.]) [1/3] DNDarray([3., 3., 3.]) [2/3] DNDarray([3., 3.]) @@ -182,6 +185,7 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: >>> v = ht.arange(3, split = 0) >>> ht.convolve(a, v) DNDarray([0., 1., 3., 3., 3., 3., 3., 3., 3., 3., 3., 2.], dtype=ht.float32, device=cpu:0, split=0) + [0/3] DNDarray([0., 1., 3., 3.]) [1/3] DNDarray([3., 3., 3., 3.]) [2/3] DNDarray([3., 3., 3., 2.]) @@ -214,7 +218,23 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: [ 0., 40., 81., 83., 85., 87., 44.], [ 0., 0., 45., 46., 47., 48., 49.]], dtype=ht.float64, device=cpu:0, split=0) """ - a, v = inputcheck(a, v) + if np.isscalar(a): + a = array([a]) + if np.isscalar(v): + v = array([v]) + if not isinstance(a, DNDarray): + try: + a = array(a) + except TypeError: + raise TypeError(f"non-supported type for signal: {type(a)}") + if not isinstance(v, DNDarray): + try: + v = array(v) + except TypeError: + raise TypeError(f"non-supported type for filter: {type(v)}") + promoted_type = promote_types(a.dtype, v.dtype) + a = a.astype(promoted_type) + v = v.astype(promoted_type) # check if the filter is longer than the signal and swap them if necessary if v.shape[-1] > a.shape[-1]: @@ -383,7 +403,12 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: # accumulate relevant slice of filtered signal # note, this is a binary operation between unevenly distributed dndarrays and will require communication, check out _operations.__binary_op() - signal_filtered += global_signal_filtered[start_idx : start_idx + gshape] + try: + signal_filtered += global_signal_filtered[start_idx : start_idx + gshape] + except (ValueError, TypeError): + signal_filtered = ( + signal_filtered + global_signal_filtered[start_idx : start_idx + gshape] + ) if r != size - 1: start_idx += v.lshape_map[r + 1][0].item() return signal_filtered @@ -598,12 +623,21 @@ def convolve2d(a, v, mode="full", boundary="fill", fillvalue=0): # accumulate relevant slice of filtered signal # note, this is a binary operation between unevenly distributed dndarrays and will require communication, check out _operations.__binary_op() + # print( + # "DEVICES: signal_filtered, global_signal_filtered, start_idx, gshape", + # signal_filtered.device, + # global_signal_filtered.device, + # start_idx, + # gshape, + # ) print( - "DEVICES: signal_filtered, global_signal_filtered, start_idx, gshape", - signal_filtered.device, - global_signal_filtered.device, - start_idx, - gshape, + "DEBUGGING: signal_filtered.split, global_signal_filtered.split, gshapes, lshapes", + signal_filtered.split, + global_signal_filtered.split, + signal_filtered.gshape, + global_signal_filtered.gshape, + signal_filtered.lshape, + global_signal_filtered.lshape, ) if split_axis == 0: signal_filtered += global_signal_filtered[start_idx : start_idx + gshape[0]]