Skip to content

Commit

Permalink
update CPU train and content in RAEDME
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Nov 28, 2023
1 parent a38d4f0 commit 400ece7
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 24 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

3. 安装 deepmd-kit

由于 DeepMD-kit 大量代码基于 tensorflow 编写,暂时没有完全迁移到 PaddlePaddle 上,因此运行前需要安装 tensorflow 2.12。

``` sh
git clone https://github.com/HydrogenSulfate/deepmd-kit.git -b add_ddle_backend
cd deepmd-kit
Expand Down Expand Up @@ -52,8 +54,11 @@ python ./custom_op_test.py
``` sh
# 进入案例目录
cd examples/water/se_e2_a
# 运行训练
# 运行 GPU 训练
dp train ./input.json
# 运行 CPU 训练(速度极慢,不推荐运行,仅作为跑通测试)
dp train ./input.json --cpu
```

### 2.3 评估
Expand Down
7 changes: 7 additions & 0 deletions deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@ def main_parser() -> argparse.ArgumentParser:
help="Skip calculating neighbor statistics. Sel checking, automatic sel, and model compression will be disabled.",
)

parser_train.add_argument(
"--cpu",
action="store_true",
default=False,
help="Training on CPU",
)

# * freeze script ******************************************************************
parser_frz = subparsers.add_parser(
"freeze",
Expand Down
30 changes: 7 additions & 23 deletions deepmd/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def train(
RuntimeError
if distributed training job name is wrong
"""
if kwargs.get("cpu", False):
import paddle

paddle.set_device("cpu")
print("[NOTE]", "=" * 10, "Running paddle code on CPU", "=" * 10)

run_opt = RunOptions(
init_model=init_model,
restart=restart,
Expand Down Expand Up @@ -386,10 +392,7 @@ def get_nbor_stat(jdata, rcut, one_type: bool = False):

neistat = NeighborStat(ntypes, rcut, one_type=one_type)

min_nbor_dist, max_nbor_size = neistat.get_stat(
train_data
) # 0.8854385688525511, [38 72]
# paddle: 0.8854385614395142 [38 72]
min_nbor_dist, max_nbor_size = neistat.get_stat(train_data)

# moved from traier.py as duplicated
# TODO: this is a simple fix but we should have a clear
Expand Down Expand Up @@ -466,25 +469,6 @@ def update_one_sel(jdata, descriptor):
"not less than %d, but you set it to %d. The accuracy"
" of your model may get worse." % (ii, tt, dd)
)
"""
descriptor:
{
'type': 'se_e2_a',
'sel': [46, 92],
'rcut_smth': 0.5,
'rcut': 6.0,
'neuron': [25, 50, 100],
'resnet_dt': False,
'axis_neuron': 16,
'seed': 1,
'activation_function': 'tanh',
'type_one_side': False,
'precision': 'default',
'trainable': True,
'exclude_types': [],
'set_davg_zero': False
}
"""
if descriptor["type"] in ("se_atten",):
descriptor["sel"] = sel = sum(sel)
return descriptor
Expand Down

0 comments on commit 400ece7

Please sign in to comment.