|
| 1 | +"""Pydantic model for default configuration and validation.""" |
| 2 | +"""Implementation based on the template of ALIGNN.""" |
| 3 | + |
| 4 | +import subprocess |
| 5 | +from typing import Optional, Union |
| 6 | +import os |
| 7 | +from pydantic import root_validator |
| 8 | + |
| 9 | +# vfrom pydantic import Field, root_validator, validator |
| 10 | +from pydantic.typing import Literal |
| 11 | +from matformer.utils import BaseSettings |
| 12 | +from matformer.models.pyg_att import MatformerConfig |
| 13 | + |
| 14 | +# from typing import List |
| 15 | + |
| 16 | +try: |
| 17 | + VERSION = ( |
| 18 | + subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip() |
| 19 | + ) |
| 20 | +except Exception as exp: |
| 21 | + VERSION = "NA" |
| 22 | + pass |
| 23 | + |
| 24 | + |
| 25 | +FEATURESET_SIZE = {"basic": 11, "atomic_number": 1, "cfid": 438, "cgcnn": 92} |
| 26 | + |
| 27 | + |
| 28 | +TARGET_ENUM = Literal[ |
| 29 | + "formation_energy_peratom", |
| 30 | + "optb88vdw_bandgap", |
| 31 | + "bulk_modulus_kv", |
| 32 | + "shear_modulus_gv", |
| 33 | + "mbj_bandgap", |
| 34 | + "slme", |
| 35 | + "magmom_oszicar", |
| 36 | + "spillage", |
| 37 | + "kpoint_length_unit", |
| 38 | + "encut", |
| 39 | + "optb88vdw_total_energy", |
| 40 | + "epsx", |
| 41 | + "epsy", |
| 42 | + "epsz", |
| 43 | + "mepsx", |
| 44 | + "mepsy", |
| 45 | + "mepsz", |
| 46 | + "max_ir_mode", |
| 47 | + "min_ir_mode", |
| 48 | + "n-Seebeck", |
| 49 | + "p-Seebeck", |
| 50 | + "n-powerfact", |
| 51 | + "p-powerfact", |
| 52 | + "ncond", |
| 53 | + "pcond", |
| 54 | + "nkappa", |
| 55 | + "pkappa", |
| 56 | + "ehull", |
| 57 | + "exfoliation_energy", |
| 58 | + "dfpt_piezo_max_dielectric", |
| 59 | + "dfpt_piezo_max_eij", |
| 60 | + "dfpt_piezo_max_dij", |
| 61 | + "gap pbe", |
| 62 | + "e_form", |
| 63 | + "e_hull", |
| 64 | + "energy_per_atom", |
| 65 | + "formation_energy_per_atom", |
| 66 | + "band_gap", |
| 67 | + "e_above_hull", |
| 68 | + "mu_b", |
| 69 | + "bulk modulus", |
| 70 | + "shear modulus", |
| 71 | + "elastic anisotropy", |
| 72 | + "U0", |
| 73 | + "HOMO", |
| 74 | + "LUMO", |
| 75 | + "R2", |
| 76 | + "ZPVE", |
| 77 | + "omega1", |
| 78 | + "mu", |
| 79 | + "alpha", |
| 80 | + "homo", |
| 81 | + "lumo", |
| 82 | + "gap", |
| 83 | + "r2", |
| 84 | + "zpve", |
| 85 | + "U", |
| 86 | + "H", |
| 87 | + "G", |
| 88 | + "Cv", |
| 89 | + "A", |
| 90 | + "B", |
| 91 | + "C", |
| 92 | + "all", |
| 93 | + "target", |
| 94 | + "max_efg", |
| 95 | + "avg_elec_mass", |
| 96 | + "avg_hole_mass", |
| 97 | + "_oqmd_band_gap", |
| 98 | + "_oqmd_delta_e", |
| 99 | + "_oqmd_stability", |
| 100 | + "edos_up", |
| 101 | + "pdos_elast", |
| 102 | + "bandgap", |
| 103 | + "energy_total", |
| 104 | + "net_magmom", |
| 105 | + "b3lyp_homo", |
| 106 | + "b3lyp_lumo", |
| 107 | + "b3lyp_gap", |
| 108 | + "b3lyp_scharber_pce", |
| 109 | + "b3lyp_scharber_voc", |
| 110 | + "b3lyp_scharber_jsc", |
| 111 | + "log_kd_ki", |
| 112 | + "max_co2_adsp", |
| 113 | + "min_co2_adsp", |
| 114 | + "lcd", |
| 115 | + "pld", |
| 116 | + "void_fraction", |
| 117 | + "surface_area_m2g", |
| 118 | + "surface_area_m2cm3", |
| 119 | + "indir_gap", |
| 120 | + "f_enp", |
| 121 | + "final_energy", |
| 122 | + "energy_per_atom", |
| 123 | +] |
| 124 | + |
| 125 | + |
| 126 | +class TrainingConfig(BaseSettings): |
| 127 | + """Training config defaults and validation.""" |
| 128 | + |
| 129 | + version: str = VERSION |
| 130 | + |
| 131 | + # dataset configuration |
| 132 | + dataset: Literal[ |
| 133 | + "dft_3d", |
| 134 | + "megnet", |
| 135 | + ] = "dft_3d" |
| 136 | + target: TARGET_ENUM = "formation_energy_peratom" |
| 137 | + atom_features: Literal["basic", "atomic_number", "cfid", "cgcnn"] = "cgcnn" |
| 138 | + neighbor_strategy: Literal["k-nearest", "voronoi", "pairwise-k-nearest"] = "k-nearest" |
| 139 | + id_tag: Literal["jid", "id", "_oqmd_entry_id"] = "jid" |
| 140 | + |
| 141 | + # logging configuration |
| 142 | + |
| 143 | + # training configuration |
| 144 | + random_seed: Optional[int] = 123 |
| 145 | + classification_threshold: Optional[float] = None |
| 146 | + n_val: Optional[int] = None |
| 147 | + n_test: Optional[int] = None |
| 148 | + n_train: Optional[int] = None |
| 149 | + train_ratio: Optional[float] = 0.8 |
| 150 | + val_ratio: Optional[float] = 0.1 |
| 151 | + test_ratio: Optional[float] = 0.1 |
| 152 | + target_multiplication_factor: Optional[float] = None |
| 153 | + epochs: int = 300 |
| 154 | + batch_size: int = 64 |
| 155 | + weight_decay: float = 0 |
| 156 | + learning_rate: float = 1e-2 |
| 157 | + filename: str = "sample" |
| 158 | + warmup_steps: int = 2000 |
| 159 | + criterion: Literal["mse", "l1", "poisson", "zig"] = "mse" |
| 160 | + optimizer: Literal["adamw", "sgd"] = "adamw" |
| 161 | + scheduler: Literal["onecycle", "none", "step"] = "onecycle" |
| 162 | + pin_memory: bool = False |
| 163 | + save_dataloader: bool = False |
| 164 | + write_checkpoint: bool = True |
| 165 | + write_predictions: bool = True |
| 166 | + store_outputs: bool = True |
| 167 | + progress: bool = True |
| 168 | + log_tensorboard: bool = False |
| 169 | + standard_scalar_and_pca: bool = False |
| 170 | + use_canonize: bool = True |
| 171 | + num_workers: int = 2 |
| 172 | + cutoff: float = 8.0 |
| 173 | + max_neighbors: int = 12 |
| 174 | + keep_data_order: bool = False |
| 175 | + distributed: bool = False |
| 176 | + n_early_stopping: Optional[int] = None # typically 50 |
| 177 | + output_dir: str = os.path.abspath(".") # typically 50 |
| 178 | + matrix_input: bool = False |
| 179 | + pyg_input: bool = False |
| 180 | + use_lattice: bool = False |
| 181 | + use_angle: bool = False |
| 182 | + |
| 183 | + # model configuration |
| 184 | + model = MatformerConfig(name="matformer") |
| 185 | + print(model) |
| 186 | + @root_validator() |
| 187 | + def set_input_size(cls, values): |
| 188 | + """Automatically configure node feature dimensionality.""" |
| 189 | + values["model"].atom_input_features = FEATURESET_SIZE[ |
| 190 | + values["atom_features"] |
| 191 | + ] |
| 192 | + |
| 193 | + return values |
0 commit comments