Skip to content

Commit

Permalink
Merge pull request #6 from sxm13/main
Browse files Browse the repository at this point in the history
add "keep_connect"
  • Loading branch information
sxm13 authored Sep 24, 2024
2 parents 6ea18e6 + ce8efac commit f410ce4
Show file tree
Hide file tree
Showing 8 changed files with 3,078 additions and 39 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ pip install -r requirements.txt

```sh
from PACMANCharge import pmcharge
pmcharge.predict(cif_file="./test/Cu-BTC.cif",charge_type="DDEC6",digits=10,atom_type=True,neutral=True)
pmcharge.predict(cif_file="./test/Cu-BTC.cif",charge_type="DDEC6",digits=10,atom_type=True,neutral=True,keep_connect=True)

```
**Terminal**
```sh
python pmcharge.py folder-name[path] --charge_type[DDEC6/Bader/CM5/REPEAT] --digits[int] --atom_type[bool] --neutral[bool]
python pmcharge.py folder-name[path] --charge_type[DDEC6/Bader/CM5/REPEAT] --digits[int] --atom_type[bool] --neutral[bool] --keep_connect[bool]
```
**Example command:** ```python pmcharge.py test_file/test-1/ --charge_type DDEC6 --digits 10```

Expand All @@ -53,6 +53,7 @@ python pmcharge.py folder-name[path] --charge_type[DDEC6/Bader/CM5/REPEAT] --dig
* digits (default: 6): number of decimal places to print for partial atomic charges. ML models were trained on a 6-digit dataset
* atom-type (default: True): Default is to keep the same partial atomic charge for the same atom types (based on the similarity of partial atomic charges up to 3 decimal places)
* neutral (default: True): Default is to keep the net charge is zero. We use "mean" method to neuralize the system where the excess charges are equally distributed across all atoms
* keep_connect (default: True): retain the atomic and connection information (such as _atom_site_adp_type, bond) for the structure.

# Website & Zenodo
* Predict partial atomic charges using an online APP :point_right: [link](https://pacman-charge-mtap.streamlit.app/)
Expand Down
37 changes: 34 additions & 3 deletions model/cif2data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import re
import glob
import json
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm
from CifFile import ReadCif
import pymatgen.core as mg
from ase.io import read
from ase import neighborlist
Expand Down Expand Up @@ -207,6 +209,14 @@ def get_ddec_data(root_cif_dir,dataset_csv,save_ddec_dir):
np.save(save_ddec_dir + mof + '.npy', ddec_data)
f.close()

def get_ddec_data_from_qmof(root_cif_dir,dataset_csv,save_ddec_dir):
mofs = pd.read_csv(dataset_csv)["name"]
for mof in mofs:
data_cif = ReadCif(root_cif_dir + mof + ".cif")
charge = data_cif[data_cif.keys()[0]]["_atom_site_pbe_ddec_charge"]
np.save(save_ddec_dir + mof + '.npy', charge)


def get_bader_data(root_cif_dir,dataset_csv,save_bader_dir):
mofs = pd.read_csv(dataset_csv)["name"]
for mof in mofs:
Expand Down Expand Up @@ -234,7 +244,14 @@ def get_bader_data(root_cif_dir,dataset_csv,save_bader_dir):
f.close()
except:
pass


def get_bader_data_from_qmof(root_cif_dir,dataset_csv,save_bader_dir):
mofs = pd.read_csv(dataset_csv)["name"]
for mof in mofs:
data_cif = ReadCif(root_cif_dir + mof + ".cif")
charge = data_cif[data_cif.keys()[0]]["_atom_site_pbe_bader_charge"]
np.save(save_bader_dir + mof + '.npy', charge)

def get_cm5_data(root_cif_dir,dataset_csv,save_cm5_dir):
mofs = pd.read_csv(dataset_csv)["name"]
for mof in mofs:
Expand Down Expand Up @@ -263,6 +280,13 @@ def get_cm5_data(root_cif_dir,dataset_csv,save_cm5_dir):
except:
pass

def get_cm5_data_from_qmof(root_cif_dir,dataset_csv,save_cm5_dir):
mofs = pd.read_csv(dataset_csv)["name"]
for mof in mofs:
data_cif = ReadCif(root_cif_dir + mof + ".cif")
charge = data_cif[data_cif.keys()[0]]["_atom_site_pbe_cm5_charge"]
np.save(save_cm5_dir + mof + '.npy', charge)

def get_repeat_data(root_cif_dir,save_repeat_dir):
mofs = glob.glob(os.path.join(root_cif_dir, '*.cif'))
for mof in tqdm(mofs[:]):
Expand All @@ -279,6 +303,13 @@ def get_repeat_data(root_cif_dir,save_repeat_dir):
repeat_data.append(repeat.replace("\n",""))
np.save(save_repeat_dir + mof + '.npy', repeat_data)
f.close()

except:
pass
pass

def get_repeat_data_from_arcmof(root_cif_dir,save_repeat_dir):
mofs = glob.glob(os.path.join(root_cif_dir, '*.cif'))
for mof in mofs:
mof = mof.replace(".cif","").split("/")[-1]
data_cif = ReadCif(root_cif_dir + mof + ".cif")
charge = data_cif[data_cif.keys()[0]]["_atom_site_pbe_cm5_charge"]
np.save(save_repeat_dir + mof + '.npy', charge)
Binary file modified model4pre/__pycache__/cif2data.cpython-39.pyc
Binary file not shown.
68 changes: 38 additions & 30 deletions model4pre/cif2data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
import numpy as np
import pymatgen.core as mg
from CifFile import ReadCif
from ase.io import read,write
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.io.cif import CifParser
Expand Down Expand Up @@ -146,7 +147,7 @@ def average_and_replace(numbers, di):
numbers[i] = avg
return numbers

def write4cif(mof,chg,digits,atom_type,neutral,charge_type):
def write4cif(mof,chg,digits,atom_type,neutral,charge_type,keep_connect):
name = mof.split('.cif')[0]
chg = chg.numpy()
dia = int(digits)
Expand Down Expand Up @@ -190,36 +191,43 @@ def write4cif(mof,chg,digits,atom_type,neutral,charge_type):

if neutral==False:
print("net charge: "+str(sum(charges)))
if keep_connect:
mof = ReadCif(name + ".cif")
mof.first_block().AddToLoop("_atom_site_type_symbol",{'_atom_site_charge':[str(q) for q in charges]})
with open(name + "_pacman.cif", 'w') as f:
f.write("# " + charge_type + "charges by PACMAN v1.3 (https://github.com/mtap-research/PACMAN-charge/)\n" +
f"data_{name.split('/')[-1]}" + str(mof.first_block()))
print("Compelete and save as "+ name + "_pacman.cif")
else:
with open(name + ".cif", 'r') as file:
lines = file.readlines()
lines[0] = "# "+charge_type+" charges by PACMAN v1.3 (https://github.com/mtap-research/PACMAN-charge/)\n"
lines[1] = "data_" + name.split("/")[-1] + "_pacman\n"
for i, line in enumerate(lines):
if '_atom_site_occupancy' in line:
lines.insert(i + 1, " _atom_site_charge\n")
break
charge_index = 0
for j in range(i + 2, len(lines)):
if charge_index < len(charges):
lines[j] = lines[j].strip() + " " + str(charges[charge_index]) + "\n"
charge_index += 1
else:
break

with open(name + ".cif", 'r') as file:
lines = file.readlines()
lines[0] = "# "+charge_type+" charges by PACMAN v1.1 (https://github.com/mtap-research/PACMAN-charge/)\n"
lines[1] = "data_" + name.split("/")[-1] + "_pacman\n"
for i, line in enumerate(lines):
if '_atom_site_occupancy' in line:
lines.insert(i + 1, " _atom_site_charge\n")
break
charge_index = 0
for j in range(i + 2, len(lines)):
if charge_index < len(charges):
lines[j] = lines[j].strip() + " " + str(charges[charge_index]) + "\n"
charge_index += 1
else:
break

with open(name + "_pacman.cif", 'w') as file:
file.writelines(lines)
file.close()
with open(name + "_pacman.cif", 'w') as file:
file.writelines(lines)
file.close()

with open(name + "_pacman.cif", 'r') as file:
content = file.read()
file.close()
with open(name + "_pacman.cif", 'r') as file:
content = file.read()
file.close()

new_content = content.replace('_space_group_name_H-M_alt', '_symmetry_space_group_name_H-M')
new_content = new_content.replace('_space_group_IT_number', '_symmetry_Int_Tables_number')
new_content = new_content.replace('_space_group_symop_operation_xyz', '_symmetry_equiv_pos_as_xyz')
new_content = content.replace('_space_group_name_H-M_alt', '_symmetry_space_group_name_H-M')
new_content = new_content.replace('_space_group_IT_number', '_symmetry_Int_Tables_number')
new_content = new_content.replace('_space_group_symop_operation_xyz', '_symmetry_equiv_pos_as_xyz')

with open(name + "_pacman.cif", 'wb') as file:
file.write(new_content.encode('utf-8'))
file.close()
print("Compelete and save as "+ name + "_pacman.cif")
with open(name + "_pacman.cif", 'wb') as file:
file.write(new_content.encode('utf-8'))
file.close()
print("Compelete and save as "+ name + "_pacman.cif")
12 changes: 9 additions & 3 deletions pmcharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ def main():
parser.add_argument('--digits', type=int, default=6, help='Number of decimal places to print for partial atomic charges')
parser.add_argument('--atom_type', type=bool, default=True, help='Keep the same partial atomic charge for the same atom types')
parser.add_argument('--neutral', type=bool, default=True, help='Keep the net charge is zero')
parser.add_argument('--keep_connect', type=bool, default=True, help='Keep information from original CIF file')
args = parser.parse_args()

path = args.folder_name
charge_type = args.charge_type
digits = args.digits
atom_type = args.atom_type
neutral = args.neutral
keep_connect = args.keep_connect
if os.path.isfile(path):
print("please input a folder, not a file")
elif os.path.isdir(path):
Expand Down Expand Up @@ -65,6 +67,7 @@ def main():
print("Digits: " + str(digits))
print("Atom Type:" + str(atom_type))
print("Neutral: " + str(neutral))
print("Keep Connect: " + str(keep_connect))

cif_files = glob.glob(os.path.join(path, '*.cif'))
print("writing cif: ***_pacman.cif")
Expand All @@ -74,7 +77,10 @@ def main():
i = 0
for cif in tqdm(cif_files):
try:
ase_format(cif)
if keep_connect:
pass
else:
ase_format(cif)
cif_data = CIF2json(cif)
pos = pre4pre(cif)
# num_atom = n_atom(cif)
Expand Down Expand Up @@ -114,7 +120,7 @@ def main():
# model_bandgap.eval()

gcn = GCN(chg_1-3, chg_2, 128, 7, 256,5)
chkpt = torch.load(model_charge_name, map_location=torch.device(device), weights_only=True)
chkpt = torch.load(model_charge_name, map_location=torch.device(device)) #, weights_only=True
model4chg = SemiFullGN(chg_1,chg_2,128,8,256)
model4chg.to(device)
model4chg.load_state_dict(chkpt['state_dict'])
Expand Down Expand Up @@ -146,7 +152,7 @@ def main():

chg = model4chg(*input_var2)
chg = charge_nor.denorm(chg.data.cpu())
write4cif(cif,chg,digits,atom_type,neutral,charge_type)
write4cif(cif,chg,digits,atom_type,neutral,charge_type,keep_connect)
except:
print("Fail predict: " + cif)
fail[str(i)]=[cif]
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ tqdm>=4.15
pandas>=0.20.3
scikit-learn>=0.19.1
joblib>= 0.13.2
torch
torch
PyCifRW
Loading

0 comments on commit f410ce4

Please sign in to comment.