Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/920: Implemented 2D convolution #1007

Draft
wants to merge 69 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
d7463a4
added gepad to heat/core/signal.py
krajsek Apr 27, 2022
7b2d2b9
added convolve2d to heat/core/signal.py
krajsek Apr 27, 2022
8452fc3
refactored
krajsek Apr 27, 2022
d6d8460
refactored
krajsek Apr 27, 2022
7975267
replaced unsqueeze wirh reshape
krajsek Apr 27, 2022
01f813e
refactorized
krajsek Apr 27, 2022
a8025d8
refactorized
krajsek Apr 28, 2022
98b8ba0
refactorized
krajsek Apr 28, 2022
4dc9950
refactorized
krajsek Apr 28, 2022
296c4dd
added balance step for mode=full/valid
krajsek May 12, 2022
f489fe9
added balance for full and valid mode
krajsek May 25, 2022
5c4a656
init commit
shahpratham Jul 20, 2022
5b4a2b2
implemented 2d convolution with distributed kernel
shahpratham Aug 18, 2022
42632a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 18, 2022
0040f93
swap a, v when v is larger and check for different split axis
shahpratham Aug 20, 2022
f665345
Merge branch 'feature920/distributed-2D-convolution' of https://githu…
shahpratham Aug 20, 2022
d1408ae
comparison to none
shahpratham Aug 20, 2022
ba65a04
Merge branch 'main' of https://github.com/helmholtz-analytics/heat in…
shahpratham Sep 3, 2022
51450d2
supported for all modes
shahpratham Sep 7, 2022
adde144
added tests for convolve2d
shahpratham Sep 9, 2022
149f345
supported non-square matrix
shahpratham Sep 11, 2022
3ea900b
used scipy to compute example for tests
shahpratham Sep 11, 2022
8ccdf3e
reformatted
shahpratham Sep 11, 2022
46e4acd
Merge branch 'main' into feature920/distributed-2D-convolution
shahpratham Sep 13, 2022
5a440d7
Merge branch 'main' into feature920/distributed-2D-convolution
shahpratham Sep 27, 2022
dc79757
Merge branch 'main' into feature920/distributed-2D-convolution
shahpratham Nov 3, 2022
3f2c492
Merge branch 'main' of https://github.com/helmholtz-analytics/heat in…
shahpratham Nov 8, 2022
2ec7b15
Merge branch 'feature920/distributed-2D-convolution' of https://githu…
shahpratham Nov 8, 2022
5a51cec
manual merge convolve() from main
ClaudiaComito Mar 29, 2023
f764f8e
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Mar 29, 2023
b8c67df
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Apr 17, 2023
72bbe32
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Apr 28, 2023
df386e1
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito May 8, 2023
9fb64c8
Merge branch 'main' into feature920/distributed-2D-convolution
mrfh92 May 22, 2023
f762a79
Merge branch 'main' into feature920/distributed-2D-convolution
mrfh92 May 25, 2023
90fd61c
Merge branch 'main' into feature920/distributed-2D-convolution
Jun 19, 2023
73aac81
Merge branch 'main' into feature920/distributed-2D-convolution
Jun 19, 2023
ac8cb2f
found where the CI failed: in the second-last test of test_signal one…
Jun 19, 2023
35cef0b
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Jun 21, 2023
e6e5558
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Aug 7, 2023
6f94766
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Aug 21, 2023
3c82dc6
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Aug 28, 2023
bb713e6
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Sep 4, 2023
31ea4af
added inputcheck to heat/core/signal.py
krajsek Sep 17, 2023
707b7f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2023
a948999
added comments
krajsek Sep 17, 2023
deb2154
refactoring
krajsek Sep 17, 2023
2d057a7
added comments to inputcheck
krajsek Sep 17, 2023
88b839b
convolve2d docstring revised
krajsek Sep 17, 2023
3fed4f4
Merge branch 'main' into feature920/distributed-2D-convolution
mrfh92 Oct 5, 2023
6fac51e
Merge branch 'main' into feature920/distributed-2D-convolution
mrfh92 Oct 16, 2023
7c7f4d4
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Oct 23, 2023
e659d62
set dtype to float on GPU
ClaudiaComito Nov 20, 2023
34f3264
fix device conditional statement
ClaudiaComito Nov 20, 2023
fbddaac
debugging
ClaudiaComito Nov 20, 2023
4eba58b
debugging
ClaudiaComito Nov 20, 2023
89e1b8f
debugging
ClaudiaComito Nov 20, 2023
f785188
debugging devices
ClaudiaComito Nov 20, 2023
45fc3f2
local convolution on root process
ClaudiaComito Nov 20, 2023
64a3320
debugging
ClaudiaComito Nov 20, 2023
1e4ee28
debugging
ClaudiaComito Nov 20, 2023
1e49676
update pre-commit-config
ClaudiaComito Nov 22, 2023
6c6c946
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Dec 4, 2023
0489b18
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Mar 8, 2024
04e7cd2
debugging
ClaudiaComito Mar 8, 2024
583b242
Merge branch 'main' into feature920/distributed-2D-convolution
mtar Apr 3, 2024
37d02a0
Merge branch 'main' into feature920/distributed-2D-convolution
mtar Apr 12, 2024
9d466f4
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Aug 15, 2024
69f129d
merge latest main
ClaudiaComito Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 288 additions & 6 deletions heat/core/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,62 @@
from .factories import array, zeros
import torch.nn.functional as fc

__all__ = ["convolve"]
__all__ = ["convolve", "convolve2d"]


def convgenpad(a, signal, pad, boundary, fillvalue):
"""
Adds padding to local PyTorch tensors considering the distributed scheme of the overlying DNDarray.

Parameters
----------
a : DNDarray
Overlying N-dimensional `DNDarray` signal
signal : torch.Tensor
Local Pytorch tensors to be padded
pad: list
list containing paddings per dimensions
boundary: str{‘fill’, ‘wrap’, ‘symm’}, optional
A flag indicating how to handle boundaries:
'fill':
pad input arrays with fillvalue. (default)
'wrap':
circular boundary conditions.
'symm':
symmetrical boundary conditions.
fillvalue: scalar, optional
Value to fill pad input arrays with. Default is 0.
"""
dim = len(signal.shape) - 2
dime = 2 * dim - 1
dimz = 2 * dim - 2
# check if more than one rank is involved
if a.is_distributed() and a.split is not None:
# set the padding of the first rank
if a.comm.rank == 0:
pad[dime - 2 * a.split] = 0
# set the padding of the last rank
elif a.comm.rank == a.comm.size - 1:
pad[dimz - 2 * a.split] = 0
else:
pad[dime - 2 * a.split] = 0
pad[dimz - 2 * a.split] = 0

if boundary == "fill":
signal = fc.pad(signal, pad, mode="constant", value=fillvalue)
elif boundary == "wrap":
signal = fc.pad(signal, pad, mode="circular")
elif boundary == "symm":
signal = fc.pad(signal, pad, mode="reflect")
else:
raise ValueError("Only {'fill', 'wrap', 'symm'} are allowed for boundary")

return signal


def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
"""
Returns the discrete, linear convolution of two one-dimensional `DNDarray`s or scalars.

Parameters
----------
a : DNDarray or scalar
Expand All @@ -39,12 +88,10 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
convolution product is only given for points where the signals
overlap completely. Values outside the signal boundary have no
effect.

Examples
--------
Note how the convolution operator flips the second array
before "sliding" the two across one another:

>>> a = ht.ones(10)
>>> v = ht.arange(3).astype(ht.float)
>>> ht.convolve(a, v, mode='full')
Expand All @@ -57,15 +104,13 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
>>> v = ht.arange(3, split = 0).astype(ht.float)
>>> ht.convolve(a, v, mode='valid')
DNDarray([3., 3., 3., 3., 3., 3., 3., 3.])

[0/3] DNDarray([3., 3., 3.])
[1/3] DNDarray([3., 3., 3.])
[2/3] DNDarray([3., 3.])
>>> a = ht.ones(10, split = 0)
>>> v = ht.arange(3, split = 0)
>>> ht.convolve(a, v)
DNDarray([0., 1., 3., 3., 3., 3., 3., 3., 3., 3., 3., 2.], dtype=ht.float32, device=cpu:0, split=0)

[0/3] DNDarray([0., 1., 3., 3.])
[1/3] DNDarray([3., 3., 3., 3.])
[2/3] DNDarray([3., 3., 3., 2.])
Expand Down Expand Up @@ -204,3 +249,240 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
a.comm,
balanced=False,
).astype(a.dtype.torch_type())


def convolve2d(a, v, mode="full", boundary="fill", fillvalue=0):
"""
Returns the discrete, linear convolution of two two-dimensional HeAT tensors.

Parameters
----------
a : (N,) ht.tensor
Two-dimensional signal HeAT tensor
v : (M,) ht.tensor
Two-dimensional filter weight HeAT tensor.
mode : {'full', 'valid', 'same'}, optional
'full':
By default, mode is 'full'. This returns the convolution at
each point of overlap, with an output shape of (N+M-1,). At
the end-points of the convolution, the signals do not overlap
completely, and boundary effects may be seen.
'same':
Mode 'same' returns output of length 'N'. Boundary
effects are still visible. This mode is not supported for
even sized filter weights
'valid':
Mode 'valid' returns output of length 'N-M+1'. The
convolution product is only given for points where the signals
overlap completely. Values outside the signal boundary have no
effect.
boundary: str{‘fill’, ‘wrap’, ‘symm’}, optional
A flag indicating how to handle boundaries:
'fill':
pad input arrays with fillvalue. (default)
'wrap':
circular boundary conditions.
'symm':
symmetrical boundary conditions.
fillvalue: scalar, optional
Value to fill pad input arrays with. Default is 0.

Returns
-------
out : ht.tensor
Discrete, linear convolution of 'a' and 'v'.

Note : If the filter weight is larger
than fitting into memory, using the FFT for convolution is recommended.

Example
--------
>>> a = ht.ones((5, 5))
>>> v = ht.ones((3, 3))
>>> ht.convolve2d(a, v, mode='valid')
DNDarray([[9., 9., 9.],
[9., 9., 9.],
[9., 9., 9.]], dtype=ht.float32, device=cpu:0, split=None)

>>> a = ht.ones((5,5), split=1)
>>> v = ht.ones((3,3), split=1)
>>> ht.convolve2d(a, v)
DNDarray([[1., 2., 3., 3., 3., 2., 1.],
[2., 4., 6., 6., 6., 4., 2.],
[3., 6., 9., 9., 9., 6., 3.],
[3., 6., 9., 9., 9., 6., 3.],
[3., 6., 9., 9., 9., 6., 3.],
[2., 4., 6., 6., 6., 4., 2.],
[1., 2., 3., 3., 3., 2., 1.]], dtype=ht.float32, device=cpu:0, split=1)
"""
if np.isscalar(a):
a = array([a])
if np.isscalar(v):
v = array([v])
if not isinstance(a, DNDarray):
try:
a = array(a)
except TypeError:
raise TypeError("non-supported type for signal: {}".format(type(a)))
if not isinstance(v, DNDarray):
try:
v = array(v)
except TypeError:
raise TypeError("non-supported type for filter: {}".format(type(v)))
promoted_type = promote_types(a.dtype, v.dtype)
a = a.astype(promoted_type)
v = v.astype(promoted_type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 318 to 334 seems to be independent of the dimension and mabe wrapped into a sanitize function that is then used for any N -D convolution


if a.shape[0] < v.shape[0] and a.shape[1] < v.shape[1]:
a, v = v, a

if len(a.shape) != 2 or len(v.shape) != 2:
raise ValueError("Only 2-dimensional input DNDarrays are allowed")
if a.shape[0] < v.shape[0] or a.shape[1] < v.shape[1]:
raise ValueError("Filter size must not be greater than the signal size")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still valid?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, when a is of dimension (4, 15) and v is of dimension (5, 10)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a more precise Error message: "Filter size must not be larger in one dimension and smaller in the other" as the user might not be aware of the swapping in case of the filter size beeing larger in both dimensions

if mode == "same" and v.shape[0] % 2 == 0:
raise ValueError("Mode 'same' cannot be used with even-sized kernel")
if (a.split == 0 and v.split == 1) or (a.split == 1 and v.split == 0):
raise ValueError("DNDarrays must have same axis of split")

# compute halo size
if a.split == 0 or a.split is None:
halo_size = int(v.lshape_map[0][0]) // 2
else:
halo_size = int(v.lshape_map[0][1]) // 2

# fetch halos and store them in a.halo_next/a.halo_prev
# print("qqa: ", halo_size)
a.get_halo(halo_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fetching halos is only necessary in distributed mode


# apply halos to local array
signal = a.array_with_halos

# check if a local chunk is smaller than the filter size
if a.is_distributed() and signal.size()[0] < v.lshape_map[0][0]:
raise ValueError("Local signal chunk size is smaller than the local filter size.")

if mode == "full":
pad_0 = v.shape[1] - 1
gshape_0 = v.shape[0] + a.shape[0] - 1
pad_1 = v.shape[0] - 1
gshape_1 = v.shape[1] + a.shape[1] - 1
pad = list((pad_0, pad_0, pad_1, pad_1))
gshape = (gshape_0, gshape_1)

elif mode == "same":
pad_0 = v.shape[1] // 2
pad_1 = v.shape[0] // 2
pad = list((pad_0, pad_0, pad_1, pad_1))
gshape = (a.shape[0], a.shape[1])

elif mode == "valid":
pad = list((0,) * 4)
gshape_0 = a.shape[0] - v.shape[0] + 1
gshape_1 = a.shape[1] - v.shape[1] + 1
gshape = (gshape_0, gshape_1)

else:
raise ValueError("Only {'full', 'valid', 'same'} are allowed for mode")

# make signal and filter weight 4D for Pytorch conv2d function
signal = signal.reshape(1, 1, signal.shape[0], signal.shape[1])

# add padding to the borders according to mode
signal = convgenpad(a, signal, pad, boundary, fillvalue)

# flip filter for convolution as PyTorch conv2d computes correlation
v = flip(v, [0, 1])

# compute weight size
if a.split == 0 or a.split is None:
weight_size = int(v.lshape_map[0][0])
current_size = v.larray.shape[0]
else:
weight_size = int(v.lshape_map[0][1])
current_size = v.larray.shape[1]

if current_size != weight_size:
weight_shape = (int(v.lshape_map[0][0]), int(v.lshape_map[0][1]))
target = torch.zeros(weight_shape, dtype=v.larray.dtype, device=v.larray.device)
pad_size = weight_size - current_size
if v.split == 0:
target[pad_size:] = v.larray
else:
target[:, pad_size:] = v.larray
weight = target
else:
weight = v.larray

t_v = weight # stores temporary weight
weight = weight.reshape(1, 1, weight.shape[0], weight.shape[1])

if v.is_distributed():
size = v.comm.size
split_axis = v.split
for r in range(size):
rec_v = v.comm.bcast(t_v, root=r)
t_v1 = rec_v.reshape(1, 1, rec_v.shape[0], rec_v.shape[1])

# apply torch convolution operator
local_signal_filtered = fc.conv2d(signal, t_v1)

# unpack 3D result into 1D
local_signal_filtered = local_signal_filtered[0, 0, :]

# if kernel shape along split axis is even we need to get rid of duplicated values
if a.comm.rank != 0 and weight_size % 2 == 0 and a.split == 0:
local_signal_filtered = local_signal_filtered[1:, :]
if a.comm.rank != 0 and weight_size % 2 == 0 and a.split == 1:
local_signal_filtered = local_signal_filtered[:, 1:]

# accumulate filtered signal on the fly
global_signal_filtered = array(
local_signal_filtered, is_split=split_axis, device=a.device, comm=a.comm
)
if r == 0:
# initialize signal_filtered, starting point of slice
signal_filtered = zeros(
gshape, dtype=a.dtype, split=a.split, device=a.device, comm=a.comm
)
start_idx = 0

# accumulate relevant slice of filtered signal
# note, this is a binary operation between unevenly distributed dndarrays and will require communication, check out _operations.__binary_op()
if split_axis == 0:
signal_filtered += global_signal_filtered[start_idx : start_idx + gshape[0]]
else:
signal_filtered += global_signal_filtered[:, start_idx : start_idx + gshape[1]]
if r != size - 1:
start_idx += v.lshape_map[r + 1][split_axis]

signal_filtered.balance()
return signal_filtered

else:
# apply torch convolution operator
signal_filtered = fc.conv2d(signal, weight)

# unpack 3D result into 1D
signal_filtered = signal_filtered[0, 0, :]

# if kernel shape along split axis is even we need to get rid of duplicated values
if a.comm.rank != 0 and v.lshape_map[0][0] % 2 == 0 and a.split == 0:
signal_filtered = signal_filtered[1:, :]
elif a.comm.rank != 0 and v.lshape_map[0][1] % 2 == 0 and a.split == 1:
signal_filtered = signal_filtered[:, 1:]

result = DNDarray(
signal_filtered.contiguous(),
gshape,
signal_filtered.dtype,
a.split,
a.device,
a.comm,
a.balanced,
).astype(a.dtype.torch_type())

if mode == "full" or mode == "valid":
result.balance()

return result
Loading