Skip to content

Commit

Permalink
Update test to include non-uniform case
Browse files Browse the repository at this point in the history
  • Loading branch information
JulianKnodt committed Oct 6, 2022
1 parent f124dc4 commit 2603985
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
18 changes: 11 additions & 7 deletions pytorch3d/structures/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,20 +1559,24 @@ def volume_centroid(self):
"""
v_idxs = self.faces_padded().split([1, 1, 1], dim=-1)
verts = self.verts_padded()

v0, v1, v2 = [torch.gather(verts, 1, idx.expand(-1, -1, 3)) for idx in v_idxs]
valid = (self.faces_padded() != -1).all(dim=-1, keepdim=True)

v0, v1, v2 = [
torch.gather(
verts,
1,
idx.where(valid, torch.zeros_like(idx)).expand(-1, -1, 3),
).where(valid, torch.zeros_like(idx, dtype=verts.dtype))
for idx in v_idxs
]

tetra_center = (v0 + v1 + v2) / 4
signed_tetra_vol = (v0 * torch.cross(v1, v2, dim=-1)).sum(
dim=-1, keepdim=True
) / 6
denom = signed_tetra_vol.sum(dim=-2)
# clamp the denominator to prevent instability for degenerate meshes.
denom = torch.where(
denom < 0,
denom.clamp(max=-1e-5),
denom.clamp(min=1e-5)
)
denom = torch.where(denom < 0, denom.clamp(max=-1e-5), denom.clamp(min=1e-5))
return (tetra_center * signed_tetra_vol).sum(dim=-2) / denom

def submeshes(
Expand Down
23 changes: 17 additions & 6 deletions tests/test_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,13 +1299,24 @@ def test_assigned_normals(self):
self.assertFalse(torch.allclose(yes_normals.verts_normals_padded(), verts))

def test_centroid(self):
meshes = init_simple_mesh()
# Check that it returns a valid value for multiple meshes with an inconsistent number
# of vertices
meshes.volume_centroid()

cube = init_cube_meshes()
self.assertClose(cube.volume_centroid(), torch.tensor([
[0.5] * 3,
[1.5] * 3,
[2.5] * 3,
[3.5] * 3,
]))
self.assertClose(
cube.volume_centroid(),
torch.tensor(
[
[0.5] * 3,
[1.5] * 3,
[2.5] * 3,
[3.5] * 3,
]
),
)

def test_submeshes(self):
empty_mesh = Meshes([], [])
# Four cubes with offsets [0, 1, 2, 3].
Expand Down

0 comments on commit 2603985

Please sign in to comment.