Skip to content

Commit

Permalink
reformatted
Browse files Browse the repository at this point in the history
  • Loading branch information
shahpratham committed Sep 11, 2022
1 parent 3ea900b commit 8ccdf3e
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions heat/core/tests/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def test_convolve2d(self):
dis_kernel_even = ht.arange(16, split=0).reshape((4, 4)).astype(ht.int)

with self.assertRaises(TypeError):
signal_wrong_type = [[0, 1, 2, "tre", 4]]*5
signal_wrong_type = [[0, 1, 2, "tre", 4]] * 5
ht.convolve2d(signal_wrong_type, kernel_odd)
with self.assertRaises(TypeError):
filter_wrong_type = [[ 1, "pizza", "pineapple"]]*3
filter_wrong_type = [[1, "pizza", "pineapple"]] * 3
ht.convolve2d(dis_signal, filter_wrong_type, mode="full")
with self.assertRaises(ValueError):
ht.convolve2d(dis_signal, kernel_odd, mode="invalid")
Expand All @@ -127,15 +127,21 @@ def test_convolve2d(self):
# odd kernel size
conv = ht.convolve2d(dis_signal, kernel_odd, mode=mode)
gathered = manipulations.resplit(conv, axis=None)
self.assertTrue(ht.equal(full_odd[i : len(full_odd) - i, i : len(full_odd) - i], gathered))
self.assertTrue(
ht.equal(full_odd[i : len(full_odd) - i, i : len(full_odd) - i], gathered)
)

conv = ht.convolve2d(dis_signal, dis_kernel_odd, mode=mode)
gathered = manipulations.resplit(conv, axis=None)
self.assertTrue(ht.equal(full_odd[i : len(full_odd) - i, i : len(full_odd) - i], gathered))
self.assertTrue(
ht.equal(full_odd[i : len(full_odd) - i, i : len(full_odd) - i], gathered)
)

conv = ht.convolve2d(signal, dis_kernel_odd, mode=mode)
gathered = manipulations.resplit(conv, axis=None)
self.assertTrue(ht.equal(full_odd[i : len(full_odd) - i, i : len(full_odd) - i], gathered))
self.assertTrue(
ht.equal(full_odd[i : len(full_odd) - i, i : len(full_odd) - i], gathered)
)

# different data types
conv = ht.convolve2d(dis_signal.astype(ht.float), kernel_odd)
Expand Down Expand Up @@ -167,8 +173,8 @@ def test_convolve2d(self):

# distributed large signal and kernel
np.random.seed(12)
np_a = np.random.randint(1000, size = (140, 250))
np_b = np.random.randint(1000, size = (39, 17))
np_a = np.random.randint(1000, size=(140, 250))
np_b = np.random.randint(1000, size=(39, 17))
sc_conv = sig.convolve2d(np_a, np_b, mode=mode)

a = ht.array(np_a, split=0, dtype=ht.int32)
Expand Down

0 comments on commit 8ccdf3e

Please sign in to comment.