Skip to content

Commit

Permalink
Basis transfer code (#216)
Browse files Browse the repository at this point in the history
* temp

* update basis_api document

* update loss and argcheck for overlap specific weight

* Update argcheck.py

* Update loss.py

align unit

* Update loss.py
  • Loading branch information
floatingCatty authored Jan 1, 2025
1 parent 03abc02 commit 73ff3c8
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 5 deletions.
1 change: 1 addition & 0 deletions dptb/data/AtomicData.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
AtomicDataDict.HAMILTONIAN_KEY, # new # should be nested
AtomicDataDict.OVERLAP_KEY, # new # should be nested
AtomicDataDict.ENERGY_EIGENVALUE_KEY, # new # should be nested
AtomicDataDict.EIGENVECTOR_KEY, # new # should be nested
AtomicDataDict.ENERGY_WINDOWS_KEY, # new,
AtomicDataDict.BAND_WINDOW_KEY, # new,
AtomicDataDict.NODE_SOC_SWITCH_KEY # new
Expand Down
1 change: 1 addition & 0 deletions dptb/data/_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
ATOM_TYPE_KEY: Final[str] = "atom_types"
# [n_batch, n_kpoint, n_orb]
ENERGY_EIGENVALUE_KEY: Final[str] = "eigenvalue"
EIGENVECTOR_KEY: Final[str] = "eigenvector"

# [n_batch, 2]
ENERGY_WINDOWS_KEY = "ewindow"
Expand Down
3 changes: 2 additions & 1 deletion dptb/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from .dftbsk import DFTBSK
from .hamiltonian import E3Hamiltonian, SKHamiltonian
from .hr2hk import HR2HK
from .energy import Eigenvalues
from .energy import Eigenvalues, Eigh

__all__ = [
build_model,
E3Hamiltonian,
SKHamiltonian,
HR2HK,
Eigenvalues,
Eigh,
NNENV,
NNSK,
MIX,
Expand Down
84 changes: 84 additions & 0 deletions dptb/nn/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,90 @@ def forward(self, data: AtomicDataDict.Type, nk: Optional[int]=None) -> AtomicDa

eigvals.append(torch.linalg.eigvalsh(data[self.h_out_field]))
data[self.out_field] = torch.nested.as_nested_tensor([torch.cat(eigvals, dim=0)])
if nested:
data[AtomicDataDict.KPOINT_KEY] = torch.nested.as_nested_tensor([kpoints])
else:
data[AtomicDataDict.KPOINT_KEY] = kpoints

return data

class Eigh(nn.Module):
def __init__(
self,
idp: Union[OrbitalMapper, None]=None,
h_edge_field: str = AtomicDataDict.EDGE_FEATURES_KEY,
h_node_field: str = AtomicDataDict.NODE_FEATURES_KEY,
h_out_field: str = AtomicDataDict.HAMILTONIAN_KEY,
eigval_field: str = AtomicDataDict.ENERGY_EIGENVALUE_KEY,
eigvec_field: str = AtomicDataDict.EIGENVECTOR_KEY,
s_edge_field: str = None,
s_node_field: str = None,
s_out_field: str = None,
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu")):
super(Eigh, self).__init__()

self.h2k = HR2HK(
idp=idp,
edge_field=h_edge_field,
node_field=h_node_field,
out_field=h_out_field,
dtype=dtype,
device=device,
)

if s_edge_field is not None:
self.s2k = HR2HK(
idp=idp,
overlap=True,
edge_field=s_edge_field,
node_field=s_node_field,
out_field=s_out_field,
dtype=dtype,
device=device,
)

self.overlap = True
else:
self.overlap = False

self.eigval_field = eigval_field
self.eigvec_field = eigvec_field
self.h_out_field = h_out_field
self.s_out_field = s_out_field


def forward(self, data: AtomicDataDict.Type, nk: Optional[int]=None) -> AtomicDataDict.Type:
kpoints = data[AtomicDataDict.KPOINT_KEY]
if kpoints.is_nested:
nested = True
assert kpoints.size(0) == 1
kpoints = kpoints[0]
else:
nested = False
num_k = kpoints.shape[0]
eigvals = []
eigvecs = []
if nk is None:
nk = num_k
for i in range(int(np.ceil(num_k / nk))):
data[AtomicDataDict.KPOINT_KEY] = kpoints[i*nk:(i+1)*nk]
data = self.h2k(data)
if self.overlap:
data = self.s2k(data)
chklowt = torch.linalg.cholesky(data[self.s_out_field])
chklowtinv = torch.linalg.inv(chklowt)
data[self.h_out_field] = (chklowtinv @ data[self.h_out_field] @ torch.transpose(chklowtinv,dim0=1,dim1=2).conj())
else:
data[self.h_out_field] = data[self.h_out_field]

eigval, eigvec = torch.linalg.eigh(data[self.h_out_field])
eigvecs.append(torch.transpose(torch.transpose(chklowtinv,dim0=1,dim1=2).conj() @ eigvec,dim0=1,dim1=2))
eigvals.append(eigval)

data[self.eigval_field] = torch.nested.as_nested_tensor([torch.cat(eigvals, dim=0)])
data[self.eigvec_field] = torch.cat(eigvecs, dim=0)

if nested:
data[AtomicDataDict.KPOINT_KEY] = torch.nested.as_nested_tensor([kpoints])
else:
Expand Down
105 changes: 104 additions & 1 deletion dptb/nnops/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,109 @@ def forward(
# hopping_loss += self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt))

# return hopping_loss + onsite_loss

@Loss.register("eig_ham")
class EigHamLoss(nn.Module):
def __init__(
self,
basis: Dict[str, Union[str, list]]=None,
idp: Union[OrbitalMapper, None]=None,
overlap: bool=False,
onsite_shift: bool=False,
dtype: Union[str, torch.dtype] = torch.float32,
device: Union[str, torch.device] = torch.device("cpu"),
diff_on: bool=False,
eout_weight: float=0.01,
diff_weight: float=0.01,
diff_valence: dict=None,
spin_deg: int = 2,
coeff_ham: float=1.,
coeff_ovp: float=1.,
**kwargs,
):
super(EigHamLoss, self).__init__()
self.loss1 = nn.L1Loss()
self.loss2 = nn.MSELoss()
self.overlap = overlap
self.device = device
self.onsite_shift = onsite_shift
self.coeff_ham = coeff_ham
assert self.coeff_ham <= 1.
self.coeff_ovp = coeff_ovp

if basis is not None:
self.idp = OrbitalMapper(basis, method="e3tb", device=self.device)
if idp is not None:
assert idp == self.idp, "The basis of idp and basis should be the same."
else:
assert idp is not None, "Either basis or idp should be provided."
self.idp = idp

self.eigloss = EigLoss(
idp=self.idp,
overlap=overlap,
diff_on=diff_on,
eout_weight=eout_weight,
diff_weight=diff_weight,
diff_valence=diff_valence,
spin_deg=spin_deg,
dtype=dtype,
device=device,
)

def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
# mask the data

if self.onsite_shift:
batch = data.get("batch", torch.zeros(data[AtomicDataDict.POSITIONS_KEY].shape[0]))
# assert batch.max() == 0, "The onsite shift is only supported for batchsize=1."
mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
if batch.max() == 0: # when batchsize is zero
mu = mu.mean().detach()
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]
elif batch.max() >= 1:
slices = [data["__slices__"]["pos"][i]-data["__slices__"]["pos"][i-1] for i in range(1,len(data["__slices__"]["pos"]))]
slices = [0] + slices
ndiag_batch = torch.stack([i.sum() for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
ndiag_batch = torch.cumsum(ndiag_batch, dim=0)
mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
mu = mu.detach()
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu[batch, None] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
edge_mu_index = torch.zeros(data[AtomicDataDict.EDGE_INDEX_KEY].shape[1], dtype=torch.long, device=self.device)
for i in range(1, batch.max().item()+1):
edge_mu_index[data["__slices__"]["edge_index"][i]:data["__slices__"]["edge_index"][i+1]] += i
ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu[edge_mu_index, None] * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]

pre = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
tgt = ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_nrme[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
onsite_loss = 0.5*(self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt)))

pre = data[AtomicDataDict.EDGE_FEATURES_KEY][self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
tgt = ref_data[AtomicDataDict.EDGE_FEATURES_KEY][self.idp.mask_to_erme[ref_data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
hopping_loss = 0.5*(self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt)))

if self.overlap:
pre = data[AtomicDataDict.EDGE_OVERLAP_KEY][self.idp.mask_to_erme[data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
tgt = ref_data[AtomicDataDict.EDGE_OVERLAP_KEY][self.idp.mask_to_erme[ref_data[AtomicDataDict.EDGE_TYPE_KEY].flatten()]]
overlap_loss = 0.5*(self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt)))

pre = data[AtomicDataDict.NODE_OVERLAP_KEY][self.idp.mask_to_nrme[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
tgt = ref_data[AtomicDataDict.NODE_OVERLAP_KEY][self.idp.mask_to_nrme[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
overlap_loss += 0.5*(self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt)))

ham_loss = (1/3) * (hopping_loss + onsite_loss + (self.coeff_ovp / self.coeff_ham) * overlap_loss)
else:
ham_loss = 0.5 * (onsite_loss + hopping_loss)

eigloss = self.eigloss(data, ref_data)

return self.coeff_ham * ham_loss + (1 - self.coeff_ham) * eigloss






@Loss.register("hamil_abs")
Expand Down Expand Up @@ -1004,4 +1107,4 @@ def __cal_norm__(self, irreps: Irreps, x: torch.Tensor):
tensor = tensor.norm(dim=-1)
out.append(tensor)

return torch.cat(out, dim=-1).squeeze(0)
return torch.cat(out, dim=-1).squeeze(0)
8 changes: 7 additions & 1 deletion dptb/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,11 @@ def loss_options():
Argument("spin_deg", int, optional=True, default=2, doc="The spin degeneracy of band structure. Default: 2"),
]

eig_ham = [
Argument("coeff_ham", float, optional=True, default=1., doc="The coefficient of the hamiltonian penalty. Default: 1"),
Argument("coeff_ovp", float, optional=True, default=1., doc="The coefficient of the hamiltonian penalty. Default: 1"),
]

skints = [
Argument("skdata", str, optional=False, doc="The path to the skfile or sk database."),
]
Expand All @@ -848,6 +853,7 @@ def loss_options():
Argument("hamil_abs", dict, sub_fields=hamil),
Argument("hamil_blas", dict, sub_fields=hamil),
Argument("hamil_wt", dict, sub_fields=hamil+wt),
Argument("eig_ham", dict, sub_fields=hamil+eigvals+eig_ham),
], optional=False, doc=doc_method)


Expand Down Expand Up @@ -1750,4 +1756,4 @@ def normalize_skf2nnsk(data):
base.check_value(data, strict=True)

return data


4 changes: 2 additions & 2 deletions dptb/utils/config_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def check_config_train(
if train_data_config.get("get_Hamiltonian") and not train_data_config.get("get_eigenvalues"):
assert jdata['train_options']['loss_options']['train'].get("method").startswith("hamil")

if train_data_config.get("get_Hamiltonian") and train_data_config.get("get_eigenvalues"):
raise RuntimeError("The train data set should not have both get_Hamiltonian and get_eigenvalues set to True.")
# if train_data_config.get("get_Hamiltonian") and train_data_config.get("get_eigenvalues"):
# raise RuntimeError("The train data set should not have both get_Hamiltonian and get_eigenvalues set to True.")

#if jdata["data_options"].get("validation"):

Expand Down

0 comments on commit 73ff3c8

Please sign in to comment.