Skip to content

Commit

Permalink
Bandswidget bump to 0.5.0, make the data format compatible (#313)
Browse files Browse the repository at this point in the history
bump to use the new `bandsplot-widget`, the data format is change to compatible with its new API.
  • Loading branch information
unkcpz authored Nov 29, 2022
1 parent c21a7ed commit ec1e4a4
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 50 deletions.
158 changes: 119 additions & 39 deletions aiidalab_qe/node_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import json
import random
import shutil
import typing
from importlib import resources
Expand All @@ -24,7 +25,7 @@
from filelock import FileLock, Timeout
from IPython.display import HTML, display
from jinja2 import Environment
from monty.json import MontyEncoder, jsanitize
from monty.json import jsanitize
from traitlets import Instance, Int, List, Unicode, Union, default, observe, validate
from widget_bandsplot import BandsPlotWidget

Expand Down Expand Up @@ -126,55 +127,134 @@ def export_bands_data(work_chain_node, fermi_energy=None):
return None


def cmap(label: str) -> str:
"""Return RGB string of color for given pseudo info
Hardcoded at the momment.
"""
# if a unknow type generate random color based on ascii sum
ascn = sum([ord(c) for c in label])
random.seed(ascn)

return "#%06x" % random.randint(0, 0xFFFFFF)


def _projections_curated(
projections: ProjectionData,
curated_tag="atom",
spin_type="none",
line_style="solid",
):
"""Collect the data from ProjectionData and parse it as dos list which can be
understand by bandsplot widget. `curated_tag` is for which tag to be grouped, by atom or by orbital name.
The spin_type is used to invert all the y values of pdos to be shown as spin down pdos and to set label."""
_pdos = {}

for orbital, pdos, energy in projections.get_pdos():
orbital_data = orbital.get_orbital_dict()
kind_name = orbital_data["kind_name"]
atom_position = [round(i, 2) for i in orbital_data["position"]]
orbital_name = orbital.get_name_from_quantum_numbers(
orbital_data["angular_momentum"], orbital_data["magnetic_number"]
).lower()

if curated_tag == "atom":
curated_tag_var = atom_position
else:
# by orbital label
curated_tag_var = orbital_name

key = f"{kind_name}-{curated_tag_var}"
if key in _pdos:
_pdos[key][1] += pdos
else:
_pdos[key] = [energy, pdos]

dos = []
for label, (energy, pdos) in _pdos.items():
if spin_type == "down":
# invert y-axis
pdos = -pdos
label = f"{label} (↓)"

if spin_type == "up":
label = f"{label} (↑)"

orbital_pdos = {
"label": label,
"x": energy.tolist(),
"y": pdos.tolist(),
"borderColor": cmap(label),
"lineStyle": line_style,
}
dos.append(orbital_pdos)

return dos


def export_pdos_data(work_chain_node):
if "dos" in work_chain_node.outputs:
_, energy_dos, energy_units = work_chain_node.outputs.dos.get_x()
tdos_values = {
f"{n} | {u}": v for n, v, u in work_chain_node.outputs.dos.get_y()
}
_, energy_dos, _ = work_chain_node.outputs.dos.get_x()
tdos_values = {f"{n}": v for n, v, _ in work_chain_node.outputs.dos.get_y()}

pdos_orbitals = []
dos = []

if "projections" in work_chain_node.outputs:
projection_list = [
(work_chain_node.outputs.projections, None),
]
# The total dos parsed
tdos = {
"label": "Total DOS",
"x": energy_dos.tolist(),
"y": tdos_values.get("dos").tolist(),
"borderColor": "#8A8A8A", # dark gray
"backgroundColor": "#999999", # light gray
"backgroundAlpha": "40%",
"lineStyle": "solid",
}
dos.append(tdos)

dos += _projections_curated(
work_chain_node.outputs.projections, spin_type="none"
)

else:
projection_list = [
(work_chain_node.outputs.projections_up, "up"),
(work_chain_node.outputs.projections_down, "dn"),
]
tdos_values["dos | states/eV"] = tdos_values.pop(
"dos_spin_up | states/eV"
) + tdos_values.pop("dos_spin_down | states/eV")

for projections, suffix in projection_list: # type: ProjectionData, str
for orbital, pdos, energy in projections.get_pdos():
orbital_data = orbital.get_orbital_dict()
kind_name = orbital_data["kind_name"]
orbital_name = orbital.get_name_from_quantum_numbers(
orbital_data["angular_momentum"], orbital_data["magnetic_number"]
)
if suffix is not None:
orbital_name += f"-{suffix}"

pdos_orbitals.append(
{
"kind": kind_name,
"orbital": orbital_name,
"energy | eV": energy,
"pdos | states/eV": pdos,
}
)
# The total dos parsed
tdos_up = {
"label": "Total DOS (↑)",
"x": energy_dos.tolist(),
"y": tdos_values.get("dos_spin_up").tolist(),
"borderColor": "#8A8A8A", # dark gray
"backgroundColor": "#999999", # light gray
"backgroundAlpha": "40%",
"lineStyle": "solid",
}
tdos_down = {
"label": "Total DOS (↓)",
"x": energy_dos.tolist(),
"y": (-tdos_values.get("dos_spin_down")).tolist(), # minus
"borderColor": "#8A8A8A", # dark gray
"backgroundColor": "#999999", # light gray
"backgroundAlpha": "40%",
"lineStyle": "dash",
}
dos += [tdos_up, tdos_down]

# spin-up (↑)
dos += _projections_curated(
work_chain_node.outputs.projections_up, spin_type="up"
)

# spin-dn (↓)
dos += _projections_curated(
work_chain_node.outputs.projections_down,
spin_type="down",
line_style="dash",
)

data_dict = {
"fermi_energy": work_chain_node.outputs.nscf_parameters["fermi_energy"],
"tdos": {f"energy | {energy_units}": energy_dos, "values": tdos_values},
"pdos": pdos_orbitals,
"dos": dos,
}

# And this is why we shouldn't use special encoders...
return json.loads(json.dumps(data_dict, cls=MontyEncoder))
return json.loads(json.dumps(data_dict))

else:
return None
Expand Down
10 changes: 0 additions & 10 deletions qe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,6 @@
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"AIIDA_WARN_v3\"] = \"1\""
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ install_requires =
filelock~=3.8
importlib-resources~=5.2.2
numpy~=1.23
widget-bandsplot~=0.2.8
widget-bandsplot~=0.5.0
python_requires = >=3.8

[options.extras_require]
Expand Down

0 comments on commit ec1e4a4

Please sign in to comment.