From d03702a690743cbfc73638c0c144c0d0e96663ad Mon Sep 17 00:00:00 2001
From: HydrogenSulfate <490868991@qq.com>
Date: Fri, 1 Nov 2024 11:21:00 +0800
Subject: [PATCH] fix bugs and enable more pd unitests

---
 deepmd/pd/infer/deep_eval.py                  | 365 +++++++-----------
 deepmd/pd/utils/env.py                        |   3 +-
 source/tests/pd/model/test_forward_lower.py   |   6 +-
 .../tests/pd/model/test_make_hessian_model.py |   2 +-
 .../pd/model/test_permutation_denoise.py      |   6 +-
 source/tests/pd/model/test_rot_denoise.py     |   6 +-
 source/tests/pd/model/test_smooth.py          |   6 +-
 source/tests/pd/model/test_smooth_denoise.py  |   6 +-
 source/tests/pd/model/test_trans_denoise.py   |   6 +-
 source/tests/pd/model/test_unused_params.py   |   6 +-
 10 files changed, 158 insertions(+), 254 deletions(-)

diff --git a/deepmd/pd/infer/deep_eval.py b/deepmd/pd/infer/deep_eval.py
index a8347ac7c0..d939c6ef7f 100644
--- a/deepmd/pd/infer/deep_eval.py
+++ b/deepmd/pd/infer/deep_eval.py
@@ -10,6 +10,7 @@
 import numpy as np
 import paddle
 
+from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT
 from deepmd.dpmodel.output_def import (
     ModelOutputDef,
     OutputVariableCategory,
@@ -32,12 +33,18 @@
 from deepmd.infer.deep_pot import (
     DeepPot,
 )
+from deepmd.infer.deep_property import (
+    DeepProperty,
+)
 from deepmd.infer.deep_wfc import (
     DeepWFC,
 )
 from deepmd.pd.model.model import (
     get_model,
 )
+from deepmd.pd.model.network.network import (
+    TypeEmbedNetConsistent,
+)
 from deepmd.pd.train.wrapper import (
     ModelWrapper,
 )
@@ -47,9 +54,11 @@
 from deepmd.pd.utils.env import (
     DEVICE,
     GLOBAL_PD_FLOAT_PRECISION,
+    RESERVED_PRECISON_DICT,
     enable_prim,
 )
 from deepmd.pd.utils.utils import (
+    to_numpy_array,
     to_paddle_tensor,
 )
 
@@ -58,7 +67,7 @@
 
 
 class DeepEval(DeepEvalBackend):
-    """Paddle backend implementaion of DeepEval.
+    """Paddle backend implementation of DeepEval.
 
     Parameters
     ----------
@@ -85,7 +94,7 @@ def __init__(
         *args: Any,
         auto_batch_size: Union[bool, int, AutoBatchSize] = True,
         neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
-        head: Optional[str] = None,
+        head: Optional[Union[str, int]] = None,
         **kwargs: Any,
     ):
         enable_prim(True)
@@ -96,9 +105,12 @@ def __init__(
             if "model" in state_dict:
                 state_dict = state_dict["model"]
             self.input_param = state_dict["_extra_state"]["model_params"]
+            self.model_def_script = self.input_param
             self.multi_task = "model_dict" in self.input_param
             if self.multi_task:
                 model_keys = list(self.input_param["model_dict"].keys())
+                if isinstance(head, int):
+                    head = model_keys[0]
                 assert (
                     head is not None
                 ), f"Head must be set for multitask model! Available heads are: {model_keys}"
@@ -120,7 +132,6 @@ def __init__(
         else:
             # self.dp = paddle.jit.load(self.model_path.split(".json")[0])
             raise ValueError(f"Unknown model file format: {self.model_path}!")
-
         self.rcut = self.dp.model["Default"].get_rcut()
         self.type_map = self.dp.model["Default"].get_type_map()
         if isinstance(auto_batch_size, bool):
@@ -158,6 +169,9 @@ def get_dim_aparam(self) -> int:
         """Get the number (dimension) of atomic parameters of this DP."""
         return self.dp.model["Default"].get_dim_aparam()
 
+    def get_intensive(self) -> bool:
+        return self.dp.model["Default"].get_intensive()
+
     @property
     def model_type(self) -> type["DeepEvalWrapper"]:
         """The the evaluator of the model type."""
@@ -174,6 +188,8 @@ def model_type(self) -> type["DeepEvalWrapper"]:
             return DeepGlobalPolar
         elif "wfc" in model_output_type:
             return DeepWFC
+        elif "property" in model_output_type:
+            return DeepProperty
         else:
             raise RuntimeError("Unknown model type")
 
@@ -190,6 +206,10 @@ def get_numb_dos(self) -> int:
         """Get the number of DOS."""
         return self.dp.model["Default"].get_numb_dos()
 
+    def get_task_dim(self) -> int:
+        """Get the output dimension."""
+        return self.dp.model["Default"].get_task_dim()
+
     def get_has_efield(self):
         """Check if the model has efield."""
         return False
@@ -365,6 +385,7 @@ def _eval_model(
         request_defs: list[OutputVariableDef],
     ):
         model = self.dp.to(DEVICE)
+        prec = NP_PRECISION_DICT[RESERVED_PRECISON_DICT[GLOBAL_PD_FLOAT_PRECISION]]
 
         nframes = coords.shape[0]
         if len(atom_types.shape) == 1:
@@ -374,15 +395,21 @@ def _eval_model(
             natoms = len(atom_types[0])
 
         coord_input = paddle.to_tensor(
-            coords.reshape([nframes, natoms, 3]),
+            coords.reshape([nframes, natoms, 3]).astype(prec),
             dtype=GLOBAL_PD_FLOAT_PRECISION,
-        ).to(DEVICE)
-        type_input = paddle.to_tensor(atom_types, dtype=paddle.int64).to(DEVICE)
+            place=DEVICE,
+        )
+        type_input = paddle.to_tensor(
+            atom_types.astype(NP_PRECISION_DICT[RESERVED_PRECISON_DICT[paddle.int64]]),
+            dtype=paddle.int64,
+            place=DEVICE,
+        )
         if cells is not None:
             box_input = paddle.to_tensor(
                 cells.reshape([nframes, 3, 3]),
                 dtype=GLOBAL_PD_FLOAT_PRECISION,
-            ).to(DEVICE)
+                place=DEVICE,
+            )
         else:
             box_input = None
         if fparam is not None:
@@ -421,7 +448,7 @@ def _eval_model(
             else:
                 shape = self._get_output_shape(odef, nframes, natoms)
                 results.append(
-                    np.full(np.abs(shape), np.nan)  # pylint: disable=no-explicit-dtype
+                    np.full(np.abs(shape), np.nan, dtype=prec)
                 )  # this is kinda hacky
         return tuple(results)
 
@@ -447,17 +474,20 @@ def _eval_model_spin(
         coord_input = paddle.to_tensor(
             coords.reshape([nframes, natoms, 3]),
             dtype=GLOBAL_PD_FLOAT_PRECISION,
-        ).to(DEVICE)
-        type_input = paddle.to_tensor(atom_types, dtype=paddle.int64).to(DEVICE)
+            place=DEVICE,
+        )
+        type_input = paddle.to_tensor(atom_types, dtype=paddle.int64, place=DEVICE)
         spin_input = paddle.to_tensor(
             spins.reshape([nframes, natoms, 3]),
             dtype=GLOBAL_PD_FLOAT_PRECISION,
-        ).to(DEVICE)
+            place=DEVICE,
+        )
         if cells is not None:
             box_input = paddle.to_tensor(
                 cells.reshape([nframes, 3, 3]),
                 dtype=GLOBAL_PD_FLOAT_PRECISION,
-            ).to(DEVICE)
+                place=DEVICE,
+            )
         else:
             box_input = None
         if fparam is not None:
@@ -498,7 +528,13 @@ def _eval_model_spin(
             else:
                 shape = self._get_output_shape(odef, nframes, natoms)
                 results.append(
-                    np.full(np.abs(shape), np.nan)  # pylint: disable=no-explicit-dtype
+                    np.full(
+                        np.abs(shape),
+                        np.nan,
+                        dtype=NP_PRECISION_DICT[
+                            RESERVED_PRECISON_DICT[GLOBAL_PD_FLOAT_PRECISION]
+                        ],
+                    )
                 )  # this is kinda hacky
         return tuple(results)
 
@@ -523,222 +559,91 @@ def _get_output_shape(self, odef, nframes, natoms):
         else:
             raise RuntimeError("unknown category")
 
+    def eval_typeebd(self) -> np.ndarray:
+        """Evaluate output of type embedding network by using this model.
 
-# For tests only
-def eval_model(
-    model,
-    coords: Union[np.ndarray, paddle.Tensor],
-    cells: Optional[Union[np.ndarray, paddle.Tensor]],
-    atom_types: Union[np.ndarray, paddle.to_tensor, list[int]],
-    spins: Optional[Union[np.ndarray, paddle.Tensor]] = None,
-    atomic: bool = False,
-    infer_batch_size: int = 2,
-    denoise: bool = False,
-):
-    model = model.to(DEVICE)
-    energy_out = []
-    atomic_energy_out = []
-    force_out = []
-    force_mag_out = []
-    virial_out = []
-    atomic_virial_out = []
-    updated_coord_out = []
-    logits_out = []
-    err_msg = (
-        f"All inputs should be the same format, "
-        f"but found {type(coords)}, {type(cells)}, {type(atom_types)} instead! "
-    )
-    return_tensor = True
-    if isinstance(coords, paddle.Tensor):
-        if cells is not None:
-            assert isinstance(cells, paddle.Tensor), err_msg
-        if spins is not None:
-            assert isinstance(spins, paddle.Tensor), err_msg
-        assert isinstance(atom_types, paddle.Tensor) or isinstance(atom_types, list)
-        atom_types = paddle.to_tensor(atom_types, dtype=paddle.int64).to(DEVICE)
-    elif isinstance(coords, np.ndarray):
-        if cells is not None:
-            assert isinstance(cells, np.ndarray), err_msg
-        if spins is not None:
-            assert isinstance(spins, np.ndarray), err_msg
-        assert isinstance(atom_types, np.ndarray) or isinstance(atom_types, list)
-        atom_types = np.array(atom_types, dtype=np.int32)
-        return_tensor = False
-
-    nframes = coords.shape[0]
-    if len(atom_types.shape) == 1:
-        natoms = len(atom_types)
-        if isinstance(atom_types, paddle.Tensor):
-            atom_types = paddle.tile(atom_types.unsqueeze(0), [nframes, 1]).reshape(
-                [nframes, -1]
-            )
-        else:
-            atom_types = np.tile(atom_types, nframes).reshape([nframes, -1])
-    else:
-        natoms = len(atom_types[0])
-
-    coord_input = paddle.to_tensor(
-        coords.reshape([-1, natoms, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION
-    ).to(DEVICE)
-    spin_input = None
-    if spins is not None:
-        spin_input = paddle.to_tensor(
-            spins.reshape([-1, natoms, 3]),
-            dtype=GLOBAL_PD_FLOAT_PRECISION,
-        ).to(DEVICE)
-    has_spin = getattr(model, "has_spin", False)
-    if callable(has_spin):
-        has_spin = has_spin()
-    type_input = paddle.to_tensor(atom_types, dtype=paddle.int64).to(DEVICE)
-    box_input = None
-    if cells is None:
-        pbc = False
-    else:
-        pbc = True
-        box_input = paddle.to_tensor(
-            cells.reshape([-1, 3, 3]), dtype=GLOBAL_PD_FLOAT_PRECISION
-        ).to(DEVICE)
-    num_iter = int((nframes + infer_batch_size - 1) / infer_batch_size)
-
-    for ii in range(num_iter):
-        batch_coord = coord_input[ii * infer_batch_size : (ii + 1) * infer_batch_size]
-        batch_atype = type_input[ii * infer_batch_size : (ii + 1) * infer_batch_size]
-        batch_box = None
-        batch_spin = None
-        if spin_input is not None:
-            batch_spin = spin_input[ii * infer_batch_size : (ii + 1) * infer_batch_size]
-        if pbc:
-            batch_box = box_input[ii * infer_batch_size : (ii + 1) * infer_batch_size]
-        input_dict = {
-            "coord": batch_coord,
-            "atype": batch_atype,
-            "box": batch_box,
-            "do_atomic_virial": atomic,
-        }
-        if has_spin:
-            input_dict["spin"] = batch_spin
-        batch_output = model(**input_dict)
-        if isinstance(batch_output, tuple):
-            batch_output = batch_output[0]
-        if not return_tensor:
-            if "energy" in batch_output:
-                energy_out.append(batch_output["energy"].numpy())
-            if "atom_energy" in batch_output:
-                atomic_energy_out.append(batch_output["atom_energy"].numpy())
-            if "force" in batch_output:
-                force_out.append(batch_output["force"].numpy())
-            if "force_mag" in batch_output:
-                force_mag_out.append(batch_output["force_mag"].numpy())
-            if "virial" in batch_output:
-                virial_out.append(batch_output["virial"].numpy())
-            if "atom_virial" in batch_output:
-                atomic_virial_out.append(batch_output["atom_virial"].numpy())
-            if "updated_coord" in batch_output:
-                updated_coord_out.append(batch_output["updated_coord"].numpy())
-            if "logits" in batch_output:
-                logits_out.append(batch_output["logits"].numpy())
-        else:
-            if "energy" in batch_output:
-                energy_out.append(batch_output["energy"])
-            if "atom_energy" in batch_output:
-                atomic_energy_out.append(batch_output["atom_energy"])
-            if "force" in batch_output:
-                force_out.append(batch_output["force"])
-            if "force_mag" in batch_output:
-                force_mag_out.append(batch_output["force_mag"])
-            if "virial" in batch_output:
-                virial_out.append(batch_output["virial"])
-            if "atom_virial" in batch_output:
-                atomic_virial_out.append(batch_output["atom_virial"])
-            if "updated_coord" in batch_output:
-                updated_coord_out.append(batch_output["updated_coord"])
-            if "logits" in batch_output:
-                logits_out.append(batch_output["logits"])
-    if not return_tensor:
-        energy_out = (
-            np.concatenate(energy_out) if energy_out else np.zeros([nframes, 1])  # pylint: disable=no-explicit-dtype
-        )
-        atomic_energy_out = (
-            np.concatenate(atomic_energy_out)
-            if atomic_energy_out
-            else np.zeros([nframes, natoms, 1])  # pylint: disable=no-explicit-dtype
-        )
-        force_out = (
-            np.concatenate(force_out) if force_out else np.zeros([nframes, natoms, 3])  # pylint: disable=no-explicit-dtype
-        )
-        force_mag_out = (
-            np.concatenate(force_mag_out)
-            if force_mag_out
-            else np.zeros([nframes, natoms, 3])  # pylint: disable=no-explicit-dtype
-        )
-        virial_out = (
-            np.concatenate(virial_out) if virial_out else np.zeros([nframes, 3, 3])  # pylint: disable=no-explicit-dtype
-        )
-        atomic_virial_out = (
-            np.concatenate(atomic_virial_out)
-            if atomic_virial_out
-            else np.zeros([nframes, natoms, 3, 3])  # pylint: disable=no-explicit-dtype
-        )
-        updated_coord_out = (
-            np.concatenate(updated_coord_out) if updated_coord_out else None
-        )
-        logits_out = np.concatenate(logits_out) if logits_out else None
-    else:
-        energy_out = (
-            paddle.concat(energy_out)
-            if energy_out
-            else paddle.zeros([nframes, 1], dtype=GLOBAL_PD_FLOAT_PRECISION).to(DEVICE)
-        )
-        atomic_energy_out = (
-            paddle.concat(atomic_energy_out)
-            if atomic_energy_out
-            else paddle.zeros([nframes, natoms, 1], dtype=GLOBAL_PD_FLOAT_PRECISION).to(
-                DEVICE
-            )
-        )
-        force_out = (
-            paddle.concat(force_out)
-            if force_out
-            else paddle.zeros([nframes, natoms, 3], dtype=GLOBAL_PD_FLOAT_PRECISION).to(
-                DEVICE
-            )
-        )
-        force_mag_out = (
-            paddle.concat(force_mag_out)
-            if force_mag_out
-            else paddle.zeros([nframes, natoms, 3], dtype=GLOBAL_PD_FLOAT_PRECISION).to(
-                DEVICE
-            )
-        )
-        virial_out = (
-            paddle.concat(virial_out)
-            if virial_out
-            else paddle.zeros([nframes, 3, 3], dtype=GLOBAL_PD_FLOAT_PRECISION).to(
-                DEVICE
-            )
-        )
-        atomic_virial_out = (
-            paddle.concat(atomic_virial_out)
-            if atomic_virial_out
-            else paddle.zeros(
-                [nframes, natoms, 3, 3], dtype=GLOBAL_PD_FLOAT_PRECISION
-            ).to(DEVICE)
-        )
-        updated_coord_out = (
-            paddle.concat(updated_coord_out) if updated_coord_out else None
+        Returns
+        -------
+        np.ndarray
+            The output of type embedding network. The shape is [ntypes, o_size] or [ntypes + 1, o_size],
+            where ntypes is the number of types, and o_size is the number of nodes
+            in the output layer. If there are multiple type embedding networks,
+            these outputs will be concatenated along the second axis.
+
+        Raises
+        ------
+        KeyError
+            If the model does not enable type embedding.
+
+        See Also
+        --------
+        deepmd.pd.model.network.network.TypeEmbedNetConsistent :
+            The type embedding network.
+        """
+        out = []
+        for mm in self.dp.model["Default"].modules():
+            if mm.original_name == TypeEmbedNetConsistent.__name__:
+                out.append(mm(DEVICE))
+        if not out:
+            raise KeyError("The model has no type embedding networks.")
+        typeebd = paddle.concat(out, axis=1)
+        return to_numpy_array(typeebd)
+
+    def get_model_def_script(self) -> str:
+        """Get model definition script."""
+        return self.model_def_script
+
+    def eval_descriptor(
+        self,
+        coords: np.ndarray,
+        cells: Optional[np.ndarray],
+        atom_types: np.ndarray,
+        fparam: Optional[np.ndarray] = None,
+        aparam: Optional[np.ndarray] = None,
+        **kwargs: Any,
+    ) -> np.ndarray:
+        """Evaluate descriptors by using this DP.
+
+        Parameters
+        ----------
+        coords
+            The coordinates of atoms.
+            The array should be of size nframes x natoms x 3
+        cells
+            The cell of the region.
+            If None then non-PBC is assumed, otherwise using PBC.
+            The array should be of size nframes x 9
+        atom_types
+            The atom types
+            The list should contain natoms ints
+        fparam
+            The frame parameter.
+            The array can be of size :
+            - nframes x dim_fparam.
+            - dim_fparam. Then all frames are assumed to be provided with the same fparam.
+        aparam
+            The atomic parameter
+            The array can be of size :
+            - nframes x natoms x dim_aparam.
+            - natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam.
+            - dim_aparam. Then all frames and atoms are provided with the same aparam.
+
+        Returns
+        -------
+        descriptor
+            Descriptors.
+        """
+        model = self.dp.model["Default"]
+        model.set_eval_descriptor_hook(True)
+        self.eval(
+            coords,
+            cells,
+            atom_types,
+            atomic=False,
+            fparam=fparam,
+            aparam=aparam,
+            **kwargs,
         )
-        logits_out = paddle.concat(logits_out) if logits_out else None
-    if denoise:
-        return updated_coord_out, logits_out
-    else:
-        results_dict = {
-            "energy": energy_out,
-            "force": force_out,
-            "virial": virial_out,
-        }
-        if has_spin:
-            results_dict["force_mag"] = force_mag_out
-        if atomic:
-            results_dict["atom_energy"] = atomic_energy_out
-            results_dict["atom_virial"] = atomic_virial_out
-        return results_dict
+        descriptor = model.eval_descriptor()
+        model.set_eval_descriptor_hook(False)
+        return to_numpy_array(descriptor)
diff --git a/deepmd/pd/utils/env.py b/deepmd/pd/utils/env.py
index 37b6259b61..e9593d4c50 100644
--- a/deepmd/pd/utils/env.py
+++ b/deepmd/pd/utils/env.py
@@ -15,8 +15,6 @@
     set_default_nthreads,
 )
 
-log = logging.getLogger(__name__)
-
 SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False)
 try:
     # only linux
@@ -87,6 +85,7 @@ def enable_prim(enable: bool = True):
 
         core.set_prim_eager_enabled(True)
         core._set_prim_all_enabled(True)
+        log = logging.getLogger(__name__)
         log.info("Enable prim in eager and static mode.")
 
 
diff --git a/source/tests/pd/model/test_forward_lower.py b/source/tests/pd/model/test_forward_lower.py
index efd0b638d8..dc348b5a37 100644
--- a/source/tests/pd/model/test_forward_lower.py
+++ b/source/tests/pd/model/test_forward_lower.py
@@ -5,9 +5,6 @@
 import numpy as np
 import paddle
 
-from deepmd.pd.infer.deep_eval import (
-    eval_model,
-)
 from deepmd.pd.model.model import (
     get_model,
 )
@@ -22,6 +19,9 @@
 from ...seed import (
     GLOBAL_SEED,
 )
+from ..common import (
+    eval_model,
+)
 from .test_permutation import (  # model_dpau,
     model_dpa1,
     model_dpa2,
diff --git a/source/tests/pd/model/test_make_hessian_model.py b/source/tests/pd/model/test_make_hessian_model.py
index 30171342aa..79a7c4f163 100644
--- a/source/tests/pd/model/test_make_hessian_model.py
+++ b/source/tests/pd/model/test_make_hessian_model.py
@@ -137,7 +137,7 @@ def ff(xx):
         )
 
 
-@unittest.skip("TODO")
+@unittest.skip("Skip temporarily")
 class TestDPModel(unittest.TestCase, HessianTest):
     def setUp(self):
         paddle.seed(2)
diff --git a/source/tests/pd/model/test_permutation_denoise.py b/source/tests/pd/model/test_permutation_denoise.py
index f147e360f7..0f3dc9e871 100644
--- a/source/tests/pd/model/test_permutation_denoise.py
+++ b/source/tests/pd/model/test_permutation_denoise.py
@@ -5,9 +5,6 @@
 import numpy as np
 import paddle
 
-from deepmd.pd.infer.deep_eval import (
-    eval_model,
-)
 from deepmd.pd.model.model import (
     get_model,
 )
@@ -18,6 +15,9 @@
 from ...seed import (
     GLOBAL_SEED,
 )
+from ..common import (
+    eval_model,
+)
 from .test_permutation import (  # model_dpau,
     model_dpa1,
     model_dpa2,
diff --git a/source/tests/pd/model/test_rot_denoise.py b/source/tests/pd/model/test_rot_denoise.py
index bd1c858339..9526084efe 100644
--- a/source/tests/pd/model/test_rot_denoise.py
+++ b/source/tests/pd/model/test_rot_denoise.py
@@ -5,9 +5,6 @@
 import numpy as np
 import paddle
 
-from deepmd.pd.infer.deep_eval import (
-    eval_model,
-)
 from deepmd.pd.model.model import (
     get_model,
 )
@@ -18,6 +15,9 @@
 from ...seed import (
     GLOBAL_SEED,
 )
+from ..common import (
+    eval_model,
+)
 from .test_permutation_denoise import (
     model_dpa1,
     model_dpa2,
diff --git a/source/tests/pd/model/test_smooth.py b/source/tests/pd/model/test_smooth.py
index 7ad7152b60..796b15faf4 100644
--- a/source/tests/pd/model/test_smooth.py
+++ b/source/tests/pd/model/test_smooth.py
@@ -5,9 +5,6 @@
 import numpy as np
 import paddle
 
-from deepmd.pd.infer.deep_eval import (
-    eval_model,
-)
 from deepmd.pd.model.model import (
     get_model,
 )
@@ -18,6 +15,9 @@
 from ...seed import (
     GLOBAL_SEED,
 )
+from ..common import (
+    eval_model,
+)
 from .test_permutation import (  # model_dpau,
     model_dos,
     model_dpa1,
diff --git a/source/tests/pd/model/test_smooth_denoise.py b/source/tests/pd/model/test_smooth_denoise.py
index 1563981e96..d94f15863d 100644
--- a/source/tests/pd/model/test_smooth_denoise.py
+++ b/source/tests/pd/model/test_smooth_denoise.py
@@ -5,9 +5,6 @@
 import numpy as np
 import paddle
 
-from deepmd.pd.infer.deep_eval import (
-    eval_model,
-)
 from deepmd.pd.model.model import (
     get_model,
 )
@@ -18,6 +15,9 @@
 from ...seed import (
     GLOBAL_SEED,
 )
+from ..common import (
+    eval_model,
+)
 from .test_permutation_denoise import (
     model_dpa2,
 )
diff --git a/source/tests/pd/model/test_trans_denoise.py b/source/tests/pd/model/test_trans_denoise.py
index 600a96ad8e..8317d4d2ae 100644
--- a/source/tests/pd/model/test_trans_denoise.py
+++ b/source/tests/pd/model/test_trans_denoise.py
@@ -5,9 +5,6 @@
 import numpy as np
 import paddle
 
-from deepmd.pd.infer.deep_eval import (
-    eval_model,
-)
 from deepmd.pd.model.model import (
     get_model,
 )
@@ -18,6 +15,9 @@
 from ...seed import (
     GLOBAL_SEED,
 )
+from ..common import (
+    eval_model,
+)
 from .test_permutation_denoise import (
     model_dpa1,
     model_dpa2,
diff --git a/source/tests/pd/model/test_unused_params.py b/source/tests/pd/model/test_unused_params.py
index e634ecb022..3424e9dafa 100644
--- a/source/tests/pd/model/test_unused_params.py
+++ b/source/tests/pd/model/test_unused_params.py
@@ -4,9 +4,6 @@
 
 import paddle
 
-from deepmd.pd.infer.deep_eval import (
-    eval_model,
-)
 from deepmd.pd.model.model import (
     get_model,
 )
@@ -17,6 +14,9 @@
 from ...seed import (
     GLOBAL_SEED,
 )
+from ..common import (
+    eval_model,
+)
 from .test_permutation import (
     model_dpa2,
 )