-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/920: Implemented 2D convolution #1007
base: main
Are you sure you want to change the base?
Changes from 34 commits
d7463a4
7b2d2b9
8452fc3
d6d8460
7975267
01f813e
a8025d8
98b8ba0
4dc9950
296c4dd
f489fe9
5c4a656
5b4a2b2
42632a7
0040f93
f665345
d1408ae
ba65a04
51450d2
adde144
149f345
3ea900b
8ccdf3e
46e4acd
5a440d7
dc79757
3f2c492
2ec7b15
5a51cec
f764f8e
b8c67df
72bbe32
df386e1
9fb64c8
f762a79
90fd61c
73aac81
ac8cb2f
35cef0b
e6e5558
6f94766
3c82dc6
bb713e6
31ea4af
707b7f1
a948999
deb2154
2d057a7
88b839b
3fed4f4
6fac51e
7c7f4d4
e659d62
34f3264
fbddaac
4eba58b
89e1b8f
f785188
45fc3f2
64a3320
1e4ee28
1e49676
6c6c946
0489b18
04e7cd2
583b242
37d02a0
9d466f4
69f129d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,13 +10,62 @@ | |
from .factories import array, zeros | ||
import torch.nn.functional as fc | ||
|
||
__all__ = ["convolve"] | ||
__all__ = ["convolve", "convolve2d"] | ||
|
||
|
||
def convgenpad(a, signal, pad, boundary, fillvalue): | ||
""" | ||
Adds padding to local PyTorch tensors considering the distributed scheme of the overlying DNDarray. | ||
|
||
Parameters | ||
---------- | ||
a : DNDarray | ||
Overlying N-dimensional `DNDarray` signal | ||
signal : torch.Tensor | ||
Local Pytorch tensors to be padded | ||
pad: list | ||
list containing paddings per dimensions | ||
boundary: str{‘fill’, ‘wrap’, ‘symm’}, optional | ||
A flag indicating how to handle boundaries: | ||
'fill': | ||
pad input arrays with fillvalue. (default) | ||
'wrap': | ||
circular boundary conditions. | ||
'symm': | ||
symmetrical boundary conditions. | ||
fillvalue: scalar, optional | ||
Value to fill pad input arrays with. Default is 0. | ||
""" | ||
dim = len(signal.shape) - 2 | ||
dime = 2 * dim - 1 | ||
dimz = 2 * dim - 2 | ||
# check if more than one rank is involved | ||
if a.is_distributed() and a.split is not None: | ||
# set the padding of the first rank | ||
if a.comm.rank == 0: | ||
pad[dime - 2 * a.split] = 0 | ||
# set the padding of the last rank | ||
elif a.comm.rank == a.comm.size - 1: | ||
pad[dimz - 2 * a.split] = 0 | ||
else: | ||
pad[dime - 2 * a.split] = 0 | ||
pad[dimz - 2 * a.split] = 0 | ||
|
||
if boundary == "fill": | ||
signal = fc.pad(signal, pad, mode="constant", value=fillvalue) | ||
elif boundary == "wrap": | ||
signal = fc.pad(signal, pad, mode="circular") | ||
elif boundary == "symm": | ||
signal = fc.pad(signal, pad, mode="reflect") | ||
else: | ||
raise ValueError("Only {'fill', 'wrap', 'symm'} are allowed for boundary") | ||
|
||
return signal | ||
|
||
|
||
def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: | ||
""" | ||
Returns the discrete, linear convolution of two one-dimensional `DNDarray`s or scalars. | ||
|
||
Parameters | ||
---------- | ||
a : DNDarray or scalar | ||
|
@@ -39,12 +88,10 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: | |
convolution product is only given for points where the signals | ||
overlap completely. Values outside the signal boundary have no | ||
effect. | ||
|
||
Examples | ||
-------- | ||
Note how the convolution operator flips the second array | ||
before "sliding" the two across one another: | ||
|
||
>>> a = ht.ones(10) | ||
>>> v = ht.arange(3).astype(ht.float) | ||
>>> ht.convolve(a, v, mode='full') | ||
|
@@ -57,15 +104,13 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: | |
>>> v = ht.arange(3, split = 0).astype(ht.float) | ||
>>> ht.convolve(a, v, mode='valid') | ||
DNDarray([3., 3., 3., 3., 3., 3., 3., 3.]) | ||
|
||
[0/3] DNDarray([3., 3., 3.]) | ||
[1/3] DNDarray([3., 3., 3.]) | ||
[2/3] DNDarray([3., 3.]) | ||
>>> a = ht.ones(10, split = 0) | ||
>>> v = ht.arange(3, split = 0) | ||
>>> ht.convolve(a, v) | ||
DNDarray([0., 1., 3., 3., 3., 3., 3., 3., 3., 3., 3., 2.], dtype=ht.float32, device=cpu:0, split=0) | ||
|
||
[0/3] DNDarray([0., 1., 3., 3.]) | ||
[1/3] DNDarray([3., 3., 3., 3.]) | ||
[2/3] DNDarray([3., 3., 3., 2.]) | ||
|
@@ -204,3 +249,240 @@ def convolve(a: DNDarray, v: DNDarray, mode: str = "full") -> DNDarray: | |
a.comm, | ||
balanced=False, | ||
).astype(a.dtype.torch_type()) | ||
|
||
|
||
def convolve2d(a, v, mode="full", boundary="fill", fillvalue=0): | ||
""" | ||
Returns the discrete, linear convolution of two two-dimensional HeAT tensors. | ||
|
||
Parameters | ||
---------- | ||
a : (N,) ht.tensor | ||
Two-dimensional signal HeAT tensor | ||
v : (M,) ht.tensor | ||
Two-dimensional filter weight HeAT tensor. | ||
mode : {'full', 'valid', 'same'}, optional | ||
'full': | ||
By default, mode is 'full'. This returns the convolution at | ||
each point of overlap, with an output shape of (N+M-1,). At | ||
the end-points of the convolution, the signals do not overlap | ||
completely, and boundary effects may be seen. | ||
'same': | ||
Mode 'same' returns output of length 'N'. Boundary | ||
effects are still visible. This mode is not supported for | ||
even sized filter weights | ||
'valid': | ||
Mode 'valid' returns output of length 'N-M+1'. The | ||
convolution product is only given for points where the signals | ||
overlap completely. Values outside the signal boundary have no | ||
effect. | ||
boundary: str{‘fill’, ‘wrap’, ‘symm’}, optional | ||
A flag indicating how to handle boundaries: | ||
'fill': | ||
pad input arrays with fillvalue. (default) | ||
'wrap': | ||
circular boundary conditions. | ||
'symm': | ||
symmetrical boundary conditions. | ||
fillvalue: scalar, optional | ||
Value to fill pad input arrays with. Default is 0. | ||
|
||
Returns | ||
------- | ||
out : ht.tensor | ||
Discrete, linear convolution of 'a' and 'v'. | ||
|
||
Note : If the filter weight is larger | ||
than fitting into memory, using the FFT for convolution is recommended. | ||
|
||
Example | ||
-------- | ||
>>> a = ht.ones((5, 5)) | ||
>>> v = ht.ones((3, 3)) | ||
>>> ht.convolve2d(a, v, mode='valid') | ||
DNDarray([[9., 9., 9.], | ||
[9., 9., 9.], | ||
[9., 9., 9.]], dtype=ht.float32, device=cpu:0, split=None) | ||
|
||
>>> a = ht.ones((5,5), split=1) | ||
>>> v = ht.ones((3,3), split=1) | ||
>>> ht.convolve2d(a, v) | ||
DNDarray([[1., 2., 3., 3., 3., 2., 1.], | ||
[2., 4., 6., 6., 6., 4., 2.], | ||
[3., 6., 9., 9., 9., 6., 3.], | ||
[3., 6., 9., 9., 9., 6., 3.], | ||
[3., 6., 9., 9., 9., 6., 3.], | ||
[2., 4., 6., 6., 6., 4., 2.], | ||
[1., 2., 3., 3., 3., 2., 1.]], dtype=ht.float32, device=cpu:0, split=1) | ||
""" | ||
if np.isscalar(a): | ||
a = array([a]) | ||
if np.isscalar(v): | ||
v = array([v]) | ||
if not isinstance(a, DNDarray): | ||
try: | ||
a = array(a) | ||
except TypeError: | ||
raise TypeError("non-supported type for signal: {}".format(type(a))) | ||
if not isinstance(v, DNDarray): | ||
try: | ||
v = array(v) | ||
except TypeError: | ||
raise TypeError("non-supported type for filter: {}".format(type(v))) | ||
promoted_type = promote_types(a.dtype, v.dtype) | ||
a = a.astype(promoted_type) | ||
v = v.astype(promoted_type) | ||
|
||
if a.shape[0] < v.shape[0] and a.shape[1] < v.shape[1]: | ||
a, v = v, a | ||
|
||
if len(a.shape) != 2 or len(v.shape) != 2: | ||
raise ValueError("Only 2-dimensional input DNDarrays are allowed") | ||
if a.shape[0] < v.shape[0] or a.shape[1] < v.shape[1]: | ||
raise ValueError("Filter size must not be greater than the signal size") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this still valid? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe a more precise Error message: "Filter size must not be larger in one dimension and smaller in the other" as the user might not be aware of the swapping in case of the filter size beeing larger in both dimensions |
||
if mode == "same" and v.shape[0] % 2 == 0: | ||
raise ValueError("Mode 'same' cannot be used with even-sized kernel") | ||
if (a.split == 0 and v.split == 1) or (a.split == 1 and v.split == 0): | ||
raise ValueError("DNDarrays must have same axis of split") | ||
|
||
# compute halo size | ||
if a.split == 0 or a.split is None: | ||
halo_size = int(v.lshape_map[0][0]) // 2 | ||
else: | ||
halo_size = int(v.lshape_map[0][1]) // 2 | ||
|
||
# fetch halos and store them in a.halo_next/a.halo_prev | ||
# print("qqa: ", halo_size) | ||
a.get_halo(halo_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fetching halos is only necessary in distributed mode |
||
|
||
# apply halos to local array | ||
signal = a.array_with_halos | ||
|
||
# check if a local chunk is smaller than the filter size | ||
if a.is_distributed() and signal.size()[0] < v.lshape_map[0][0]: | ||
raise ValueError("Local signal chunk size is smaller than the local filter size.") | ||
|
||
if mode == "full": | ||
pad_0 = v.shape[1] - 1 | ||
gshape_0 = v.shape[0] + a.shape[0] - 1 | ||
pad_1 = v.shape[0] - 1 | ||
gshape_1 = v.shape[1] + a.shape[1] - 1 | ||
pad = list((pad_0, pad_0, pad_1, pad_1)) | ||
gshape = (gshape_0, gshape_1) | ||
|
||
elif mode == "same": | ||
pad_0 = v.shape[1] // 2 | ||
pad_1 = v.shape[0] // 2 | ||
pad = list((pad_0, pad_0, pad_1, pad_1)) | ||
gshape = (a.shape[0], a.shape[1]) | ||
|
||
elif mode == "valid": | ||
pad = list((0,) * 4) | ||
gshape_0 = a.shape[0] - v.shape[0] + 1 | ||
gshape_1 = a.shape[1] - v.shape[1] + 1 | ||
gshape = (gshape_0, gshape_1) | ||
|
||
else: | ||
raise ValueError("Only {'full', 'valid', 'same'} are allowed for mode") | ||
|
||
# make signal and filter weight 4D for Pytorch conv2d function | ||
signal = signal.reshape(1, 1, signal.shape[0], signal.shape[1]) | ||
|
||
# add padding to the borders according to mode | ||
signal = convgenpad(a, signal, pad, boundary, fillvalue) | ||
|
||
# flip filter for convolution as PyTorch conv2d computes correlation | ||
v = flip(v, [0, 1]) | ||
|
||
# compute weight size | ||
if a.split == 0 or a.split is None: | ||
weight_size = int(v.lshape_map[0][0]) | ||
current_size = v.larray.shape[0] | ||
else: | ||
weight_size = int(v.lshape_map[0][1]) | ||
current_size = v.larray.shape[1] | ||
|
||
if current_size != weight_size: | ||
weight_shape = (int(v.lshape_map[0][0]), int(v.lshape_map[0][1])) | ||
target = torch.zeros(weight_shape, dtype=v.larray.dtype, device=v.larray.device) | ||
pad_size = weight_size - current_size | ||
if v.split == 0: | ||
target[pad_size:] = v.larray | ||
else: | ||
target[:, pad_size:] = v.larray | ||
weight = target | ||
else: | ||
weight = v.larray | ||
|
||
t_v = weight # stores temporary weight | ||
weight = weight.reshape(1, 1, weight.shape[0], weight.shape[1]) | ||
|
||
if v.is_distributed(): | ||
size = v.comm.size | ||
split_axis = v.split | ||
for r in range(size): | ||
rec_v = v.comm.bcast(t_v, root=r) | ||
t_v1 = rec_v.reshape(1, 1, rec_v.shape[0], rec_v.shape[1]) | ||
|
||
# apply torch convolution operator | ||
local_signal_filtered = fc.conv2d(signal, t_v1) | ||
|
||
# unpack 3D result into 1D | ||
local_signal_filtered = local_signal_filtered[0, 0, :] | ||
|
||
# if kernel shape along split axis is even we need to get rid of duplicated values | ||
if a.comm.rank != 0 and weight_size % 2 == 0 and a.split == 0: | ||
local_signal_filtered = local_signal_filtered[1:, :] | ||
if a.comm.rank != 0 and weight_size % 2 == 0 and a.split == 1: | ||
local_signal_filtered = local_signal_filtered[:, 1:] | ||
|
||
# accumulate filtered signal on the fly | ||
global_signal_filtered = array( | ||
local_signal_filtered, is_split=split_axis, device=a.device, comm=a.comm | ||
) | ||
if r == 0: | ||
# initialize signal_filtered, starting point of slice | ||
signal_filtered = zeros( | ||
gshape, dtype=a.dtype, split=a.split, device=a.device, comm=a.comm | ||
) | ||
start_idx = 0 | ||
|
||
# accumulate relevant slice of filtered signal | ||
# note, this is a binary operation between unevenly distributed dndarrays and will require communication, check out _operations.__binary_op() | ||
if split_axis == 0: | ||
signal_filtered += global_signal_filtered[start_idx : start_idx + gshape[0]] | ||
else: | ||
signal_filtered += global_signal_filtered[:, start_idx : start_idx + gshape[1]] | ||
if r != size - 1: | ||
start_idx += v.lshape_map[r + 1][split_axis] | ||
|
||
signal_filtered.balance() | ||
return signal_filtered | ||
|
||
else: | ||
# apply torch convolution operator | ||
signal_filtered = fc.conv2d(signal, weight) | ||
|
||
# unpack 3D result into 1D | ||
signal_filtered = signal_filtered[0, 0, :] | ||
|
||
# if kernel shape along split axis is even we need to get rid of duplicated values | ||
if a.comm.rank != 0 and v.lshape_map[0][0] % 2 == 0 and a.split == 0: | ||
signal_filtered = signal_filtered[1:, :] | ||
elif a.comm.rank != 0 and v.lshape_map[0][1] % 2 == 0 and a.split == 1: | ||
signal_filtered = signal_filtered[:, 1:] | ||
|
||
result = DNDarray( | ||
signal_filtered.contiguous(), | ||
gshape, | ||
signal_filtered.dtype, | ||
a.split, | ||
a.device, | ||
a.comm, | ||
a.balanced, | ||
).astype(a.dtype.torch_type()) | ||
|
||
if mode == "full" or mode == "valid": | ||
result.balance() | ||
|
||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 318 to 334 seems to be independent of the dimension and mabe wrapped into a sanitize function that is then used for any N -D convolution