Skip to content

Commit 95f59c9

Browse files
author
Henry Isaacson
committed
test
1 parent ffb7f57 commit 95f59c9

14 files changed

+291
-208
lines changed

src/beignet/func/_molecular_dynamics/_partition/__cell_dimensions.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def _cell_dimensions(
1616
box_size = float(box_size)
1717

1818
if box_size < minimum_cell_size:
19-
raise ValueError('Box size must be at least as large as minimum cell size.')
19+
raise ValueError("Box size must be at least as large as minimum cell size.")
2020

2121
if isinstance(box_size, Tensor):
2222
if box_size.dtype in {torch.int32, torch.int64}:
@@ -36,7 +36,9 @@ def _cell_dimensions(
3636

3737
for cells in flattened_cells_per_side:
3838
if cells.item() < 3:
39-
raise ValueError('Box must be at least 3x the size of the grid spacing in each dimension.')
39+
raise ValueError(
40+
"Box must be at least 3x the size of the grid spacing in each dimension."
41+
)
4042

4143
cell_count = functools.reduce(
4244
operator.mul,
@@ -45,16 +47,18 @@ def _cell_dimensions(
4547
)
4648

4749
elif box_size.dim() == 0:
48-
cell_count = cells_per_side ** spatial_dimension
50+
cell_count = cells_per_side**spatial_dimension
4951

5052
else:
51-
raise ValueError(f'Box must be either: a scalar, a vector, or a matrix. Found {box_size}.')
53+
raise ValueError(
54+
f"Box must be either: a scalar, a vector, or a matrix. Found {box_size}."
55+
)
5256

5357
else:
5458
cells_per_side = math.floor(box_size / minimum_cell_size)
5559

5660
cell_size = box_size / cells_per_side
5761

58-
cell_count = cells_per_side ** spatial_dimension
62+
cell_count = cells_per_side**spatial_dimension
5963

6064
return box_size, cell_size, cells_per_side, int(cell_count)

src/beignet/func/_molecular_dynamics/_partition/__cell_size.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@ def _cell_size(box: Tensor, minimum_unit_size: Tensor) -> Tensor:
88

99
else:
1010
raise ValueError("Box and minimum unit size must be of the same shape.")
11-

src/beignet/func/_molecular_dynamics/_partition/__hash_constants.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,15 @@ def _hash_constants(spatial_dimensions: int, cells_per_side: Tensor) -> Tensor:
3333
If the size of `cells_per_side` is not zero or `spatial_dimensions`.
3434
"""
3535
if cells_per_side.numel() == 1:
36-
constants = [
37-
cells_per_side ** dim for dim in range(spatial_dimensions)
38-
]
36+
constants = [cells_per_side**dim for dim in range(spatial_dimensions)]
3937
return torch.tensor([constants], dtype=torch.int32)
4038

4139
elif cells_per_side.numel() == spatial_dimensions:
4240
one = torch.tensor([[1]], dtype=torch.int32)
43-
cells_per_side = torch.cat(
44-
(one, cells_per_side[:-1].unsqueeze(0)),
45-
dim=1
46-
)
41+
cells_per_side = torch.cat((one, cells_per_side[:-1].unsqueeze(0)), dim=1)
4742
return torch.cumprod(cells_per_side.flatten(), dim=0)
4843

4944
else:
50-
raise ValueError("Cells per side must either: have 0 dimensions, be the same size as spatial dimensions.")
45+
raise ValueError(
46+
"Cells per side must either: have 0 dimensions, be the same size as spatial dimensions."
47+
)

src/beignet/func/_molecular_dynamics/_partition/__segment_sum.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,33 @@ def _segment_sum(
1212
**kwargs,
1313
) -> Tensor:
1414
"""
15-
Computes the sum of segments of a tensor along the first dimension.
16-
17-
Parameters
18-
----------
19-
input : Tensor
20-
A tensor containing the input values to be summed.
21-
indexes : Tensor
22-
A 1D tensor containing the segment indexes for summation.
23-
Should have the same length as the first dimension of the `input` tensor.
24-
n : Optional[int], optional
25-
The number of segments, by default `n` is set to `max(indexes) + 1`.
26-
27-
Returns
28-
-------
29-
Tensor
30-
A tensor where each entry contains the sum of the corresponding segment
31-
from the `input` tensor.
32-
"""
15+
Computes the sum of segments of a tensor along the first dimension.
16+
17+
Parameters
18+
----------
19+
input : Tensor
20+
A tensor containing the input values to be summed.
21+
indexes : Tensor
22+
A 1D tensor containing the segment indexes for summation.
23+
Should have the same length as the first dimension of the `input` tensor.
24+
n : Optional[int], optional
25+
The number of segments, by default `n` is set to `max(indexes) + 1`.
26+
27+
Returns
28+
-------
29+
Tensor
30+
A tensor where each entry contains the sum of the corresponding segment
31+
from the `input` tensor.
32+
"""
3333
if indexes.ndim == 1:
3434
indexes = torch.repeat_interleave(indexes, math.prod([*input.shape[1:]])).view(
3535
*[indexes.shape[0], *input.shape[1:]]
3636
)
3737

3838
if input.size(0) != indexes.size(0):
39-
raise ValueError("The length of the indexes tensor must match the size of the first dimension of the input tensor.")
39+
raise ValueError(
40+
"The length of the indexes tensor must match the size of the first dimension of the input tensor."
41+
)
4042

4143
if n is None:
4244
n = indexes.max().item() + 1

src/beignet/func/_molecular_dynamics/_partition/__shift.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,3 @@ def _shift(a: Tensor, b: Tensor) -> Tensor:
2424
"""
2525

2626
return torch.roll(a, shifts=tuple(b), dims=tuple(range(len(b))))
27-

tests/beignet/func/test__cell_dimensions.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import torch
66
from torch import Tensor
77

8-
from beignet.func._molecular_dynamics._partition.__cell_dimensions import \
9-
_cell_dimensions
8+
from beignet.func._molecular_dynamics._partition.__cell_dimensions import (
9+
_cell_dimensions,
10+
)
1011

1112

1213
@st.composite
@@ -19,8 +20,11 @@ def _cell_dimensions_strategy(draw):
1920
st.floats(min_value=3.0, max_value=10.0),
2021
st.lists(
2122
st.floats(min_value=3.0, max_value=10.0),
22-
min_size=spatial_dimension, max_size=spatial_dimension
23-
).map(torch.tensor).map(lambda x: x.float()),
23+
min_size=spatial_dimension,
24+
max_size=spatial_dimension,
25+
)
26+
.map(torch.tensor)
27+
.map(lambda x: x.float()),
2428
)
2529
)
2630

@@ -39,7 +43,9 @@ def _cell_dimensions_strategy(draw):
3943
(0, torch.tensor([100]), 10.0, AssertionError),
4044
],
4145
)
42-
def test_cell_dimensions_exceptions(spatial_dimension, box_size, minimum_cell_size, expected_exception):
46+
def test_cell_dimensions_exceptions(
47+
spatial_dimension, box_size, minimum_cell_size, expected_exception
48+
):
4349
if expected_exception is not None:
4450
with pytest.raises(expected_exception):
4551
_cell_dimensions(spatial_dimension, box_size, minimum_cell_size)
@@ -64,9 +70,7 @@ def test__cell_dimensions(data):
6470
return
6571

6672
box_size_out, cell_size, cells_per_side, cell_count = _cell_dimensions(
67-
spatial_dimension,
68-
box_size,
69-
minimum_cell_size
73+
spatial_dimension, box_size, minimum_cell_size
7074
)
7175

7276
if isinstance(box_size, (int, float)):
@@ -79,6 +83,10 @@ def test__cell_dimensions(data):
7983

8084
torch.testing.assert_allclose(box_size / cells_per_side.float(), cell_size)
8185

82-
expected_cell_count = int(torch.prod(cells_per_side).item()) if isinstance(cells_per_side, Tensor) else int(cells_per_side ** spatial_dimension)
86+
expected_cell_count = (
87+
int(torch.prod(cells_per_side).item())
88+
if isinstance(cells_per_side, Tensor)
89+
else int(cells_per_side**spatial_dimension)
90+
)
8391

8492
assert cell_count == expected_cell_count

tests/beignet/func/test__cell_size.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,31 @@
33
import pytest
44
from hypothesis import given
55

6-
from beignet.func._molecular_dynamics._partition.__cell_size import \
7-
_cell_size
6+
from beignet.func._molecular_dynamics._partition.__cell_size import _cell_size
87

98

109
@st.composite
1110
def _cell_size_strategy(draw):
1211
shape = draw(st.integers(min_value=1, max_value=10))
1312

1413
box = torch.tensor(
15-
draw(st.lists(st.floats(min_value=1.0, max_value=100.0), min_size=shape, max_size=shape)),
16-
dtype=torch.float32
14+
draw(
15+
st.lists(
16+
st.floats(min_value=1.0, max_value=100.0),
17+
min_size=shape,
18+
max_size=shape,
19+
)
20+
),
21+
dtype=torch.float32,
1722
)
1823

1924
minimum_unit_size = torch.tensor(
20-
draw(st.lists(st.floats(min_value=1.0, max_value=10.0), min_size=shape, max_size=shape)),
21-
dtype=torch.float32
25+
draw(
26+
st.lists(
27+
st.floats(min_value=1.0, max_value=10.0), min_size=shape, max_size=shape
28+
)
29+
),
30+
dtype=torch.float32,
2231
)
2332

2433
return box, minimum_unit_size

tests/beignet/func/test__hash_constants.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import torch
44
from hypothesis import given
55

6-
from beignet.func._molecular_dynamics._partition.__hash_constants import \
7-
_hash_constants
6+
from beignet.func._molecular_dynamics._partition.__hash_constants import _hash_constants
87

98

109
@st.composite
@@ -14,10 +13,13 @@ def _hash_constants_strategy(draw):
1413
cells_per_side = draw(
1514
st.one_of(
1615
st.integers(min_value=1, max_value=10).map(
17-
lambda x: torch.tensor([x], dtype=torch.int32)),
18-
st.lists(st.integers(min_value=1, max_value=10),
19-
min_size=spatial_dimensions, max_size=spatial_dimensions)
20-
.map(lambda x: torch.tensor(x, dtype=torch.int32))
16+
lambda x: torch.tensor([x], dtype=torch.int32)
17+
),
18+
st.lists(
19+
st.integers(min_value=1, max_value=10),
20+
min_size=spatial_dimensions,
21+
max_size=spatial_dimensions,
22+
).map(lambda x: torch.tensor(x, dtype=torch.int32)),
2123
)
2224
)
2325

@@ -27,15 +29,24 @@ def _hash_constants_strategy(draw):
2729
@pytest.mark.parametrize(
2830
"spatial_dimensions, cells_per_side, expected_result, expected_exception",
2931
[
30-
(3, torch.tensor([4], dtype=torch.int32),
31-
torch.tensor([[1, 4, 16]], dtype=torch.int32), None),
32-
(3, torch.tensor([4, 4, 4], dtype=torch.int32),
33-
torch.tensor([1, 4, 16], dtype=torch.int32), None),
32+
(
33+
3,
34+
torch.tensor([4], dtype=torch.int32),
35+
torch.tensor([[1, 4, 16]], dtype=torch.int32),
36+
None,
37+
),
38+
(
39+
3,
40+
torch.tensor([4, 4, 4], dtype=torch.int32),
41+
torch.tensor([1, 4, 16], dtype=torch.int32),
42+
None,
43+
),
3444
(3, torch.tensor([4, 4], dtype=torch.int32), None, ValueError),
3545
],
3646
)
37-
def test_hash_constants(spatial_dimensions, cells_per_side, expected_result,
38-
expected_exception):
47+
def test_hash_constants(
48+
spatial_dimensions, cells_per_side, expected_result, expected_exception
49+
):
3950
if expected_exception is not None:
4051
with pytest.raises(expected_exception):
4152
_hash_constants(spatial_dimensions, cells_per_side)
@@ -52,7 +63,8 @@ def test__hash_constants(data):
5263
if cells_per_side.numel() == 1:
5364
expected_result = torch.tensor(
5465
[[cells_per_side.item() ** i for i in range(spatial_dimensions)]],
55-
dtype=torch.int32)
66+
dtype=torch.int32,
67+
)
5668
else:
5769
if cells_per_side.numel() != spatial_dimensions:
5870
with pytest.raises(ValueError):
@@ -61,8 +73,10 @@ def test__hash_constants(data):
6173
return
6274

6375
augmented = torch.cat(
64-
(torch.tensor([1], dtype=torch.int32).view(1, 1),
65-
cells_per_side[:-1].view(1, -1)),
76+
(
77+
torch.tensor([1], dtype=torch.int32).view(1, 1),
78+
cells_per_side[:-1].view(1, -1),
79+
),
6680
dim=1,
6781
)
6882

tests/beignet/func/test__iota.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,25 @@ def _iota_strategy(draw):
1111
max_dimensions = 5
1212
dim = draw(st.integers(min_value=0, max_value=max_dimensions - 1))
1313

14-
shape = tuple(draw(
15-
st.lists(st.integers(min_value=1, max_value=10), min_size=1,
16-
max_size=max_dimensions)))
14+
shape = tuple(
15+
draw(
16+
st.lists(
17+
st.integers(min_value=1, max_value=10),
18+
min_size=1,
19+
max_size=max_dimensions,
20+
)
21+
)
22+
)
1723

1824
kwargs = {
19-
"dtype": draw(st.sampled_from(
20-
[torch.int32, torch.int64, torch.float32, torch.float64])),
21-
"device": draw(st.sampled_from(
22-
["cpu", "cuda"]) if torch.cuda.is_available() else st.just("cpu"))
25+
"dtype": draw(
26+
st.sampled_from([torch.int32, torch.int64, torch.float32, torch.float64])
27+
),
28+
"device": draw(
29+
st.sampled_from(["cpu", "cuda"])
30+
if torch.cuda.is_available()
31+
else st.just("cpu")
32+
),
2333
}
2434

2535
return shape, dim, kwargs
@@ -60,7 +70,7 @@ def test__iota(data):
6070
if len(shape) > 1:
6171
assert torch.equal(
6272
result.select(dim, idx),
63-
torch.tensor(idx, **kwargs).expand(*result.select(dim, idx).shape)
73+
torch.tensor(idx, **kwargs).expand(*result.select(dim, idx).shape),
6474
)
6575
else:
6676
assert result[idx].item() == idx

0 commit comments

Comments
 (0)