-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
add mlip diatomic curves
1 parent
f6fdc0c
commit d9ed521
Showing
11 changed files
with
342 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,127 @@ | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import numpy.linalg as LA | ||
import plotly.express as px | ||
import pandas as pd | ||
import plotly.colors as pcolors | ||
import plotly.graph_objects as go | ||
import streamlit as st | ||
from ase.data import chemical_symbols | ||
from ase.io import read | ||
from plotly.subplots import make_subplots | ||
from scipy.interpolate import CubicSpline | ||
|
||
st.markdown("# Homonuclear diatomics") | ||
color_sequence = pcolors.qualitative.Plotly | ||
|
||
DATA_DIR = Path("mlip_arena/tasks/diatomics") | ||
|
||
|
||
for i, symbol in enumerate(chemical_symbols[1:10]): | ||
st.markdown("# Homonuclear diatomics") | ||
|
||
if i % 3 == 0: | ||
cols = st.columns(3) | ||
# button to toggle plots | ||
container = st.container(border=True) | ||
energy_plot = container.checkbox("Show energy curves", value=True) | ||
force_plot = container.checkbox("Show force curves", value=False) | ||
|
||
fpath = DATA_DIR / "gpaw" / f"{symbol+symbol}_AFM" / "traj.extxyz" | ||
ncols = 2 | ||
|
||
if not fpath.exists(): | ||
continue | ||
|
||
trj = read(fpath, index=":") | ||
DATA_DIR = Path("mlip_arena/tasks/diatomics") | ||
mlips = ["MACE-MP", "CHGNet"] | ||
|
||
rs, es, s2s = [], [], [] | ||
dfs = [pd.read_json(DATA_DIR / mlip.lower() / "homonuclear-diatomics.json") for mlip in mlips] | ||
df = pd.concat(dfs, ignore_index=True) | ||
|
||
for atoms in trj: | ||
rs.append(LA.norm(atoms.positions[1] - atoms.positions[0])) | ||
es.append(atoms.get_potential_energy()) | ||
s2s.append(np.power(atoms.get_magnetic_moments(), 2).mean()) | ||
|
||
rs = np.array(rs) | ||
ind = np.argsort(rs) | ||
es = np.array(es) | ||
s2s = np.array(s2s) | ||
|
||
rs = rs[ind] | ||
es = es[ind] | ||
s2s = s2s[ind] | ||
df.drop_duplicates(inplace=True, subset=["name", "method"]) | ||
|
||
es = es - es[-1] | ||
for i, symbol in enumerate(chemical_symbols[1:]): | ||
|
||
xs = np.linspace(rs.min()*0.99, rs.max()*1.01, int(5e2)) | ||
if i % ncols == 0: | ||
cols = st.columns(ncols) | ||
|
||
cs = CubicSpline(rs, es) | ||
ys = cs(xs) | ||
|
||
cs = CubicSpline(rs, s2s) | ||
s2s = cs(xs) | ||
rows = df[df["name"] == symbol + symbol] | ||
|
||
ylo = min(ys.min()*1.5, -1) | ||
if rows.empty: | ||
continue | ||
|
||
fig = px.scatter( | ||
x=xs, y=ys, | ||
render_mode="webgl", | ||
color=s2s, | ||
range_color=[0, s2s.max()], | ||
width=500, | ||
range_y=[ylo, 1.2*(abs(ylo))], | ||
# title=f"{atoms.get_chemical_formula()}", | ||
labels={"x": "Bond length (Å)", "y": "Energy", "color": "Magnetic moment"}, | ||
# fig = go.Figure() | ||
fig = make_subplots(specs=[[{"secondary_y": True}]]) | ||
|
||
ylo = float("inf") | ||
|
||
for j, method in enumerate(rows["method"].unique()): | ||
row = rows[rows["method"] == method].iloc[0] | ||
|
||
rs = np.array(row["R"]) | ||
es = np.array(row["E"]) | ||
fs = np.array(row["F"]) | ||
|
||
rs = np.array(rs) | ||
ind = np.argsort(rs) | ||
es = np.array(es) | ||
fs = np.array(fs) | ||
|
||
rs = rs[ind] | ||
es = es[ind] | ||
es = es - es[-1] | ||
fs = fs[ind] | ||
|
||
xs = np.linspace(rs.min()*0.99, rs.max()*1.01, int(5e2)) | ||
|
||
if energy_plot: | ||
cs = CubicSpline(rs, es) | ||
ys = cs(xs) | ||
|
||
ylo = min(ylo, ys.min()*1.2, -1) | ||
|
||
fig.add_trace( | ||
go.Scatter( | ||
x=xs, y=ys, | ||
mode="lines", | ||
line=dict( | ||
color=color_sequence[j % len(color_sequence)], | ||
width=2, | ||
), | ||
name=method, | ||
), | ||
secondary_y=False, | ||
) | ||
|
||
if force_plot: | ||
cs = CubicSpline(rs, fs) | ||
ys = cs(xs) | ||
|
||
fig.add_trace( | ||
go.Scatter( | ||
x=xs, y=ys, | ||
mode="lines", | ||
line=dict( | ||
color=color_sequence[j % len(color_sequence)], | ||
width=1, | ||
dash="dot", | ||
), | ||
name=method, | ||
showlegend=False if energy_plot else True, | ||
), | ||
secondary_y=True, | ||
) | ||
|
||
|
||
fig.update_layout( | ||
showlegend=True, | ||
title_text=f"{symbol}-{symbol}", | ||
title_x=0.5, | ||
# yaxis_range=[ylo, 2*(abs(ylo))], | ||
) | ||
|
||
cols[i % 3].title(f"{symbol+symbol}") | ||
cols[i % 3].plotly_chart(fig, use_container_width=False) | ||
# Set x-axis title | ||
fig.update_xaxes(title_text="Bond length (Å)") | ||
|
||
# st.latex(r"\frac{d^2E}{dr^2} = \frac{d^2E}{dr^2}") | ||
# Set y-axes titles | ||
if energy_plot: | ||
fig.update_yaxes(title_text="Energy [eV]", secondary_y=False) | ||
|
||
# st.components.v1.html(fig.to_html(include_mathjax='cdn'),height=500) | ||
if force_plot: | ||
fig.update_yaxes(title_text="Force [eV/Å]", secondary_y=True) | ||
|
||
# cols[i % ncols].title(f"{row['name']}") | ||
cols[i % ncols].plotly_chart(fig, use_container_width=True, height=250) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import streamlit as st | ||
from ase.data import chemical_symbols | ||
from pymatgen.core import Element | ||
|
||
elements = [Element.from_Z(z) for z in range(1, 119)] | ||
|
||
# Define the number of rows and columns in the periodic table | ||
rows = 9 # There are 7 rows in the conventional periodic table | ||
columns = 18 | ||
|
||
# Define a function to display the periodic table | ||
def display_periodic_table(): | ||
# elements = [ | ||
# (element, element) for element in chemical_symbols[1:] | ||
# ] | ||
|
||
# cols = st.columns(18, gap='small', vertical_alignment='bottom') # Create 18 columns for the periodic table layout | ||
|
||
row = 0 | ||
for element in elements: | ||
symbol = element.symbol | ||
atomic_number = element.Z | ||
group = element.group | ||
|
||
if element.row > row: | ||
cols = st.columns(columns, gap='small', vertical_alignment='bottom') | ||
row = element.row | ||
|
||
if element.block == 'f': | ||
continue | ||
|
||
with cols[group - 1]: | ||
if st.button(symbol, use_container_width=True): | ||
st.session_state.selected_element = symbol | ||
st.session_state.selected_name = symbol | ||
st.rerun() | ||
# st.experimental_rerun() | ||
|
||
for element in elements: | ||
symbol = element.symbol | ||
atomic_number = element.Z | ||
group = element.group | ||
|
||
if element.row > row: | ||
cols = st.columns(columns, gap='small', vertical_alignment='bottom') | ||
row = element.row | ||
|
||
if element.block == 'f': | ||
noble = Element.from_row_and_group(row-1, 18) | ||
row += 2 | ||
group += atomic_number - noble.Z - 2 | ||
else: | ||
continue | ||
|
||
with cols[group - 1]: | ||
if st.button(symbol, use_container_width=True): | ||
st.session_state.selected_element = symbol | ||
st.session_state.selected_name = symbol | ||
st.rerun() | ||
# st.experimental_rerun() | ||
|
||
|
||
# for idx, (symbol, name) in enumerate(elements): | ||
# with cols[idx % 18]: # Place each element in the correct column | ||
# if st.button(symbol, use_container_width=True): | ||
# st.session_state.selected_element = symbol | ||
# st.session_state.selected_name = name | ||
# st.experimental_rerun() | ||
|
||
# Define a function to display the details of an element | ||
def display_element_details(): | ||
symbol = st.session_state.selected_element | ||
name = st.session_state.selected_name | ||
st.write(f"### {name} ({symbol})") | ||
st.write(f"Details about {name} ({symbol}) will be displayed here.") | ||
if st.button("Back to Periodic Table"): | ||
st.session_state.selected_element = None | ||
st.session_state.selected_name = None | ||
st.experimental_rerun() | ||
|
||
|
||
st.title("Periodic Table") | ||
|
||
# st.balloons() | ||
if 'selected_element' not in st.session_state: | ||
st.session_state.selected_element = None | ||
|
||
if st.session_state.selected_element: | ||
display_element_details() | ||
else: | ||
display_periodic_table() | ||
|