Skip to content

Commit

Permalink
Lammps PBC fix (#670)
Browse files Browse the repository at this point in the history
* lammps pbc fix

* refactor variables in lammps interface

* update example model
  • Loading branch information
ken-sc01 authored Nov 11, 2024
1 parent 662ea1c commit b443376
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 57 deletions.
Binary file modified interfaces/lammps/examples/aspirin/best_model
Binary file not shown.
103 changes: 46 additions & 57 deletions interfaces/lammps/pair_schnetpack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,15 @@ void PairSCHNETPACK::compute(int eflag, int vflag){
// Total number of bonds (sum of number of neighbors)
int nedges = std::accumulate(numneigh, numneigh+ntotal, 0);

torch::Tensor pos_tensor = torch::zeros({nlocal, 3});
torch::Tensor tag2type_tensor = torch::zeros({nlocal}, torch::TensorOptions().dtype(torch::kInt64));
torch::Tensor periodic_shift_tensor = torch::zeros({3});
torch::Tensor positions_tensor = torch::zeros({nlocal, 3});
torch::Tensor atomic_numbers_tensor = torch::zeros({nlocal}, torch::TensorOptions().dtype(torch::kInt64));
torch::Tensor cell_tensor = torch::zeros({3,3});

auto pos = pos_tensor.accessor<float, 2>();
long edges[2*nedges];
float edge_cell_shifts[3*nedges];
auto tag2type = tag2type_tensor.accessor<long, 1>();
auto periodic_shift = periodic_shift_tensor.accessor<float, 1>();
auto positions = positions_tensor.accessor<float, 2>();
long idx_i[nedges];
long idx_j[nedges];
float offsets[3*nedges];
auto atomic_numbers = atomic_numbers_tensor.accessor<long, 1>();
auto cell = cell_tensor.accessor<float,2>();

// Inverse mapping from tag to "real" atom index
Expand All @@ -216,10 +215,10 @@ void PairSCHNETPACK::compute(int eflag, int vflag){

// Inverse mapping from tag to x/f atom index
tag2i[itag-1] = i; // tag is probably 1-based
tag2type[itag-1] = type_mapper[itype];
pos[itag-1][0] = x[i][0];
pos[itag-1][1] = x[i][1];
pos[itag-1][2] = x[i][2];
atomic_numbers[itag-1] = type_mapper[itype]; //tag to type
positions[itag-1][0] = x[i][0];
positions[itag-1][1] = x[i][1];
positions[itag-1][2] = x[i][2];
}

// Get cell
Expand All @@ -232,15 +231,12 @@ void PairSCHNETPACK::compute(int eflag, int vflag){
cell[2][1] = domain->yz;
cell[2][2] = domain->boxhi[2] - domain->boxlo[2];


auto cell_inv = cell_tensor.inverse().transpose(0,1);

// Loop over atoms and neighbors,
// store edges and _cell_shifts
// store edges and offsets
// ii follows the order of the neighbor lists,
// i follows the order of x, f, etc.
int edge_counter = 0;
if (debug_mode) printf("SchNetPack edges: i j xi[:] xj[:] cell_shift[:] rij\n");
if (debug_mode) printf("SchNetPack edges: i j xi[:] xj[:] offset[:] rij\n");
for(int ii = 0; ii < nlocal; ii++){
int i = ilist[ii];
int itag = tag[i];
Expand All @@ -254,33 +250,25 @@ void PairSCHNETPACK::compute(int eflag, int vflag){
int jtag = tag[j];
int jtype = type[j];

// TODO: check sign
periodic_shift[0] = x[j][0] - pos[jtag-1][0];
periodic_shift[1] = x[j][1] - pos[jtag-1][1];
periodic_shift[2] = x[j][2] - pos[jtag-1][2];

double dx = x[i][0] - x[j][0];
double dy = x[i][1] - x[j][1];
double dz = x[i][2] - x[j][2];

double rsq = dx*dx + dy*dy + dz*dz;
if (rsq < cutoff*cutoff){
torch::Tensor cell_shift_tensor = cell_inv.matmul(periodic_shift_tensor);
auto cell_shift = cell_shift_tensor.accessor<float, 1>();
float * e_vec = &edge_cell_shifts[edge_counter*3];
e_vec[0] = std::round(cell_shift[0]);
e_vec[1] = std::round(cell_shift[1]);
e_vec[2] = std::round(cell_shift[2]);
//std::cout << "cell shift: " << cell_shift_tensor << "\n";
float * e_vec = &offsets[edge_counter*3];
e_vec[0] = x[j][0] - positions[jtag-1][0];
e_vec[1] = x[j][1] - positions[jtag-1][1];
e_vec[2] = x[j][2] - positions[jtag-1][2];

// TODO: double check order
edges[edge_counter*2] = itag - 1; // tag is probably 1-based
edges[edge_counter*2+1] = jtag - 1; // tag is probably 1-based
idx_i[edge_counter] = itag - 1; // tag is probably 1-based
idx_j[edge_counter] = jtag - 1; // tag is probably 1-based
edge_counter++;

if (debug_mode){
printf("%d %d %.10g %.10g %.10g %.10g %.10g %.10g %.10g %.10g %.10g %.10g\n", itag-1, jtag-1,
pos[itag-1][0],pos[itag-1][1],pos[itag-1][2],pos[jtag-1][0],pos[jtag-1][1],pos[jtag-1][2],
positions[itag-1][0],positions[itag-1][1],positions[itag-1][2],positions[jtag-1][0],positions[jtag-1][1],positions[jtag-1][2],
e_vec[0],e_vec[1],e_vec[2],sqrt(rsq));
}

Expand All @@ -289,51 +277,52 @@ void PairSCHNETPACK::compute(int eflag, int vflag){
}
if (debug_mode) printf("end SchNetPack edges\n");

// shorten the list before sending to nequip
torch::Tensor edges_tensor = torch::zeros({2,edge_counter}, torch::TensorOptions().dtype(torch::kInt64));
torch::Tensor edge_cell_shifts_tensor = torch::zeros({edge_counter,3});
auto new_edges = edges_tensor.accessor<long, 2>();
auto new_edge_cell_shifts = edge_cell_shifts_tensor.accessor<float, 2>();
// transform to torch tensors as expected by SchNetPack
torch::Tensor idx_i_tensor = torch::zeros(edge_counter, torch::TensorOptions().dtype(torch::kInt64));
torch::Tensor idx_j_tensor = torch::zeros(edge_counter, torch::TensorOptions().dtype(torch::kInt64));
torch::Tensor offsets_tensor = torch::zeros({edge_counter,3});
auto new_idx_i = idx_i_tensor.accessor<long, 1>();
auto new_idx_j = idx_j_tensor.accessor<long, 1>();
auto new_offsets = offsets_tensor.accessor<float, 2>();
for (int i=0; i<edge_counter; i++){

long *e=&edges[i*2];
new_edges[0][i] = e[0];
new_edges[1][i] = e[1];
new_idx_i[i] = idx_i[i];
new_idx_j[i] = idx_j[i];

float *ev = &edge_cell_shifts[i*3];
new_edge_cell_shifts[i][0] = ev[0];
new_edge_cell_shifts[i][1] = ev[1];
new_edge_cell_shifts[i][2] = ev[2];
float *ev = &offsets[i*3];
new_offsets[i][0] = ev[0];
new_offsets[i][1] = ev[1];
new_offsets[i][2] = ev[2];
}

// define SchNetPack specific inputs
torch::Tensor idx_m = torch::zeros({nlocal}, torch::TensorOptions().dtype(torch::kInt64));
torch::Tensor idx_m_tensor = torch::zeros({nlocal}, torch::TensorOptions().dtype(torch::kInt64));

// define SchNetPack n_atoms input
torch::Tensor n_atoms_tensor = torch::zeros({1}, torch::TensorOptions().dtype(torch::kInt64));
n_atoms_tensor[0] = nlocal;


c10::Dict<std::string, torch::Tensor> input;
input.insert("_positions", pos_tensor.to(device));
input.insert("_idx_i", edges_tensor[0].to(device));
input.insert("_idx_j", edges_tensor[1].to(device));
input.insert("_idx_m", idx_m.to(device));
input.insert("_offsets", edge_cell_shifts_tensor.to(device));
input.insert("_positions", positions_tensor.to(device));
input.insert("_idx_i", idx_i_tensor.to(device));
input.insert("_idx_j", idx_j_tensor.to(device));
input.insert("_idx_m", idx_m_tensor.to(device));
input.insert("_offsets", offsets_tensor.to(device));
input.insert("_cell", cell_tensor.to(device));
input.insert("_n_atoms", n_atoms_tensor.to(device));
input.insert("_atomic_numbers", tag2type_tensor.to(device));
input.insert("_atomic_numbers", atomic_numbers_tensor.to(device));
std::vector<torch::IValue> input_vector(1, input);

if(debug_mode){
std::cout << "SchNetPack model input:\n";
std::cout << "_positions:\n" << pos_tensor << "\n";
std::cout << "_idx_i:\n" << edges_tensor[0] << "\n";
std::cout << "_idx_j:\n" << edges_tensor[1] << "\n";
std::cout << "_idx_m:\n" << idx_m << "\n";
std::cout << "_offsets:\n" << edge_cell_shifts_tensor << "\n";
std::cout << "_positions:\n" << positions_tensor << "\n";
std::cout << "_idx_i:\n" << idx_i_tensor << "\n";
std::cout << "_idx_j:\n" << idx_j_tensor << "\n";
std::cout << "_idx_m:\n" << idx_m_tensor << "\n";
std::cout << "_offsets:\n" << offsets_tensor << "\n";
std::cout << "_cell:\n" << cell_tensor << "\n";
std::cout << "_atomic_numbers:\n" << tag2type_tensor << "\n";
std::cout << "_atomic_numbers:\n" << atomic_numbers_tensor << "\n";
}

auto output = model.forward(input_vector).toGenericDict();
Expand Down

0 comments on commit b443376

Please sign in to comment.