Skip to content

Commit

Permalink
fix ff, update and fix rmsd and reordering
Browse files Browse the repository at this point in the history
Signed-off-by: Conrad Hübler <[email protected]>
  • Loading branch information
conradhuebler committed Mar 26, 2024
1 parent f7c84e8 commit 0b37597
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 78 deletions.
177 changes: 109 additions & 68 deletions src/capabilities/rmsd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ void RMSDDriver::LoadControlJson()

std::string method = Json2KeyWord<std::string>(m_defaults, "method");
m_munkress_cycle = 1;
m_limit = 0;
if (method.compare("template") == 0)
m_method = 2;
else if (method.compare("incr") == 0)
Expand All @@ -210,9 +211,10 @@ void RMSDDriver::LoadControlJson()
m_method = 5;
} else if (method.compare("molalign") == 0)
m_method = 6;
else if (method.compare("dtemplate") == 0)
else if (method.compare("dtemplate") == 0) {
m_method = 7;
else
m_limit = Json2KeyWord<int>(m_defaults, "limit");
} else
m_method = 1;

std::string order = Json2KeyWord<std::string>(m_defaults, "order");
Expand All @@ -229,7 +231,7 @@ void RMSDDriver::start()
{
RunTimer timer(false);
clear();

bool rmsd_calculated = false;
if (m_reference.AtomCount() < m_target.AtomCount()) {
m_swap = true;
Molecule tmp = m_reference;
Expand Down Expand Up @@ -266,6 +268,7 @@ void RMSDDriver::start()
m_target_reordered = ApplyOrder(m_reorder_rules, m_target);
m_target = m_target_reordered;
m_rmsd = m_results.begin()->first;
rmsd_calculated = true;
}
}
Molecule temp_ref, temp_tar;
Expand All @@ -281,8 +284,8 @@ void RMSDDriver::start()
if (consent) {
if (m_fragment_reference != -1 && m_fragment_target != -1) {
m_rmsd = CustomRotation();
} /*else
m_rmsd = BestFitRMSD();*/
} else if (!rmsd_calculated)
m_rmsd = BestFitRMSD();
} else {
if (!m_silent)
fmt::print("Partial RMSD is calculated, only from those atoms, that match each other.\n\n\n");
Expand Down Expand Up @@ -461,6 +464,13 @@ void RMSDDriver::ReorderIncremental()
} else {
ref = reference;
storage_shelf = storage_shelf_next;
int index = 0;
for (const auto& i : *storage_shelf.data()) {
m_intermedia_rules.push_back(i.second.first);
if (index > m_limit)
break;
index++;
}
reference_not_reorordered++;
}
wake_up = 2 * pool->WakeUp();
Expand All @@ -475,7 +485,8 @@ void RMSDDriver::ReorderIncremental()
std::vector<int> rule = FillMissing(m_reference, e.second.first);
if (std::find(m_stored_rules.begin(), m_stored_rules.end(), rule) == m_stored_rules.end()) {
m_stored_rules.push_back(rule);
m_results.insert(std::pair<double, std::vector<int>>(e.first, rule));
m_intermedia_rules.push_back(rule);
m_intermediate_cost_matrices.insert(std::pair<double, std::vector<int>>(e.first, rule));
m_stored_rotations.push_back(e.second.second);
count++;
}
Expand Down Expand Up @@ -778,7 +789,7 @@ void RMSDDriver::DistanceTemplate()
if (!m_silent)
std::cout << "Prepare template structure on atom distances:" << std::endl;

auto pairs = PrepareDistanceTemplate(10);
auto pairs = PrepareDistanceTemplate();
FinaliseTemplate(pairs);

m_target_reordered = ApplyOrder(m_reorder_rules, m_target);
Expand Down Expand Up @@ -841,35 +852,27 @@ void RMSDDriver::TemplateFree()

void RMSDDriver::FinaliseTemplate(std::pair<std::vector<int>, std::vector<int>> pairs)
{
std::vector<int> tmp;
for (int j = 0; j < m_reorder_rules.size(); ++j)
tmp.push_back(j);

Molecule target = m_target;
std::map<double, std::vector<int>> local_results = m_results;
m_results.clear();
std::map<double, Matrix> local_results;
for (const auto& indices : m_intermedia_rules) {
pairs.second = indices;
auto result = MakeCostMatrix(pairs);
local_results.insert(result);
}
std::vector<std::vector<int>> rules = m_stored_rules;
double rmsd_prev = 10;
int eq_counter = 0;
int iter = 0;
RunTimer time;
for (auto permutation : local_results) {
iter++;
pairs.second = permutation.second;
if (pairs.first.size() != pairs.second.size())
continue;
auto result = AlignByVectorPair(pairs);
auto result = SolveCostMatrix(permutation.second);
m_target_reordered = ApplyOrder(result, target);
double rmsd = Rules2RMSD(result);
m_results.insert(std::pair<double, std::vector<int>>(rmsd, result));
std::cout << rmsd << " " << permutation.first << " " << time.Elapsed() << " msecs" << std::endl;
time.Reset();
/*
for(auto i : permutation.second)
std::cout << i << " ";
std::cout << std::endl;*/
/*if (rmsd > rmsd_prev || eq_counter > 1)
break;*/
if (rmsd > rmsd_prev || eq_counter > 3)
break;
eq_counter += std::abs(rmsd - rmsd_prev) < 1e-3;
rmsd_prev = rmsd;
}
Expand Down Expand Up @@ -913,6 +916,7 @@ std::pair<std::vector<int>, std::vector<int>> RMSDDriver::PrepareHeavyTemplate()
for (auto index : m_stored_rules[i])
tmp.push_back(target_indicies[index]);
transformed_rules.push_back(tmp);
m_intermedia_rules.push_back(tmp);
}
m_stored_rules = transformed_rules;
std::vector<int> target_indices = m_reorder_rules;
Expand Down Expand Up @@ -963,6 +967,7 @@ std::pair<std::vector<int>, std::vector<int>> RMSDDriver::PrepareAtomTemplate(in
for (auto index : m_stored_rules[i])
tmp.push_back(target_indicies[index]);
transformed_rules.push_back(tmp);
m_intermedia_rules.push_back(tmp);
}
m_stored_rules = transformed_rules;
std::vector<int> target_indices = m_reorder_rules;
Expand Down Expand Up @@ -1012,6 +1017,7 @@ std::pair<std::vector<int>, std::vector<int>> RMSDDriver::PrepareAtomTemplate(co
for (auto index : m_stored_rules[i])
tmp.push_back(target_indicies[index]);
transformed_rules.push_back(tmp);
m_intermedia_rules.push_back(tmp);
}
m_stored_rules = transformed_rules;
std::vector<int> target_indices = m_reorder_rules;
Expand All @@ -1024,7 +1030,7 @@ std::pair<std::vector<int>, std::vector<int>> RMSDDriver::PrepareAtomTemplate(co
return std::pair<std::vector<int>, std::vector<int>>(reference_indicies, target_indices);
}

std::pair<std::vector<int>, std::vector<int>> RMSDDriver::PrepareDistanceTemplate(int number) // const std::vector<int>& templateatom)
std::pair<std::vector<int>, std::vector<int>> RMSDDriver::PrepareDistanceTemplate() // const std::vector<int>& templateatom)
{
std::cout << "Start Prepare Template" << std::endl;
RunTimer time;
Expand Down Expand Up @@ -1060,7 +1066,7 @@ std::pair<std::vector<int>, std::vector<int>> RMSDDriver::PrepareDistanceTemplat
auto ref_end = m_distance_reference.cend();
auto tar_end = m_distance_target.cend();

while (reference_indicies.size() < number) {
while (reference_indicies.size() < m_limit) {
ref_end--;
tar_end--;
std::pair<int, Position> atom_r1 = m_reference.Atom(ref_end->second.first);
Expand All @@ -1070,25 +1076,19 @@ std::pair<std::vector<int>, std::vector<int>> RMSDDriver::PrepareDistanceTemplat
std::pair<int, Position> atom_t2 = m_target.Atom(tar_end->second.second);

if (atom_r1.first == atom_t1.first && std::find(reference_indicies.begin(), reference_indicies.end(), ref_end->second.first) == reference_indicies.end() && std::find(target_indicies.begin(), target_indicies.end(), tar_end->second.first) == target_indicies.end()) {
// if (std::find(reference_indicies.begin(), reference_indicies.end(), ref_end->second.first) != reference_indicies.end())
{
reference.addPair(atom_r1);
reference_indicies.push_back(ref_end->second.first);
}
// if (std::find(target_indicies.begin(), target_indicies.end(), tar_end->second.first) != target_indicies.end())
{

target.addPair(atom_t1);
target_indicies.push_back(tar_end->second.first);
}
}
if (atom_r2.first == atom_t2.first && std::find(reference_indicies.begin(), reference_indicies.end(), ref_end->second.second) == reference_indicies.end() && std::find(target_indicies.begin(), target_indicies.end(), tar_end->second.second) == target_indicies.end()) {
// if (std::find(reference_indicies.begin(), reference_indicies.end(), ref_end->second.second) != reference_indicies.end())
{
reference.addPair(atom_r2);
reference_indicies.push_back(ref_end->second.second);
}
// if (std::find(target_indicies.begin(), target_indicies.end(), tar_end->second.second) != target_indicies.end())
{

target.addPair(atom_t2);
target_indicies.push_back(tar_end->second.second);
}
Expand All @@ -1105,37 +1105,15 @@ std::pair<std::vector<int>, std::vector<int>> RMSDDriver::PrepareDistanceTemplat
m_init_count = m_heavy_init;

ReorderIncremental();
m_results.clear();

std::map<double, std::vector<int>> order_rules;
std::vector<std::vector<int>> transformed_rules;
for (int i = 0; i < m_stored_rules.size(); ++i) {
double sum = 0;
Geometry tmp_geom = target_geometry * m_stored_rotations[i];
auto ref_end = distance_reference.cend();
auto tar_end = distance_target.cend();
ref_end--;
tar_end--;
double d_max = ref_end->first;
while (ref_end != distance_reference.cbegin() && tar_end != distance_target.cbegin()) {
if (std::abs(ref_end->first - d_max) > 1e-1)
break;
int ref = (ref_end->second);

int tar = (tar_end->second);
double d1 = (reference_geometry.row(ref) - tmp_geom.row(tar)).norm();

sum += d1;
// std::cout << sum << " ";
ref_end--;
tar_end--;
}

std::vector<int> tmp;
for (auto index : m_stored_rules[i])
tmp.push_back(target_indicies[index]);

m_results.insert(std::pair<double, std::vector<int>>(sum, tmp));
m_intermedia_rules.push_back(tmp);
}
m_stored_rules.clear();

Expand All @@ -1145,7 +1123,6 @@ std::pair<std::vector<int>, std::vector<int>> RMSDDriver::PrepareDistanceTemplat
m_target = cached_target_mol;
m_reference = cached_reference_mol;
m_target = cached_target_mol;
std::cout << time.Elapsed() << " msecs for init" << std::endl;
return std::pair<std::vector<int>, std::vector<int>>(reference_indicies, target_indices);
}

Expand All @@ -1172,6 +1149,65 @@ std::vector<int> RMSDDriver::AlignByVectorPair(std::vector<int> first, std::vect
return DistanceReorder(ref_mol, tar_mol);
}

std::pair<double, Matrix> RMSDDriver::MakeCostMatrix(const std::pair<std::vector<int>, std::vector<int>>& pair)
{
auto operators = GetOperateVectors(pair.first, pair.second);
Eigen::Matrix3d R = operators.first;

Geometry cached_reference = m_reference.getGeometry(pair.first, m_protons);
Geometry cached_target = m_target.getGeometry(pair.second, m_protons);
Geometry ref = GeometryTools::TranslateMolecule(m_reference, m_reference.Centroid(), Position{ 0, 0, 0 });
Geometry tget = GeometryTools::TranslateMolecule(m_target, m_target.Centroid(), Position{ 0, 0, 0 });

Eigen::MatrixXd tar = tget.transpose();

Geometry rotated = tar.transpose() * R;

Molecule ref_mol = m_reference;
ref_mol.setGeometry(ref);

Molecule tar_mol = m_target;
tar_mol.setGeometry(rotated);

double penalty = 100;

std::vector<int> new_order;

Eigen::MatrixXd distance = Eigen::MatrixXd::Zero(ref_mol.AtomCount(), ref_mol.AtomCount());
std::vector<int> element_reference = ref_mol.Atoms();
std::vector<int> element_target = tar_mol.Atoms();
double min = penalty;
double sum = 0;
for (int i = 0; i < ref_mol.AtomCount(); ++i) {
double min = penalty;
for (int j = 0; j < tar_mol.AtomCount(); ++j) {
distance(i, j) = GeometryTools::Distance(tar_mol.Atom(j).second, ref_mol.Atom(i).second) * GeometryTools::Distance(tar_mol.Atom(j).second, ref_mol.Atom(i).second)
+ penalty * (tar_mol.Atom(j).first != tar_mol.Atom(i).first);
min = std::min(min, GeometryTools::Distance(tar_mol.Atom(j).second, ref_mol.Atom(i).second) * GeometryTools::Distance(tar_mol.Atom(j).second, ref_mol.Atom(i).second));
}
sum += min;
}

return std::pair<double, Matrix>(sum, distance);
}

std::vector<int> RMSDDriver::SolveCostMatrix(const Matrix& distance)
{
std::vector<int> new_order;

auto result = MunkressAssign(distance);

for (int i = 0; i < result.cols(); ++i) {
for (int j = 0; j < result.rows(); ++j) {
if (result(i, j) == 1) {
new_order.push_back(j);
break;
}
}
}
return new_order;
}

Molecule RMSDDriver::ApplyOrder(const std::vector<int>& order, const Molecule& mol)
{
Molecule result;
Expand Down Expand Up @@ -1211,9 +1247,16 @@ std::pair<Matrix, Position> RMSDDriver::GetOperateVectors(const std::vector<int>
{
Molecule reference_mol;
Molecule target_mol;
for (int i = 0; i < reference_atoms.size(); ++i) {
reference_mol.addPair(m_reference.Atom(reference_atoms[i]));
target_mol.addPair(m_target.Atom(target_atoms[i]));
if (reference_atoms.size() == target_atoms.size()) {
for (int i = 0; i < reference_atoms.size(); ++i) {
reference_mol.addPair(m_reference.Atom(reference_atoms[i]));
target_mol.addPair(m_target.Atom(target_atoms[i]));
}
} else {
for (int i = 0; i < target_atoms.size(); ++i) {
reference_mol.addPair(m_reference.Atom(i));
target_mol.addPair(m_target.Atom(target_atoms[i]));
}
}
return GetOperateVectors(reference_mol, target_mol);
}
Expand Down Expand Up @@ -1276,7 +1319,6 @@ std::vector<int> RMSDDriver::DistanceReorder(const Molecule& reference, const Mo
for (int i = 0; i < m_munkress_cycle; ++i) {
std::vector<int> munkress = Munkress(reference, target);
double rmsdM = Rules2RMSD(munkress);
// std::cout << rmsdM << ": ";
if (rmsdM > rmsd)
break;
best = munkress;
Expand All @@ -1285,7 +1327,7 @@ std::vector<int> RMSDDriver::DistanceReorder(const Molecule& reference, const Mo
}
return best;
}

/*
std::vector<int> RMSDDriver::FillOrder(const Molecule& reference, const Molecule& target, const std::vector<int>& order)
{
Molecule ref = reference, tar = target;
Expand Down Expand Up @@ -1350,7 +1392,7 @@ std::vector<int> RMSDDriver::FillOrder(const Molecule& reference, const Molecule
}
return new_order;
}

*/
std::vector<int> RMSDDriver::Munkress(const Molecule& reference, const Molecule& target)
{
double penalty = 100;
Expand Down Expand Up @@ -1379,18 +1421,17 @@ std::vector<int> RMSDDriver::Munkress(const Molecule& reference, const Molecule&
}
sum += min;
}
/*
if (m_dmix <= 1 && 0 < m_dmix) {
Matrix d = target.DistanceMatrix().first;
distance = (1 - m_dmix) * distance + m_dmix * d;
}
std::cout << sum / double(reference.AtomCount() * reference.AtomCount()) << " :: ";
}*/
// std::cout << sum / double(reference.AtomCount() * reference.AtomCount()) << " :: ";
auto result = MunkressAssign(distance);

for (int i = 0; i < result.cols(); ++i) {
for (int j = 0; j < result.rows(); ++j) {
if (result(i, j) == 1) {
if (target.Atom(j).first != reference.Atom(i).first)
std::cout << "hilfe " << distance(i, j) << distance(j, i) << " ";
new_order.push_back(j);
break;
}
Expand Down
Loading

0 comments on commit 0b37597

Please sign in to comment.