Skip to content

Commit

Permalink
refine code
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Jan 12, 2024
1 parent 3021f6b commit 41c586f
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions deepmd/descriptor/se_a_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,11 @@ def __init__(
nei_type = np.append(nei_type, ii * np.ones(self.sel_a[ii])) # like a mask
# self.nei_type = tf.constant(nei_type, dtype=tf.int32)
# self.nei_type = paddle.to_tensor(nei_type, dtype="int32")
self.register_buffer("nei_type", paddle.to_tensor(nei_type, dtype="int32"))
self.register_buffer(
"buffer_ntypes_spin", paddle.to_tensor(nei_type, dtype="int32")
)

nets = []
# self._pass_filter => self._filter => self._filter_lower
for type_input in range(self.ntypes):
layer = []
for type_i in range(self.ntypes):
Expand Down Expand Up @@ -536,7 +538,7 @@ def _filter_lower(
is_exclude=False,
):
"""Input env matrix, returns R.G."""
outputs_size = [1] + self.filter_neuron
outputs_size = [1, *self.filter_neuron]
# cut-out inputs
# with natom x (nei_type_i x 4)
inputs_i = paddle.slice(
Expand Down Expand Up @@ -663,10 +665,12 @@ def _filter(
nframes = 1
# natom x (nei x 4)
shape = inputs.shape
outputs_size = [1] + self.filter_neuron
outputs_size = [1, *self.filter_neuron]
outputs_size_2 = self.n_axis_neuron # 16
all_excluded = all(
[
# FIXME: the bracket '[]' is needed when convert to static model, will be
# removed when fixed.
[ # noqa
(type_input, type_i) in self.exclude_types # set()
for type_i in range(self.ntypes)
]
Expand Down

0 comments on commit 41c586f

Please sign in to comment.