Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/920: Implemented 2D convolution #1007

Closed
wants to merge 69 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
d7463a4
added gepad to heat/core/signal.py
krajsek Apr 27, 2022
7b2d2b9
added convolve2d to heat/core/signal.py
krajsek Apr 27, 2022
8452fc3
refactored
krajsek Apr 27, 2022
d6d8460
refactored
krajsek Apr 27, 2022
7975267
replaced unsqueeze wirh reshape
krajsek Apr 27, 2022
01f813e
refactorized
krajsek Apr 27, 2022
a8025d8
refactorized
krajsek Apr 28, 2022
98b8ba0
refactorized
krajsek Apr 28, 2022
4dc9950
refactorized
krajsek Apr 28, 2022
296c4dd
added balance step for mode=full/valid
krajsek May 12, 2022
f489fe9
added balance for full and valid mode
krajsek May 25, 2022
5c4a656
init commit
shahpratham Jul 20, 2022
5b4a2b2
implemented 2d convolution with distributed kernel
shahpratham Aug 18, 2022
42632a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 18, 2022
0040f93
swap a, v when v is larger and check for different split axis
shahpratham Aug 20, 2022
f665345
Merge branch 'feature920/distributed-2D-convolution' of https://githu…
shahpratham Aug 20, 2022
d1408ae
comparison to none
shahpratham Aug 20, 2022
ba65a04
Merge branch 'main' of https://github.com/helmholtz-analytics/heat in…
shahpratham Sep 3, 2022
51450d2
supported for all modes
shahpratham Sep 7, 2022
adde144
added tests for convolve2d
shahpratham Sep 9, 2022
149f345
supported non-square matrix
shahpratham Sep 11, 2022
3ea900b
used scipy to compute example for tests
shahpratham Sep 11, 2022
8ccdf3e
reformatted
shahpratham Sep 11, 2022
46e4acd
Merge branch 'main' into feature920/distributed-2D-convolution
shahpratham Sep 13, 2022
5a440d7
Merge branch 'main' into feature920/distributed-2D-convolution
shahpratham Sep 27, 2022
dc79757
Merge branch 'main' into feature920/distributed-2D-convolution
shahpratham Nov 3, 2022
3f2c492
Merge branch 'main' of https://github.com/helmholtz-analytics/heat in…
shahpratham Nov 8, 2022
2ec7b15
Merge branch 'feature920/distributed-2D-convolution' of https://githu…
shahpratham Nov 8, 2022
5a51cec
manual merge convolve() from main
ClaudiaComito Mar 29, 2023
f764f8e
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Mar 29, 2023
b8c67df
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Apr 17, 2023
72bbe32
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Apr 28, 2023
df386e1
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito May 8, 2023
9fb64c8
Merge branch 'main' into feature920/distributed-2D-convolution
mrfh92 May 22, 2023
f762a79
Merge branch 'main' into feature920/distributed-2D-convolution
mrfh92 May 25, 2023
90fd61c
Merge branch 'main' into feature920/distributed-2D-convolution
Jun 19, 2023
73aac81
Merge branch 'main' into feature920/distributed-2D-convolution
Jun 19, 2023
ac8cb2f
found where the CI failed: in the second-last test of test_signal one…
Jun 19, 2023
35cef0b
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Jun 21, 2023
e6e5558
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Aug 7, 2023
6f94766
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Aug 21, 2023
3c82dc6
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Aug 28, 2023
bb713e6
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Sep 4, 2023
31ea4af
added inputcheck to heat/core/signal.py
krajsek Sep 17, 2023
707b7f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2023
a948999
added comments
krajsek Sep 17, 2023
deb2154
refactoring
krajsek Sep 17, 2023
2d057a7
added comments to inputcheck
krajsek Sep 17, 2023
88b839b
convolve2d docstring revised
krajsek Sep 17, 2023
3fed4f4
Merge branch 'main' into feature920/distributed-2D-convolution
mrfh92 Oct 5, 2023
6fac51e
Merge branch 'main' into feature920/distributed-2D-convolution
mrfh92 Oct 16, 2023
7c7f4d4
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Oct 23, 2023
e659d62
set dtype to float on GPU
ClaudiaComito Nov 20, 2023
34f3264
fix device conditional statement
ClaudiaComito Nov 20, 2023
fbddaac
debugging
ClaudiaComito Nov 20, 2023
4eba58b
debugging
ClaudiaComito Nov 20, 2023
89e1b8f
debugging
ClaudiaComito Nov 20, 2023
f785188
debugging devices
ClaudiaComito Nov 20, 2023
45fc3f2
local convolution on root process
ClaudiaComito Nov 20, 2023
64a3320
debugging
ClaudiaComito Nov 20, 2023
1e4ee28
debugging
ClaudiaComito Nov 20, 2023
1e49676
update pre-commit-config
ClaudiaComito Nov 22, 2023
6c6c946
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Dec 4, 2023
0489b18
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Mar 8, 2024
04e7cd2
debugging
ClaudiaComito Mar 8, 2024
583b242
Merge branch 'main' into feature920/distributed-2D-convolution
mtar Apr 3, 2024
37d02a0
Merge branch 'main' into feature920/distributed-2D-convolution
mtar Apr 12, 2024
9d466f4
Merge branch 'main' into feature920/distributed-2D-convolution
ClaudiaComito Aug 15, 2024
69f129d
merge latest main
ClaudiaComito Aug 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
362 changes: 361 additions & 1 deletion heat/core/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,124 @@
from .factories import array, zeros
import torch.nn.functional as fc

__all__ = ["convolve"]
__all__ = ["convolve", "convolve2d"]


def convgenpad(a, signal, pad, boundary, fillvalue):
"""
Adds padding to local PyTorch tensors considering the distributed scheme of the overlying DNDarray.

Parameters
----------
a : DNDarray
Overlying N-dimensional `DNDarray` signal
signal : torch.Tensor
Local Pytorch tensors to be padded
pad: list
list containing paddings per dimensions
boundary: str{‘fill’, ‘wrap’, ‘symm’}, optional
A flag indicating how to handle boundaries:
'fill':
pad input arrays with fillvalue. (default)
'wrap':
circular boundary conditions.
'symm':
symmetrical boundary conditions.
fillvalue: scalar, optional
Value to fill pad input arrays with. Default is 0.
"""
dim = len(signal.shape) - 2
dime = 2 * dim - 1
dimz = 2 * dim - 2
# check if more than one rank is involved
if a.is_distributed() and a.split is not None:
# set the padding of the first rank
if a.comm.rank == 0:
pad[dime - 2 * a.split] = 0
# set the padding of the last rank
elif a.comm.rank == a.comm.size - 1:
pad[dimz - 2 * a.split] = 0
else:
pad[dime - 2 * a.split] = 0
pad[dimz - 2 * a.split] = 0

if boundary == "fill":
signal = fc.pad(signal, pad, mode="constant", value=fillvalue)
elif boundary == "wrap":
signal = fc.pad(signal, pad, mode="circular")
elif boundary == "symm":
signal = fc.pad(signal, pad, mode="reflect")
else:
raise ValueError("Only {'fill', 'wrap', 'symm'} are allowed for boundary")

return signal


def inputcheck(a, v):
"""
Check and preprocess input data.

Parameters
----------
a : scalar, array_like, DNDarray
Input signal data.
v : scalar, array_like, DNDarray
Input filter mask.

Returns
-------
tuple
A tuple containing the processed input signal 'a' and filter mask 'v'.

Raises
------
TypeError
If 'a' or 'v' have unsupported data types.

Description
-----------
This function takes two inputs, 'a' (signal data) and 'v' (filter mask), and performs the following checks and
preprocessing steps:

1. Check if 'a' and 'v' are scalars. If they are, convert them into DNDarray arrays.

2. Check if 'a' and 'v' are instances of the 'DNDarray' class. If not, attempt to convert them into DNDarray arrays.
If conversion is not possible, raise a TypeError.

3. Determine the promoted data type for 'a' and 'v' based on their existing data types. Convert 'a' and 'v' to this
promoted data type to ensure consistent data types.

4. Return a tuple containing the processed 'a' and 'v'.
"""
# Check if 'a' is a scalar and convert to a DNDarray if necessary
if np.isscalar(a):
a = array([a])

# Check if 'v' is a scalar and convert to a DNDarray if necessary
if np.isscalar(v):
v = array([v])

# Check if 'a' is not an instance of DNDarray and try to convert it to a DNDarray array
if not isinstance(a, DNDarray):
try:
a = array(a)
except TypeError:
raise TypeError(f"non-supported type for signal: {type(a)}")

# Check if 'v' is not an instance of DNDarray and try to convert it to a NumPy array
if not isinstance(v, DNDarray):
try:
v = array(v)
except TypeError:
raise TypeError(f"non-supported type for filter: {type(v)}")

# Determine the promoted data type for 'a' and 'v' and convert them to this data type
promoted_type = promote_types(a.dtype, v.dtype)
a = a.astype(promoted_type)
v = v.astype(promoted_type)

# Return the processed 'a' and 'v' as a tuple
return a, v


def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
Expand Down Expand Up @@ -316,3 +433,246 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray:
a.comm,
balanced=False,
).astype(a.dtype.torch_type())


def convolve2d(a, v, mode="full", boundary="fill", fillvalue=0):
"""
Returns the discrete, linear convolution of two two-dimensional HeAT tensors.

Parameters
----------
a : scalar, array_like, DNDarray
Two-dimensional signal
v : scalar, array_like, DNDarray
Two-dimensional filter mask.
mode : {'full', 'valid', 'same'}, optional
'full':
By default, mode is 'full'. This returns the convolution at
each point of overlap, with an output shape of (N+M-1,). At
the end-points of the convolution, the signals do not overlap
completely, and boundary effects may be seen.
'same':
Mode 'same' returns output of length 'N'. Boundary
effects are still visible. This mode is not supported for
even sized filter weights
'valid':
Mode 'valid' returns output of length 'N-M+1'. The
convolution product is only given for points where the signals
overlap completely. Values outside the signal boundary have no
effect.
boundary: str{‘fill’, ‘wrap’, ‘symm’}, optional
A flag indicating how to handle boundaries:
'fill':
pad input arrays with fillvalue. (default)
'wrap':
circular boundary conditions.
'symm':
symmetrical boundary conditions.
fillvalue: scalar, optional
Value to fill pad input arrays with. Default is 0.

Returns
-------
out : ht.tensor
Discrete, linear convolution of 'a' and 'v'.

Note : If the filter weight is larger
than fitting into memory, using the FFT for convolution is recommended.

Example
--------
>>> a = ht.ones((5, 5))
>>> v = ht.ones((3, 3))
>>> ht.convolve2d(a, v, mode='valid')
DNDarray([[9., 9., 9.],
[9., 9., 9.],
[9., 9., 9.]], dtype=ht.float32, device=cpu:0, split=None)

>>> a = ht.ones((5,5), split=1)
>>> v = ht.ones((3,3), split=1)
>>> ht.convolve2d(a, v)
DNDarray([[1., 2., 3., 3., 3., 2., 1.],
[2., 4., 6., 6., 6., 4., 2.],
[3., 6., 9., 9., 9., 6., 3.],
[3., 6., 9., 9., 9., 6., 3.],
[3., 6., 9., 9., 9., 6., 3.],
[2., 4., 6., 6., 6., 4., 2.],
[1., 2., 3., 3., 3., 2., 1.]], dtype=ht.float32, device=cpu:0, split=1)
"""
a, v = inputcheck(a, v)

if a.shape[0] < v.shape[0] and a.shape[1] < v.shape[1]:
a, v = v, a

if len(a.shape) != 2 or len(v.shape) != 2:
raise ValueError("Only 2-dimensional input DNDarrays are allowed")
if a.shape[0] < v.shape[0] or a.shape[1] < v.shape[1]:
raise ValueError("Filter size must not be larger in one dimension and smaller in the other")
if mode == "same" and v.shape[0] % 2 == 0:
raise ValueError("Mode 'same' cannot be used with even-sized kernel")
if (a.split == 0 and v.split == 1) or (a.split == 1 and v.split == 0):
raise ValueError("DNDarrays must have same axis of split")

# compute halo size
if a.split == 0 or a.split is None:
halo_size = int(v.lshape_map[0][0]) // 2
else:
halo_size = int(v.lshape_map[0][1]) // 2

if a.is_distributed():
# fetch halos and store them in a.halo_next/a.halo_prev
a.get_halo(halo_size)
# apply halos to local array
signal = a.array_with_halos
else:
# get local array in case of non-distributed a
signal = a.larray

# check if a local chunk is smaller than the filter size
if a.is_distributed() and signal.size()[0] < v.lshape_map[0][0]:
raise ValueError("Local signal chunk size is smaller than the local filter size.")

if mode == "full":
pad_0 = v.shape[1] - 1
gshape_0 = v.shape[0] + a.shape[0] - 1
pad_1 = v.shape[0] - 1
gshape_1 = v.shape[1] + a.shape[1] - 1
pad = list((pad_0, pad_0, pad_1, pad_1))
gshape = (gshape_0, gshape_1)

elif mode == "same":
pad_0 = v.shape[1] // 2
pad_1 = v.shape[0] // 2
pad = list((pad_0, pad_0, pad_1, pad_1))
gshape = (a.shape[0], a.shape[1])

elif mode == "valid":
pad = list((0,) * 4)
gshape_0 = a.shape[0] - v.shape[0] + 1
gshape_1 = a.shape[1] - v.shape[1] + 1
gshape = (gshape_0, gshape_1)

else:
raise ValueError("Only {'full', 'valid', 'same'} are allowed for mode")

# make signal and filter weight 4D for Pytorch conv2d function
signal = signal.reshape(1, 1, signal.shape[0], signal.shape[1])

# add padding to the borders according to mode
signal = convgenpad(a, signal, pad, boundary, fillvalue)

# flip filter for convolution as PyTorch conv2d computes correlation
v = flip(v, [0, 1])

# compute weight size
if a.split == 0 or a.split is None:
weight_size = int(v.lshape_map[0][0])
current_size = v.larray.shape[0]
else:
weight_size = int(v.lshape_map[0][1])
current_size = v.larray.shape[1]

if current_size != weight_size:
weight_shape = (int(v.lshape_map[0][0]), int(v.lshape_map[0][1]))
target = torch.zeros(weight_shape, dtype=v.larray.dtype, device=v.larray.device)
pad_size = weight_size - current_size
if v.split == 0:
target[pad_size:] = v.larray
else:
target[:, pad_size:] = v.larray
weight = target
else:
weight = v.larray

t_v = weight.clone() # stores temporary weight
weight = weight.reshape(1, 1, weight.shape[0], weight.shape[1])

if v.is_distributed():
size = v.comm.size
rank = v.comm.rank
split_axis = v.split
for r in range(size):
rec_v = v.comm.bcast(t_v, root=r)
if rank != r:
t_v1 = rec_v.reshape(1, 1, rec_v.shape[0], rec_v.shape[1])
else:
t_v1 = t_v.reshape(1, 1, t_v.shape[0], t_v.shape[1])
# apply torch convolution operator
print("DEVICES: signal, t_v1", signal.device, t_v1.device)
print("RANK = ", rank)
local_signal_filtered = fc.conv2d(signal, t_v1)
# unpack 3D result into 1D
local_signal_filtered = local_signal_filtered[0, 0, :]

# if kernel shape along split axis is even we need to get rid of duplicated values
if a.comm.rank != 0 and weight_size % 2 == 0 and a.split == 0:
local_signal_filtered = local_signal_filtered[1:, :]
if a.comm.rank != 0 and weight_size % 2 == 0 and a.split == 1:
local_signal_filtered = local_signal_filtered[:, 1:]

# accumulate filtered signal on the fly
global_signal_filtered = array(
local_signal_filtered, is_split=split_axis, device=a.device, comm=a.comm
)
if r == 0:
# initialize signal_filtered, starting point of slice
signal_filtered = zeros(
gshape, dtype=a.dtype, split=a.split, device=a.device, comm=a.comm
)
start_idx = 0

# accumulate relevant slice of filtered signal
# note, this is a binary operation between unevenly distributed dndarrays and will require communication, check out _operations.__binary_op()
# print(
# "DEVICES: signal_filtered, global_signal_filtered, start_idx, gshape",
# signal_filtered.device,
# global_signal_filtered.device,
# start_idx,
# gshape,
# )
print(
"DEBUGGING: signal_filtered.split, global_signal_filtered.split, gshapes, lshapes",
signal_filtered.split,
global_signal_filtered.split,
signal_filtered.gshape,
global_signal_filtered.gshape,
signal_filtered.lshape,
global_signal_filtered.lshape,
)
if split_axis == 0:
signal_filtered += global_signal_filtered[start_idx : start_idx + gshape[0]]
else:
signal_filtered += global_signal_filtered[:, start_idx : start_idx + gshape[1]]
if r != size - 1:
start_idx += v.lshape_map[r + 1][split_axis]

signal_filtered.balance()
return signal_filtered

else:
# apply torch convolution operator
signal_filtered = fc.conv2d(signal, weight)

# unpack 3D result into 1D
signal_filtered = signal_filtered[0, 0, :]

# if kernel shape along split axis is even we need to get rid of duplicated values
if a.comm.rank != 0 and v.lshape_map[0][0] % 2 == 0 and a.split == 0:
signal_filtered = signal_filtered[1:, :]
elif a.comm.rank != 0 and v.lshape_map[0][1] % 2 == 0 and a.split == 1:
signal_filtered = signal_filtered[:, 1:]

result = DNDarray(
signal_filtered.contiguous(),
gshape,
signal_filtered.dtype,
a.split,
a.device,
a.comm,
a.balanced,
).astype(a.dtype.torch_type())

if mode == "full" or mode == "valid":
result.balance()

return result
Loading