Skip to content

Commit

Permalink
add equiformer, escn; add leaderboard
Browse files Browse the repository at this point in the history
  • Loading branch information
chiang-yuan committed Jul 7, 2024
1 parent 89bc52a commit 221dfe3
Show file tree
Hide file tree
Showing 12 changed files with 131 additions and 87 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ __pycache__/
*$py.class
tests/
*.out
mlip_arena/tasks/*/*/

# C extensions
*.so
Expand Down
16 changes: 8 additions & 8 deletions mlip_arena/models/registry.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ MACE_MP_Medium:
gpu-tasks:
- diatomics

CHGNet:
module: chgnet
username: cyrusyc
datetime: 2024-03-25T14:30:00
datasets:
- atomind/mptrj
cpu-tasks:
- diatomics
# CHGNet:
# module: chgnet
# username: cyrusyc
# datetime: 2024-03-25T14:30:00
# datasets:
# - atomind/mptrj
# cpu-tasks:
# - diatomics

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

60 changes: 56 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,20 @@ classifiers=[
"Programming Language :: Python :: 3 :: Only",
]
dependencies=[
"torch>=2.0.0",
"torch",
"ase",
"torch_dftd>=0.4.0",
"huggingface_hub",
"torch-geometric>=2.5.2",
"torch-geometric",
"safetensors"
]

[project.optional-dependencies]
m3gnet = ["matgl", "dgl", "torch<=2.2.1"]
mace = ["mace-torch"]
chgnet = ["chgnet"]
fairchem = ["fairchem"]

[project.urls]
Homepage = "https://github.com/atomind-ai/mlip-arena"
Issues = "https://github.com/atomind-ai/mlip-arena/issues"
Expand Down Expand Up @@ -74,7 +80,53 @@ line-length = 88
indent-width = 4

[tool.ruff.lint]
select = ["ALL"]
ignore = []
select = [
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"E", # pycodestyle error
"EXE", # flake8-executable
"F", # pyflakes
"FA", # flake8-future-annotations
"FBT003", # boolean-positional-value-in-call
"FLY", # flynt
"I", # isort
"ICN", # flake8-import-conventions
"PD", # pandas-vet
"PERF", # perflint
"PIE", # flake8-pie
"PL", # pylint
"PT", # flake8-pytest-style
"PYI", # flakes8-pyi
"Q", # flake8-quotes
"RET", # flake8-return
"RSE", # flake8-raise
"RUF", # Ruff-specific rules
"SIM", # flake8-simplify
"SLOT", # flake8-slots
"TCH", # flake8-type-checking
"TID", # tidy imports
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warning
"YTT", # flake8-2020
]
ignore = [
"C408", # Unnecessary dict call
"PLR", # Design related pylint codes
"E501", # Line too long
"B028", # No explicit stacklevel
"EM101", # Exception must not use a string literal
"EM102", # Exception must not use an f-string literal
"G004", # f-string in Logging statement
"RUF015", # Prefer next(iter())
"RET505", # Unnecessary `elif` after `return`
"PT004", # Fixture does not return anthing
"B017", # pytest.raises
"PT011", # pytest.raises
"PT012", # pytest.raises"
"E741", # ambigous variable naming, i.e. one letter
"FBT003", # boolean positional variable in function call
"PERF203", # `try`-`except` within a loop incurs performance overhead (no overhead in Py 3.11+)
]
fixable = ["ALL"]
pydocstyle.convention = "google"
18 changes: 11 additions & 7 deletions serve/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
layout="wide",
page_title="MLIP Arena",
page_icon=":shark:",
# initial_sidebar_state="expanded",
menu_items=None
initial_sidebar_state="expanded",
menu_items={
"About": 'https://github.com/atomind-ai/mlip-arena',
"Report a bug": "https://github.com/atomind-ai/mlip-arena/issues/new",
}
)

# if "logged_in" not in st.session_state:
Expand All @@ -24,26 +27,27 @@
# login_page = st.Page(login, title="Log in", icon=":material/login:")
# logout_page = st.Page(logout, title="Log out", icon=":material/logout:")

dashboard = st.Page(
"reports/dashboard.py", title="Dashboard", icon=":material/dashboard:"
leaderboard = st.Page(
"models/leaderboard.py", title="Leaderboard", icon=":material/trophy:"
)
bugs = st.Page("reports/bugs.py", title="Bug reports", icon=":material/bug_report:")
bugs = st.Page("models/bugs.py", title="Bug reports", icon=":material/bug_report:")
alerts = st.Page(
"reports/alerts.py", title="System alerts", icon=":material/notification_important:"
"models/alerts.py", title="System alerts", icon=":material/notification_important:"
)

search = st.Page("tools/search.py", title="Search", icon=":material/search:")
history = st.Page("tools/history.py", title="History", icon=":material/history:")
ptable = st.Page("tools/ptable.py", title="Periodic table", icon=":material/gradient:")

diatomics = st.Page("tasks/homonuclear-diatomics.py", title="Homonuclear diatomics", icon="", default=True)
diatomics = st.Page("tasks/homonuclear-diatomics.py", title="Homonuclear diatomics", icon=":material/target:", default=True)

# if st.session_state.logged_in:
pg = st.navigation(
{
# "Account": [logout_page],
# "Reports": [dashboard, bugs, alerts],
# "Tools": [search, history, ptable],
"Models": [leaderboard],
"Tasks": [diatomics],
"Tools": [ptable],
}
Expand Down
4 changes: 0 additions & 4 deletions serve/reports/alerts.py

This file was deleted.

4 changes: 0 additions & 4 deletions serve/reports/bugs.py

This file was deleted.

23 changes: 0 additions & 23 deletions serve/reports/dashboard.py

This file was deleted.

84 changes: 50 additions & 34 deletions serve/tasks/homonuclear-diatomics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,35 +9,46 @@
from plotly.subplots import make_subplots
from scipy.interpolate import CubicSpline

color_sequence = pcolors.qualitative.Plotly



st.markdown("# Homonuclear diatomics")

# button to toggle plots
st.markdown("### Methods")
container = st.container(border=True)
energy_plot = container.checkbox("Show energy curves", value=True)
force_plot = container.checkbox("Show force curves", value=True)
methods = container.multiselect("MLIPs", ["MACE-MP", "Equiformer", "CHGNet", "MACE-OFF", "eSCN"], ["MACE-MP", "Equiformer", "CHGNet", "eSCN"])
methods += container.multiselect("DFT Methods", ["GPAW"], [])

ncols = 2
st.markdown("### Settings")
vis = st.container(border=True)
energy_plot = vis.checkbox("Show energy curves", value=True)
force_plot = vis.checkbox("Show force curves", value=True)
ncols = vis.select_slider("Number of columns", options=[1, 2, 3, 4], value=3)

DATA_DIR = Path("mlip_arena/tasks/diatomics")
mlips = ["MACE-MP", "CHGNet"]
# Get all attributes from pcolors.qualitative
all_attributes = dir(pcolors.qualitative)
color_palettes = {attr: getattr(pcolors.qualitative, attr) for attr in all_attributes if isinstance(getattr(pcolors.qualitative, attr), list)}
color_palettes.pop("__all__", None)

dfs = [pd.read_json(DATA_DIR / mlip.lower() / "homonuclear-diatomics.json") for mlip in mlips]
df = pd.concat(dfs, ignore_index=True)
palette_names = list(color_palettes.keys())
palette_colors = list(color_palettes.values())

palette_name = vis.selectbox(
"Color sequence",
options=palette_names, index=22
)

color_sequence = color_palettes[palette_name] # type: ignore

DATA_DIR = Path("mlip_arena/tasks/diatomics")
dfs = [pd.read_json(DATA_DIR / method.lower() / "homonuclear-diatomics.json") for method in methods]
df = pd.concat(dfs, ignore_index=True)
df.drop_duplicates(inplace=True, subset=["name", "method"])

method_color_mapping = {method: color_sequence[i % len(color_sequence)] for i, method in enumerate(df["method"].unique())}

for i, symbol in enumerate(chemical_symbols[1:]):

if i % ncols == 0:
cols = st.columns(ncols)


rows = df[df["name"] == symbol + symbol]

if rows.empty:
Expand All @@ -61,57 +72,67 @@

rs = rs[ind]
es = es[ind]
es = es - es[-1]
fs = fs[ind]
if "GPAW" not in method:
es = es - es[-1]
else:
pass

xs = np.linspace(rs.min()*0.99, rs.max()*1.01, int(5e2))
if "GPAW" not in method:
fs = fs[ind]

if "GPAW" in method:
xs = np.linspace(rs.min()*0.99, rs.max()*1.01, int(5e2))
else:
xs = rs

if energy_plot:
cs = CubicSpline(rs, es)
ys = cs(xs)
if "GPAW" in method:
cs = CubicSpline(rs, es)
ys = cs(xs)
else:
ys = es

elo = min(elo, ys.min()*1.2, -1)
elo = min(elo, max(ys.min()*1.2, -15), -1)

fig.add_trace(
go.Scatter(
x=xs, y=ys,
mode="lines",
line=dict(
color=color_sequence[j % len(color_sequence)],
color=method_color_mapping[method],
width=2,
),
name=method,
),
secondary_y=False,
)

if force_plot:
cs = CubicSpline(rs, fs)
ys = cs(xs)
if force_plot and "GPAW" not in method:
ys = fs

flo = min(flo, ys.min()*1.2)
flo = min(flo, max(ys.min()*1.2, -50))

fig.add_trace(
go.Scatter(
x=xs, y=ys,
mode="lines",
line=dict(
color=color_sequence[j % len(color_sequence)],
color=method_color_mapping[method],
width=1,
dash="dot",
),
name=method,
showlegend=False if energy_plot else True,
showlegend=not energy_plot,
),
secondary_y=True,
)

name = f"{symbol}-{symbol}"

fig.update_layout(
showlegend=True,
title_text=f"{symbol}-{symbol}",
title_text=f"{name}",
title_x=0.5,
# yaxis_range=[ylo, 2*(abs(ylo))],
)

# Set x-axis title
Expand All @@ -128,21 +149,16 @@
)
)

# fig.update_yaxes(title_text="Energy [eV]", secondary_y=False)

if force_plot:

fig.update_layout(
yaxis2=dict(
title=dict(text="Force [eV/Å]"),
side="right",
range=[flo, 2*(abs(flo))],
range=[flo, 1.5*abs(flo)],
overlaying="y",
tickmode="sync",
),
)

# fig.update_yaxes(title_text="Force [eV/Å]", secondary_y=True)

# cols[i % ncols].title(f"{row['name']}")
cols[i % ncols].plotly_chart(fig, use_container_width=True, height=250)
3 changes: 2 additions & 1 deletion serve/tools/ptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def display_element_details():
if st.button("Back to Periodic Table"):
st.session_state.selected_element = None
st.session_state.selected_name = None
st.experimental_rerun()
st.rerun()
# st.experimental_rerun()


st.title("Periodic Table")
Expand Down

0 comments on commit 221dfe3

Please sign in to comment.