Skip to content

Commit 70d324e

Browse files
author
Henry Isaacson
committed
cleanup
1 parent 563764c commit 70d324e

21 files changed

+1706
-1475
lines changed

src/beignet/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from ._invert_rotation_matrix import invert_rotation_matrix
3131
from ._invert_rotation_vector import invert_rotation_vector
3232
from ._invert_transform import invert_transform
33+
from ._iota import iota
34+
from ._pairwise_displacement import pairwise_displacement
3335
from ._quaternion_identity import quaternion_identity
3436
from ._quaternion_magnitude import quaternion_magnitude
3537
from ._quaternion_mean import quaternion_mean
@@ -68,6 +70,8 @@
6870
rotation_vector_to_rotation_matrix,
6971
)
7072
from ._translation_identity import translation_identity
73+
from ._segment_sum import segment_sum
74+
from ._square_distance import square_distance
7175
from .special import error_erf, error_erfc
7276

7377
__all__ = [

src/beignet/_iota.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
from torch import Tensor
3+
4+
5+
def iota(shape: tuple[int, ...], dim: int = 0, **kwargs) -> Tensor:
6+
r"""Generate a tensor with a specified shape where elements along the given dimension
7+
are sequential integers starting from 0.
8+
9+
Parameters
10+
----------
11+
shape : tuple[int, ...]
12+
The shape of the resulting tensor.
13+
dim : int, optional
14+
The dimension along which to vary the values (default is 0).
15+
16+
Returns
17+
-------
18+
Tensor
19+
A tensor of the specified shape with sequential integers along the specified dimension.
20+
21+
Raises
22+
------
23+
IndexError
24+
If `dim` is out of the range of `shape`.
25+
"""
26+
dimensions = []
27+
28+
for index, _ in enumerate(shape):
29+
if index != dim:
30+
dimension = 1
31+
32+
else:
33+
dimension = shape[index]
34+
35+
dimensions = [*dimensions, dimension]
36+
37+
return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape)

src/beignet/_pairwise_displacement.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
11
from torch import Tensor
22

33

4-
def pairwise_displacement(Ra: Tensor, Rb: Tensor) -> Tensor:
4+
def pairwise_displacement(input: Tensor, other: Tensor) -> Tensor:
55
r"""Compute a matrix of pairwise displacements given two sets of positions.
66
77
Parameters
88
----------
9-
Ra : Tensor
9+
input : Tensor
1010
Vector of positions
11-
Rb : Tensor
11+
12+
other : Tensor
1213
Vector of positions
1314
1415
Returns:
15-
Tensor(shape=[spatial_dim]
16-
Matrix of displacements
16+
-------
17+
output : Tensor, shape [spatial_dimensions]
18+
Matrix of displacements
1719
"""
18-
if len(Ra.shape) != 1:
19-
msg = (
20+
if len(input.shape) != 1:
21+
message = (
2022
"Can only compute displacements between vectors. To compute "
2123
"displacements between sets of vectors use vmap or TODO."
2224
)
23-
raise ValueError(msg)
25+
raise ValueError(message)
2426

25-
if Ra.shape != Rb.shape:
26-
msg = "Can only compute displacement between vectors of equal dimension."
27-
raise ValueError(msg)
27+
if input.shape != other.shape:
28+
message = "Can only compute displacement between vectors of equal dimension."
29+
raise ValueError(message)
2830

29-
return Ra - Rb
31+
return input - other

src/beignet/_periodic_displacement.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,23 @@
22
from torch import Tensor
33

44

5-
def periodic_displacement(box: float | Tensor, dR: Tensor) -> Tensor:
5+
def periodic_displacement(input: Tensor, position: Tensor) -> Tensor:
66
r"""Wraps displacement vectors into a hypercube.
77
8-
Parameters
8+
Parameters:
99
----------
1010
box : float or Tensor
1111
Specification of hypercube size. Either:
12-
(a) float if all sides have equal length.
12+
(a) scalar if all sides have equal length.
1313
(b) Tensor of shape (spatial_dim,) if sides have different lengths.
14+
1415
dR : Tensor
1516
Matrix of displacements with shape (..., spatial_dim).
1617
17-
Returns
18+
Returns:
1819
-------
19-
Tensor
20+
output : Tensor, shape=(...)
2021
Matrix of wrapped displacements with shape (..., spatial_dim).
2122
"""
22-
distances = (
23-
torch.remainder(dR + box * torch.tensor(0.5, dtype=torch.float32), box)
24-
- torch.tensor(0.5, dtype=torch.float32) * box
25-
)
26-
return distances
23+
output = torch.remainder(position + input * 0.5, input) - 0.5 * input
24+
return output

src/beignet/_segment_sum.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import math
2+
from typing import Optional
3+
4+
import torch
5+
from torch import Tensor
6+
7+
8+
def segment_sum(
9+
input: Tensor,
10+
indexes: Tensor,
11+
n: Optional[int] = None,
12+
**kwargs,
13+
) -> Tensor:
14+
r"""Computes the sum of segments of a tensor along the first dimension.
15+
16+
Parameters
17+
----------
18+
input : Tensor
19+
A tensor containing the input values to be summed.
20+
21+
indexes : Tensor
22+
A 1D tensor containing the segment indexes for summation.
23+
Should have the same length as the first dimension of the `input` tensor.
24+
25+
n : Optional[int], optional
26+
The number of segments, by default `n` is set to `max(indexes) + 1`.
27+
28+
Returns
29+
-------
30+
Tensor
31+
A tensor where each entry contains the sum of the corresponding segment
32+
from the `input` tensor.
33+
"""
34+
if indexes.ndim == 1:
35+
indexes = torch.repeat_interleave(indexes, math.prod([*input.shape[1:]])).view(
36+
*[indexes.shape[0], *input.shape[1:]]
37+
)
38+
39+
if input.size(0) != indexes.size(0):
40+
raise ValueError(
41+
"The length of the indexes tensor must match the size of the first dimension of the input tensor."
42+
)
43+
44+
if n is None:
45+
n = indexes.max().item() + 1
46+
47+
valid_mask = indexes < n
48+
valid_indexes = indexes[valid_mask]
49+
valid_input = input[valid_mask]
50+
51+
output = torch.zeros(n, *input.shape[1:], device=input.device)
52+
53+
return output.scatter_add(0, valid_indexes, valid_input.to(torch.float32)).to(
54+
**kwargs
55+
)

src/beignet/_square_distance.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import Callable, Any
2+
3+
import torch
4+
from torch import Tensor
5+
6+
7+
def _square_distance(input: Tensor) -> Tensor:
8+
"""Computes square distances.
9+
10+
Args:
11+
input: Matrix of displacements; `Tensor(shape=[..., spatial_dim])`.
12+
Returns:
13+
Matrix of squared distances; `Tensor(shape=[...])`.
14+
"""
15+
return torch.sum(input**2, dim=-1)
16+
17+
18+
def _safe_mask(
19+
mask: Tensor, fn: Callable, operand: Tensor, placeholder: Any = 0
20+
) -> Tensor:
21+
r"""Applies a function to elements of a tensor where a mask is True, and replaces elements where the mask is False with a placeholder.
22+
23+
Parameters
24+
----------
25+
mask : Tensor
26+
A boolean tensor indicating which elements to apply the function to.
27+
fn : Callable[[Tensor], Tensor]
28+
The function to apply to the masked elements.
29+
operand : Tensor
30+
The tensor to apply the function to.
31+
placeholder : Any, optional
32+
The value to use for elements where the mask is False (default is 0).
33+
34+
Returns
35+
-------
36+
Tensor
37+
A tensor with the function applied to the masked elements and the placeholder value elsewhere.
38+
"""
39+
masked = torch.where(mask, operand, torch.tensor(0, dtype=operand.dtype))
40+
41+
return torch.where(mask, fn(masked), torch.tensor(placeholder, dtype=operand.dtype))
42+
43+
44+
def square_distance(dR: Tensor) -> Tensor:
45+
r"""Computes distances.
46+
47+
Args:
48+
dR: Matrix of displacements; `Tensor(shape=[..., spatial_dim])`.
49+
Returns:
50+
Matrix of distances; `Tensor(shape=[...])`.
51+
"""
52+
return _safe_mask(_square_distance(dR) > 0, torch.sqrt, _square_distance(dR))

src/beignet/func/__dataclass.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@ def _set(self: dataclasses.dataclass, **kwargs):
2222
else:
2323
metadata_fields = [*metadata_fields, name]
2424

25-
def _iterate_cls(_x) -> List[Tuple]:
25+
def _iterate_cls(_x) -> list[list]:
2626
data_iterable = []
2727

2828
for k in data_fields:
29-
data_iterable.append(getattr(_x, k))
29+
data_iterable = [*data_iterable, getattr(_x, k)]
3030

3131
metadata_iterable = []
3232

3333
for k in metadata_fields:
34-
metadata_iterable.append(getattr(_x, k))
34+
metadata_iterable = [*metadata_iterable, getattr(_x, k)]
3535

3636
return [data_iterable, metadata_iterable]
3737

@@ -46,7 +46,7 @@ def _iterable_to_cls(meta, data):
4646
dataclass_cls,
4747
_iterate_cls,
4848
_iterable_to_cls,
49-
"prescient.func",
49+
"beignet.func",
5050
)
5151

5252
return dataclass_cls

0 commit comments

Comments
 (0)