-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add homonuclear diatomic rank * formatting
- Loading branch information
1 parent
92e3c26
commit f9d3b3b
Showing
2 changed files
with
172 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,134 @@ | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import plotly.colors as pcolors | ||
import plotly.graph_objects as go | ||
import streamlit as st | ||
from ase.data import chemical_symbols | ||
from plotly.subplots import make_subplots | ||
from scipy.interpolate import CubicSpline | ||
|
||
from mlip_arena.models import REGISTRY as MODELS | ||
|
||
valid_models = [ | ||
model | ||
for model, metadata in MODELS.items() | ||
if Path(__file__).stem in metadata.get("gpu-tasks", []) | ||
] | ||
|
||
DATA_DIR = Path("mlip_arena/tasks/diatomics") | ||
|
||
dfs = [ | ||
pd.read_json(DATA_DIR / MODELS[model].get("family") / "homonuclear-diatomics.json") | ||
for model in valid_models | ||
] | ||
df = pd.concat(dfs, ignore_index=True) | ||
|
||
table = pd.DataFrame() | ||
|
||
for model in valid_models: | ||
rows = df[df["method"] == model] | ||
metadata = MODELS.get(model, {}) | ||
|
||
new_row = { | ||
"Model": model, | ||
"Conservation deviation [eV/Å]": rows["conservation-deviation"].mean(), | ||
"Spearman's coeff. (Energy - repulsion)": rows[ | ||
"spearman-repulsion-energy" | ||
].mean(), | ||
"Spearman's coeff. (Force - descending)": rows[ | ||
"spearman-descending-force" | ||
].mean(), | ||
"Tortuosity": rows["tortuosity"].mean(), | ||
"Energy jump [eV]": rows["energy-jump"].mean(), | ||
"Force flips": rows["force-flip-times"].mean(), | ||
"Spearman's coeff. (Energy - attraction)": rows[ | ||
"spearman-attraction-energy" | ||
].mean(), | ||
"Spearman's coeff. (Force - ascending)": rows[ | ||
"spearman-ascending-force" | ||
].mean(), | ||
} | ||
|
||
table = pd.concat([table, pd.DataFrame([new_row])], ignore_index=True) | ||
|
||
table.set_index("Model", inplace=True) | ||
|
||
table.sort_values("Conservation deviation [eV/Å]", ascending=True, inplace=True) | ||
table["Rank"] = np.argsort(table["Conservation deviation [eV/Å]"].to_numpy()) | ||
|
||
table.sort_values( | ||
"Spearman's coeff. (Energy - repulsion)", ascending=True, inplace=True | ||
) | ||
table["Rank"] += np.argsort(table["Spearman's coeff. (Energy - repulsion)"].to_numpy()) | ||
|
||
table.sort_values( | ||
"Spearman's coeff. (Force - descending)", ascending=True, inplace=True | ||
) | ||
table["Rank"] += np.argsort(table["Spearman's coeff. (Force - descending)"].to_numpy()) | ||
|
||
table.sort_values("Tortuosity", ascending=True, inplace=True) | ||
table["Rank"] += np.argsort(table["Tortuosity"].to_numpy()) | ||
|
||
table.sort_values("Energy jump [eV]", ascending=True, inplace=True) | ||
table["Rank"] += np.argsort(table["Energy jump [eV]"].to_numpy()) | ||
|
||
table.sort_values("Force flips", ascending=True, inplace=True) | ||
table["Rank"] += np.argsort(table["Force flips"].to_numpy()) | ||
|
||
table.sort_values("Rank", ascending=True, inplace=True) | ||
|
||
table["Rank aggr."] = table["Rank"] | ||
|
||
table["Rank"] = np.argsort(table["Rank"].to_numpy()) + 1 | ||
|
||
# table.drop(columns=["rank"], inplace=True) | ||
# table = table.rename(columns={"Rank": "Rank Aggr."}) | ||
|
||
table = table.reindex( | ||
columns=[ | ||
"Rank", | ||
"Rank aggr.", | ||
"Conservation deviation [eV/Å]", | ||
"Spearman's coeff. (Energy - repulsion)", | ||
"Spearman's coeff. (Force - descending)", | ||
"Tortuosity", | ||
"Energy jump [eV]", | ||
"Force flips", | ||
"Spearman's coeff. (Energy - attraction)", | ||
"Spearman's coeff. (Force - ascending)", | ||
] | ||
) | ||
|
||
def get_rank_page(): | ||
# st.markdown("""HIHI""") | ||
pass | ||
s = ( | ||
table.style.background_gradient( | ||
cmap="viridis_r", | ||
subset=["Conservation deviation [eV/Å]"], | ||
gmap=np.log(table["Conservation deviation [eV/Å]"].to_numpy()), | ||
) | ||
.background_gradient( | ||
cmap="Reds", | ||
subset=[ | ||
"Spearman's coeff. (Energy - repulsion)", | ||
"Spearman's coeff. (Force - descending)", | ||
], | ||
# vmin=-1, vmax=-0.5 | ||
) | ||
.background_gradient( | ||
cmap="RdPu", | ||
subset=["Tortuosity", "Energy jump [eV]", "Force flips"], | ||
) | ||
.background_gradient( | ||
cmap="Blues", | ||
subset=["Rank", "Rank aggr."], | ||
) | ||
) | ||
|
||
# def plot(): | ||
# pass | ||
|
||
# if __name__ == '__main__': | ||
# pass | ||
def render(): | ||
st.dataframe( | ||
s, | ||
use_container_width=True, | ||
) | ||
# return table |