Skip to content

Commit

Permalink
chore: test consistency of rotation matrix (#4550)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **Tests**
- Updated test methods in multiple descriptor test files to return
two-element tuples instead of single-element tuples
- Modified `build_tf_descriptor` method to include rotation matrix in
the return value

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Jan 14, 2025
1 parent b7effe5 commit 33df869
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion source/tests/consistent/descriptor/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def build_tf_descriptor(self, obj, natoms, coords, atype, box, suffix):
)
# ensure get_dim_out gives the correct shape
t_des = tf.reshape(t_des, [1, natoms[0], obj.get_dim_out()])
return [t_des], {
return [t_des, obj.get_rot_mat()], {
t_coord: coords,
t_type: atype,
t_natoms: natoms,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/descriptor/test_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
return (ret[0],)
return (ret[0], ret[1])

@property
def rtol(self) -> float:
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/descriptor/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,4 @@ def eval_jax(self, jax_obj: Any) -> Any:
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
return (ret[0],)
return (ret[0], ret[1])
2 changes: 1 addition & 1 deletion source/tests/consistent/descriptor/test_se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
return (ret[0],)
return (ret[0], ret[1])

@property
def rtol(self) -> float:
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/descriptor/test_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
return (ret[0],)
return (ret[0], ret[1])

@property
def rtol(self) -> float:
Expand Down

0 comments on commit 33df869

Please sign in to comment.