From 87ba4dc61080be9d88fecbaba9085096544e3753 Mon Sep 17 00:00:00 2001 From: AndresOrtegaGuerrero <34098967+AndresOrtegaGuerrero@users.noreply.github.com> Date: Mon, 5 Feb 2024 14:53:56 +0100 Subject: [PATCH] New BandsPdosWidget (#581) The old bandsplot widget https://github.com/aiidalab/widget-bandsplot was separately maintained in the repo with complex and not well modulized JS code. In this PR, the new bandspdoswidget is introduced to have it all implemented in python and in ipywidget framework with using plotly as plot engine. The output band structure is in a publish ready level with a flexible control on the groups of orbitals/atoms user want to projected. Co-authored-by: Jusong Yu --- setup.cfg | 1 - src/aiidalab_qe/common/bandpdoswidget.py | 862 ++++++++++++++++++ src/aiidalab_qe/plugins/bands/result.py | 33 +- .../plugins/electronic_structure/result.py | 230 +---- src/aiidalab_qe/plugins/pdos/result.py | 203 +---- tests/test_plugins_bands.py | 30 +- tests/test_plugins_electronic_structure.py | 36 +- tests/test_plugins_pdos.py | 42 +- 8 files changed, 961 insertions(+), 476 deletions(-) create mode 100644 src/aiidalab_qe/common/bandpdoswidget.py diff --git a/setup.cfg b/setup.cfg index 1b2c18642..5f4aea118 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,7 +31,6 @@ install_requires = aiida-pseudo~=1.4 filelock~=3.8 importlib-resources~=5.2 - widget-bandsplot~=0.5.1 python_requires = >=3.8 [options.packages.find] diff --git a/src/aiidalab_qe/common/bandpdoswidget.py b/src/aiidalab_qe/common/bandpdoswidget.py new file mode 100644 index 000000000..16909e42b --- /dev/null +++ b/src/aiidalab_qe/common/bandpdoswidget.py @@ -0,0 +1,862 @@ +import base64 +import json + +import ipywidgets as ipw +import numpy as np +import plotly.graph_objects as go +from aiida.orm import ProjectionData +from aiidalab_widgets_base.utils import string_range_to_list, StatusHTML +from IPython.display import clear_output, display +from plotly.subplots import make_subplots +import re + + +class BandPdosPlotly: + SETTINGS = { + "axis_linecolor": "#111111", + "bands_linecolor": "#111111", + "bands_up_linecolor": "rgba(205, 0, 0, 0.4)", # Red Opacitiy 40% + "bands_down_linecolor": "rgba(72,118,255, 0.4)", # Blue Opacitiy 40% + "combined_plot_height": 600, + "combined_plot_width": 900, + "combined_column_widths": [0.7, 0.3], + "bands_plot_height": 600, + "bands_plot_width": 850, + "pdos_plot_height": 600, + "pdos_plot_width": 850, + "vertical_linecolor": "#111111", + "horizontal_linecolor": "#111111", + "vertical_range_bands": [-10, 10], + "horizontal_range_pdos": [-10, 10], + } + + def __init__(self, bands_data=None, pdos_data=None): + self.bands_data = bands_data + self.pdos_data = pdos_data + self.fermi_energy = self._get_fermi_energy() + + # Plotly Axis + # Plotly settings + self._bands_xaxis = self._band_xaxis() + self._bands_yaxis = self._band_yaxis() + self._dos_xaxis = self._dos_xaxis() + self._dos_yaxis = self._dos_yaxis() + + def _get_fermi_energy(self): + fermi_energy = ( + self.pdos_data["fermi_energy"] + if self.pdos_data + else self.bands_data["fermi_energy"] + ) + return fermi_energy + + def _band_xaxis(self): + """Function to return the xaxis for the bands plot.""" + + if not self.bands_data: + return None + paths = self.bands_data.get("paths") + slider_bands = go.layout.xaxis.Rangeslider( + thickness=0.08, + range=[0, paths[-1]["x"][-1]], + ) + bandxaxis = go.layout.XAxis( + title="k-points", + range=[0, paths[-1]["x"][-1]], + showgrid=True, + showline=True, + tickmode="array", + rangeslider=slider_bands, + fixedrange=False, + tickvals=self.bands_data["pathlabels"][1], # ,self.band_labels[1], + ticktext=self.bands_data["pathlabels"][0], # self.band_labels[0], + showticklabels=True, + linecolor=self.SETTINGS["axis_linecolor"], + mirror=True, + linewidth=2, + type="linear", + ) + + return bandxaxis + + def _band_yaxis(self): + """Function to return the yaxis for the bands plot.""" + + if not self.bands_data: + return None + + bandyaxis = go.layout.YAxis( + title=dict(text="Electronic Bands (eV)", standoff=1), + side="left", + showgrid=True, + showline=True, + zeroline=True, + range=self.SETTINGS["vertical_range_bands"], + fixedrange=False, + automargin=True, + ticks="inside", + linewidth=2, + linecolor=self.SETTINGS["axis_linecolor"], + tickwidth=2, + zerolinewidth=2, + ) + + return bandyaxis + + def _dos_xaxis(self): + """Function to return the xaxis for the dos plot.""" + + if not self.pdos_data: + return None + + if self.bands_data: + dosxaxis = go.layout.XAxis( + title="Density of states", + side="bottom", + showgrid=True, + showline=True, + linecolor=self.SETTINGS["axis_linecolor"], + mirror="ticks", + ticks="inside", + linewidth=2, + tickwidth=2, + automargin=True, + ) + + else: + dosxaxis = go.layout.XAxis( + title="Density of states (eV)", + showgrid=True, + showline=True, + linecolor=self.SETTINGS["axis_linecolor"], + mirror="ticks", + ticks="inside", + linewidth=2, + tickwidth=2, + range=self.SETTINGS["horizontal_range_pdos"], + ) + + return dosxaxis + + def _dos_yaxis(self): + """Function to return the yaxis for the dos plot.""" + + if not self.pdos_data: + return None + + if self.bands_data: + dosyaxis = go.layout.YAxis( + # title= {"text":"Density of states (eV)", "standoff": 1}, + showgrid=True, + showline=True, + side="right", + mirror="ticks", + ticks="inside", + linewidth=2, + tickwidth=2, + linecolor=self.SETTINGS["axis_linecolor"], + zerolinewidth=2, + ) + + else: + dosyaxis = go.layout.YAxis( + # title="Density of states (eV)", + showgrid=True, + showline=True, + side="left", + mirror="ticks", + ticks="inside", + linewidth=2, + tickwidth=2, + linecolor=self.SETTINGS["axis_linecolor"], + zerolinewidth=2, + ) + + return dosyaxis + + def _get_bandspdos_plot(self): + """Function to return the bands plot widget.""" + conditions = { + (True, False): self._create_bands_only_plot, + (False, True): self._create_dos_only_plot, + (True, True): self._create_combined_plot, + } + + return conditions.get((bool(self.bands_data), bool(self.pdos_data)), None)() + + def _create_bands_only_plot(self): + """Function to return the bands plot widget.""" + + fig = go.Figure() + paths = self.bands_data.get("paths") + + self._add_band_traces(fig, paths, "bands_only") + + band_labels = self.bands_data.get("pathlabels") + for i in band_labels[1]: + fig.add_vline( + x=i, line=dict(color=self.SETTINGS["vertical_linecolor"], width=1) + ) + fig.update_layout( + xaxis=self._bands_xaxis, + yaxis=self._bands_yaxis, + plot_bgcolor="white", + height=self.SETTINGS["bands_plot_height"], + width=self.SETTINGS["bands_plot_width"], + ) + return go.FigureWidget(fig) + + def _create_dos_only_plot(self): + """Function to return the pdos plot widget.""" + + fig = go.Figure() + # Extract DOS data + self._add_dos_traces(fig, plot_type="dos_only") + # Add a vertical line at zero energy + fig.add_vline( + x=0, + line=dict(color=self.SETTINGS["vertical_linecolor"], width=1, dash="dot"), + ) + + # Update the layout of the Figure + fig.update_layout( + xaxis=self._dos_xaxis, + yaxis=self._dos_yaxis, + plot_bgcolor="white", + height=self.SETTINGS["pdos_plot_height"], + width=self.SETTINGS["pdos_plot_width"], + ) + + return go.FigureWidget(fig) + + def _create_combined_plot(self): + fig = make_subplots( + rows=1, + cols=2, + shared_yaxes=True, + column_widths=self.SETTINGS["combined_column_widths"], + horizontal_spacing=0.015, + ) + paths = self.bands_data.get("paths") + self._add_band_traces(fig, paths, plot_type="combined") + self._add_dos_traces(fig, plot_type="combined") + band_labels = self.bands_data.get("pathlabels") + for i in band_labels[1]: + fig.add_vline( + x=i, + line=dict(color=self.SETTINGS["vertical_linecolor"], width=1), + row=1, + col=1, + ) + self._customize_combined_layout(fig) + return go.FigureWidget(fig) + + def _add_band_traces(self, fig, paths, plot_type): + paths = self.bands_data.get("paths") + + # Spin condition: True if spin-polarized False if not + spin_type = paths[0].get("two_band_types") + # Convert paths to a list of Scatter objects + scatter_objects = [] + + for band in paths: + if not spin_type: + # Non-spin-polarized case + for bands in band["values"]: + bands_np = np.array(bands) + scatter_objects.append( + go.Scatter( + x=band["x"], + y=bands_np - self.fermi_energy, + mode="lines", + line=dict( + color=self.SETTINGS["bands_linecolor"], + shape="spline", + smoothing=1.3, + ), + showlegend=False, + ) + ) + else: + half_len = len(band["values"]) // 2 + first_half = band["values"][:half_len] + second_half = band["values"][half_len:] + + # Red line for the Spin up + color_first_half = self.SETTINGS["bands_up_linecolor"] + # Blue line for the Spin down + color_second_half = self.SETTINGS["bands_down_linecolor"] + + for bands, color in zip( + (first_half, second_half), (color_first_half, color_second_half) + ): + for band_values in bands: + bands_np = np.array(band_values) + scatter_objects.append( + go.Scatter( + x=band["x"], + y=bands_np - self.fermi_energy, + mode="lines", + line=dict( + color=color, + shape="spline", + smoothing=1.3, + ), + showlegend=False, + ) + ) + + if plot_type == "bands_only": + fig.add_traces(scatter_objects) + else: + rows = [1] * len(scatter_objects) + cols = [1] * len(scatter_objects) + fig.add_traces(scatter_objects, rows=rows, cols=cols) + + def _add_dos_traces(self, fig, plot_type): + # Extract DOS data + dos_data = self.pdos_data["dos"] + + # Pre-allocate memory for Scatter objects + num_traces = len(dos_data) + scatter_objects = [None] * num_traces + + # Vectorize Scatter object creation + for i, trace in enumerate(dos_data): + dos_np = np.array(trace["x"]) + fill = "tozerox" if plot_type == "combined" else "tozeroy" + x_data = ( + trace["y"] if plot_type == "combined" else dos_np - self.fermi_energy + ) + y_data = ( + dos_np - self.fermi_energy if plot_type == "combined" else trace["y"] + ) + scatter_objects[i] = go.Scatter( + x=x_data, + y=y_data, + fill=fill, + name=trace["label"], + line=dict(color=trace["borderColor"], shape="spline", smoothing=1.0), + ) + if plot_type == "dos_only": + fig.add_traces(scatter_objects) + else: + rows = [1] * len(scatter_objects) + cols = [2] * len(scatter_objects) + fig.add_traces(scatter_objects, rows=rows, cols=cols) + + def _customize_combined_layout(self, fig): + self._customize_layout(fig, self._bands_xaxis, self._bands_yaxis) + self._customize_layout(fig, self._dos_xaxis, self._dos_yaxis, col=2) + fig.update_layout( + legend=dict(xanchor="left", x=1.06), + height=self.SETTINGS["combined_plot_height"], + width=self.SETTINGS["combined_plot_width"], + plot_bgcolor="white", + ) + + def _customize_layout(self, fig, xaxis, yaxis, row=1, col=1): + fig.update_xaxes(patch=xaxis, row=row, col=col) + fig.update_yaxes(patch=yaxis, row=row, col=col, showticklabels=True) + fig.add_hline( + y=0, + line=dict(color=self.SETTINGS["horizontal_linecolor"], width=1, dash="dot"), + row=row, + col=col, + ) + + @property + def bandspdosfigure(self): + return self._get_bandspdos_plot() + + +class BandPdosWidget(ipw.VBox): + """ + A widget for plotting band structure and projected density of states (PDOS) data. + + Parameters: + - bands (optional): A node containing band structure data. + - pdos (optional): A node containing PDOS data. + + Attributes: + - description: HTML description of the widget. + - dos_atoms_group: Dropdown widget to select the grouping of atoms for PDOS plotting. + - dos_plot_group: Dropdown widget to select the type of PDOS contributions to plot. + - selected_atoms: Text widget to select specific atoms for PDOS plotting. + - update_plot_button: Button widget to update the plot. + - download_button: Button widget to download the data. + - dos_data: PDOS data. + - bands_data: Band structure data. + - bandsplot_widget: Plotly widget for band structure and PDOS plot. + - bands_widget: Output widget to display the bandsplot widget. + - pdos_options_out: Output widget to clear specific widgets. + """ + + description = ipw.HTML( + """
+ Select the style of plotting the projected density of states. +
""" + ) + + def __init__(self, bands=None, pdos=None, **kwargs): + if bands is None and pdos is None: + raise ValueError("Either bands or pdos must be provided") + + self.bands = bands # bands node + self.pdos = pdos # pdos node + + self.dos_atoms_group = ipw.Dropdown( + description="Group by:", + options=[ + ("Kinds", "kinds"), + ("Atoms", "atoms"), + ], + value="kinds", + style={"description_width": "initial"}, + ) + self.dos_plot_group = ipw.Dropdown( + description="Plot contributions:", + options=[ + ("Total", "total"), + ("Orbital", "orbital"), + ("Angular momentum", "angular_momentum"), + ], + value="total", + style={"description_width": "initial"}, + ) + self.selected_atoms = ipw.Text( + description="Select atoms:", + value="", + style={"description_width": "initial"}, + ) + self._wrong_syntax = StatusHTML(clear_after=8) + self.update_plot_button = ipw.Button( + description="Update Plot", + icon="pencil", + button_style="primary", + disabled=False, + ) + self.download_button = ipw.Button( + description="Download Data", + icon="download", + button_style="primary", + disabled=False, + layout=ipw.Layout(visibility="hidden"), + ) + + # Information for the plot + self.dos_data = self._get_dos_data() + self.bands_data = self._get_bands_data() + # Plotly widget + self.bandsplot_widget = BandPdosPlotly( + bands_data=self.bands_data, pdos_data=self.dos_data + ).bandspdosfigure + # Output widget to display the bandsplot widget + self.bands_widget = ipw.Output() + # Output widget to clear the specific widgets + self.pdos_options_out = ipw.Output() + + self.pdos_options = ipw.VBox( + [ + self.description, + self.dos_atoms_group, + self.dos_plot_group, + ipw.HBox([self.selected_atoms, self._wrong_syntax]), + self.update_plot_button, + ] + ) + + self._initial_view() + + # Set the event handlers + self.download_button.on_click(self.download_data) + self.update_plot_button.on_click(self._update_plot) + + super().__init__( + children=[ + self.pdos_options_out, + self.download_button, + self.bands_widget, # Add the output widget to the VBox + ], + **kwargs, + ) + if self.pdos: + with self.pdos_options_out: + display(self.pdos_options) + + def download_data(self, _=None): + """Function to download the data.""" + file_name_bands = "bands_data.json" + file_name_dos = "dos_data.json" + if self.bands_data: + json_str = json.dumps(self.bands_data) + b64_str = base64.b64encode(json_str.encode()).decode() + self._download(payload=b64_str, filename=file_name_bands) + if self.dos_data: + json_str = json.dumps(self.dos_data) + b64_str = base64.b64encode(json_str.encode()).decode() + self._download(payload=b64_str, filename=file_name_dos) + + @staticmethod + def _download(payload, filename): + """Download payload as a file named as filename.""" + from IPython.display import Javascript + + javas = Javascript( + """ + var link = document.createElement('a'); + link.href = 'data:text/json;charset=utf-8;base64,{payload}' + link.download = "{filename}" + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); + """.format(payload=payload, filename=filename) + ) + display(javas) + + def _get_dos_data(self): + if not self.pdos: + return None + expanded_selection, syntax_ok = string_range_to_list( + self.selected_atoms.value, shift=-1 + ) + if syntax_ok: + dos = get_pdos_data( + self.pdos, + group_tag=self.dos_atoms_group.value, + plot_tag=self.dos_plot_group.value, + selected_atoms=expanded_selection, + ) + return dos + else: + return None + + def _get_bands_data(self): + if not self.bands: + return None + + bands = export_bands_data(self.bands) + return bands + + def _initial_view(self): + with self.bands_widget: + self._clear_output_and_display(self.bandsplot_widget) + self.download_button.layout.visibility = "visible" + + def _update_plot(self, _=None): + with self.bands_widget: + expanded_selection, syntax_ok = string_range_to_list( + self.selected_atoms.value, shift=-1 + ) + if not syntax_ok: + self._wrong_syntax.message = """
ERROR: Invalid syntax for selected atoms
""" + clear_output(wait=True) + else: + self.dos_data = self._get_dos_data() + self.bandsplot_widget = BandPdosPlotly( + bands_data=self.bands_data, pdos_data=self.dos_data + ).bandspdosfigure + self._clear_output_and_display(self.bandsplot_widget) + + def _clear_output_and_display(self, widget=None): + clear_output(wait=True) + if widget: + display(widget) + + +def get_pdos_data(pdos, group_tag, plot_tag, selected_atoms): + dos = [] + + if "output_dos" not in pdos.dos: + return None + + _, energy_dos, _ = pdos.dos.output_dos.get_x() + tdos_values = {f"{n}": v for n, v, _ in pdos.dos.output_dos.get_y()} + + if "projections" in pdos.projwfc: + # Total DOS + tdos = { + "label": "Total DOS", + "x": energy_dos.tolist(), + "y": tdos_values.get("dos").tolist(), + "borderColor": "#8A8A8A", # dark gray + "backgroundColor": "#999999", # light gray + "backgroundAlpha": "40%", + "lineStyle": "solid", + } + dos.append(tdos) + dos += _projections_curated_options( + pdos.projwfc.projections, + spin_type="none", + group_tag=group_tag, + plot_tag=plot_tag, + selected_atoms=selected_atoms, + ) + else: + # Total DOS (↑) and Total DOS (↓) + tdos_up = { + "label": "Total DOS (↑)", + "x": energy_dos.tolist(), + "y": tdos_values.get("dos_spin_up").tolist(), + "borderColor": "#8A8A8A", # dark gray + "backgroundColor": "#999999", # light gray + "backgroundAlpha": "40%", + "lineStyle": "solid", + } + tdos_down = { + "label": "Total DOS (↓)", + "x": energy_dos.tolist(), + "y": (-tdos_values.get("dos_spin_down")).tolist(), + "borderColor": "#8A8A8A", # dark gray + "backgroundColor": "#999999", # light gray + "backgroundAlpha": "40%", + "lineStyle": "dash", + } + dos += [tdos_up, tdos_down] + + # Spin-up (↑) and Spin-down (↓) + dos += _projections_curated_options( + pdos.projwfc.projections_up, + spin_type="up", + group_tag=group_tag, + plot_tag=plot_tag, + selected_atoms=selected_atoms, + ) + dos += _projections_curated_options( + pdos.projwfc.projections_down, + spin_type="down", + line_style="dash", + group_tag=group_tag, + plot_tag=plot_tag, + selected_atoms=selected_atoms, + ) + + data_dict = { + "fermi_energy": pdos.nscf.output_parameters["fermi_energy"], + "dos": dos, + } + + return json.loads(json.dumps(data_dict)) + + +def _projections_curated_options( + projections: ProjectionData, + group_tag, + plot_tag, + selected_atoms, + spin_type="none", + line_style="solid", +): + _pdos = {} + list_positions = [] + + # Constants for HTML tags + HTML_TAGS = { + "s": "s", + "pz": "pz", + "px": "px", + "py": "py", + "dz2": "dz2", + "dxy": "dxy", + "dxz": "dxz", + "dyz": "dyz", + "dx2-y2": "dx2-y2", + "fz3": "fz3", + "fxz2": "fxz2", + "fyz2": "fyz2", + "fxyz": "fxzy", + "fx(x2-3y2)": "fx(x2-3y2)", + "fy(3x2-y2)": "fy(3x2-y2)", + "fy(x2-z2)": "fy(x2-z2)", + 0.5: "+1/2", + -0.5: "-1/2", + 1.5: "+3/2", + -1.5: "-3/2", + 2.5: "+5/2", + -2.5: "-5/2", + } + + # Constants for spin types + SPIN_LABELS = {"up": "(↑)", "down": "(↓)", "none": ""} + + def get_key( + group_tag, + plot_tag, + atom_position, + kind_name, + orbital_name_plotly, + orbital_angular_momentum, + ): + """Generates the key based on group_tag and plot_tag.""" + + key_formats = { + ("atoms", "total"): r"{var1}-{var}", + ("kinds", "total"): r"{var1}", + ("atoms", "orbital"): r"{var1}-{var}
{var2}", + ("kinds", "orbital"): r"{var1}-{var2}", + ("atoms", "angular_momentum"): r"{var1}-{var}
{var3}", + ("kinds", "angular_momentum"): r"{var1}-{var3}", + } + + key = key_formats.get((group_tag, plot_tag)) + if key is not None: + return key.format( + var=atom_position, + var1=kind_name, + var2=orbital_name_plotly, + var3=orbital_angular_momentum, + ) + else: + return None + + for orbital, pdos, energy in projections.get_pdos(): + orbital_data = orbital.get_orbital_dict() + kind_name = orbital_data["kind_name"] + atom_position = [round(i, 2) for i in orbital_data["position"]] + + if atom_position not in list_positions: + list_positions.append(atom_position) + + try: + orbital_name = orbital.get_name_from_quantum_numbers( + orbital_data["angular_momentum"], orbital_data["magnetic_number"] + ).lower() + orbital_name_plotly = HTML_TAGS.get(orbital_name, orbital_name) + orbital_angular_momentum = orbital_name[0] + except AttributeError: + orbital_name = "j {j} l {l} m_j{m_j}".format( + j=orbital_data["total_angular_momentum"], + l=orbital_data["angular_momentum"], + m_j=orbital_data["magnetic_number"], + ) + orbital_name_plotly = "j={j} l={l} mj={m_j}".format( + j=HTML_TAGS.get( + orbital_data["total_angular_momentum"], + orbital_data["total_angular_momentum"], + ), + l=orbital_data["angular_momentum"], + m_j=HTML_TAGS.get( + orbital_data["magnetic_number"], orbital_data["magnetic_number"] + ), + ) + orbital_angular_momentum = "l {l} ".format( + l=orbital_data["angular_momentum"], + ) + + if not selected_atoms: + key = get_key( + group_tag, + plot_tag, + atom_position, + kind_name, + orbital_name_plotly, + orbital_angular_momentum, + ) + + if key: + _pdos.setdefault(key, [energy, 0])[1] += pdos + + else: + try: + index = list_positions.index(atom_position) + if index in selected_atoms: + key = get_key( + group_tag, + plot_tag, + atom_position, + kind_name, + orbital_name_plotly, + orbital_angular_momentum, + ) + + if key: + _pdos.setdefault(key, [energy, 0])[1] += pdos + + except ValueError: + pass + + dos = [] + for label, (energy, pdos) in _pdos.items(): + if spin_type == "down": + pdos = -pdos + label += SPIN_LABELS[spin_type] + + if spin_type == "up": + label += SPIN_LABELS[spin_type] + + orbital_pdos = { + "label": label, + "x": energy.tolist(), + "y": pdos.tolist(), + "borderColor": cmap(label), + "lineStyle": line_style, + } + dos.append(orbital_pdos) + + return dos + + +def export_bands_data(outputs, fermi_energy=None): + if "band_structure" not in outputs: + return None + + data = json.loads(outputs.band_structure._exportcontent("json", comments=False)[0]) + # The fermi energy from band calculation is not robust. + data["fermi_energy"] = outputs.band_parameters["fermi_energy"] or fermi_energy + data["pathlabels"] = get_bands_labeling(data) + return data + + +def get_bands_labeling(bandsdata: dict) -> list: + """Function to return two lists containing the labels and values (kpoint) for plotting. + params: + - bandsdata: dictionary from export_bands_data function + output: update bandsdata with a new key "pathlabels" including (list of str), label_values (list of float) + """ + UNICODE_SYMBOL = { + "GAMMA": "\u0393", + "DELTA": "\u0394", + "LAMBDA": "\u039B", + "SIGMA": "\u03A3", + "EPSILON": "\u0395", + } + paths = bandsdata.get("paths") + labels = [] + for path in paths: # Remove duplicates + label_a = [path["from"], path["x"][0]] + label_b = [path["to"], path["x"][-1]] + if label_a not in labels: + labels.append(label_a) + if label_b not in labels: + labels.append(label_b) + + clean_labels = [] # Format + for i in labels: + if clean_labels: + if (i not in clean_labels) and (clean_labels[-1][-1] == i[1]): + clean_labels[-1][0] = clean_labels[-1][0] + "|" + i[0] + else: + clean_labels.append(i) + else: + clean_labels.append(i) + + path_labels = [label[0] for label in clean_labels] + for i, label in enumerate(path_labels): + path_labels[i] = re.sub( + r"([A-Z]+)", lambda x: UNICODE_SYMBOL.get(x.group(), x.group()), label + ) + path_values = [label[1] for label in clean_labels] + return [path_labels, path_values] + + +def cmap(label: str) -> str: + """Return RGB string of color for given pseudo info + Hardcoded at the momment. + """ + import random + + # if a unknow type generate random color based on ascii sum + ascn = sum([ord(c) for c in label]) + random.seed(ascn) + + return "#%06x" % random.randint(0, 0xFFFFFF) diff --git a/src/aiidalab_qe/plugins/bands/result.py b/src/aiidalab_qe/plugins/bands/result.py index 905aee852..e24c76d7e 100644 --- a/src/aiidalab_qe/plugins/bands/result.py +++ b/src/aiidalab_qe/plugins/bands/result.py @@ -1,30 +1,10 @@ """Bands results view widgets """ - - +from aiidalab_qe.common.bandpdoswidget import BandPdosWidget from aiidalab_qe.common.panel import ResultPanel -def export_bands_data(outputs, fermi_energy=None): - """Export the bands data from the outputs of the calculation.""" - import json - - from monty.json import jsanitize - - if "band_structure" in outputs: - data = json.loads( - outputs.band_structure._exportcontent("json", comments=False)[0] - ) - # The fermi energy from band calculation is not robust. - data["fermi_level"] = fermi_energy or outputs.band_parameters["fermi_energy"] - return [ - jsanitize(data), - ] - else: - return None - - class Result(ResultPanel): """Result panel for the bands calculation.""" @@ -35,12 +15,13 @@ def __init__(self, node=None, **kwargs): super().__init__(node=node, **kwargs) def _update_view(self): - from widget_bandsplot import BandsPlotWidget + # Check if the workchain has the outputs + try: + bands_node = self.node.outputs.bands + except AttributeError: + bands_node = None - bands_data = export_bands_data(self.outputs.bands) - _bands_plot_view = BandsPlotWidget( - bands=bands_data, - ) + _bands_plot_view = BandPdosWidget(bands=bands_node) self.children = [ _bands_plot_view, ] diff --git a/src/aiidalab_qe/plugins/electronic_structure/result.py b/src/aiidalab_qe/plugins/electronic_structure/result.py index 2a5d652a3..1142ae450 100644 --- a/src/aiidalab_qe/plugins/electronic_structure/result.py +++ b/src/aiidalab_qe/plugins/electronic_structure/result.py @@ -1,234 +1,28 @@ """Electronic structure results view widgets""" -import json -import random - -import ipywidgets as ipw -from aiida import orm -from monty.json import jsanitize -from widget_bandsplot import BandsPlotWidget +from aiidalab_qe.common.bandpdoswidget import BandPdosWidget from aiidalab_qe.common.panel import ResultPanel -def export_data(work_chain_node, group_dos_by="atom"): - dos = export_pdos_data(work_chain_node, group_dos_by=group_dos_by) - fermi_energy = dos["fermi_energy"] if dos else None - - bands = export_bands_data(work_chain_node, fermi_energy) - - return dict( - bands=bands, - dos=dos, - ) - - -def export_pdos_data(work_chain_node, group_dos_by="atom"): - if "pdos" in work_chain_node.outputs: - _, energy_dos, _ = work_chain_node.outputs.pdos.dos.output_dos.get_x() - tdos_values = { - f"{n}": v for n, v, _ in work_chain_node.outputs.pdos.dos.output_dos.get_y() - } - - dos = [] - - if "projections" in work_chain_node.outputs.pdos.projwfc: - # The total dos parsed - tdos = { - "label": "Total DOS", - "x": energy_dos.tolist(), - "y": tdos_values.get("dos").tolist(), - "borderColor": "#8A8A8A", # dark gray - "backgroundColor": "#999999", # light gray - "backgroundAlpha": "40%", - "lineStyle": "solid", - } - dos.append(tdos) - - dos += _projections_curated( - work_chain_node.outputs.pdos.projwfc.projections, - group_dos_by=group_dos_by, - spin_type="none", - ) - - else: - # The total dos parsed - tdos_up = { - "label": "Total DOS (↑)", - "x": energy_dos.tolist(), - "y": tdos_values.get("dos_spin_up").tolist(), - "borderColor": "#8A8A8A", # dark gray - "backgroundColor": "#999999", # light gray - "backgroundAlpha": "40%", - "lineStyle": "solid", - } - tdos_down = { - "label": "Total DOS (↓)", - "x": energy_dos.tolist(), - "y": (-tdos_values.get("dos_spin_down")).tolist(), # minus - "borderColor": "#8A8A8A", # dark gray - "backgroundColor": "#999999", # light gray - "backgroundAlpha": "40%", - "lineStyle": "dash", - } - dos += [tdos_up, tdos_down] - - # spin-up (↑) - dos += _projections_curated( - work_chain_node.outputs.pdos.projwfc.projections_up, - group_dos_by=group_dos_by, - spin_type="up", - ) - - # spin-dn (↓) - dos += _projections_curated( - work_chain_node.outputs.pdos.projwfc.projections_down, - group_dos_by=group_dos_by, - spin_type="down", - line_style="dash", - ) - - data_dict = { - "fermi_energy": work_chain_node.outputs.pdos.nscf.output_parameters[ - "fermi_energy" - ], - "dos": dos, - } - - return json.loads(json.dumps(data_dict)) - - else: - return None - - -def export_bands_data(work_chain_node, fermi_energy=None): - if "bands" in work_chain_node.outputs: - data = json.loads( - work_chain_node.outputs.bands.band_structure._exportcontent( - "json", comments=False - )[0] - ) - # The fermi energy from band calculation is not robust. - data["fermi_level"] = ( - fermi_energy - or work_chain_node.outputs.bands.band_parameters["fermi_energy"] - ) - return [ - jsanitize(data), - ] - else: - return None - - -def _projections_curated( - projections: orm.ProjectionData, - group_dos_by="atom", - spin_type="none", - line_style="solid", -): - """Collect the data from ProjectionData and parse it as dos list which can be - understand by bandsplot widget. `group_dos_by` is for which tag to be grouped, by atom or by orbital name. - The spin_type is used to invert all the y values of pdos to be shown as spin down pdos and to set label. - """ - _pdos = {} - - for orbital, pdos, energy in projections.get_pdos(): - orbital_data = orbital.get_orbital_dict() - kind_name = orbital_data["kind_name"] - atom_position = [round(i, 2) for i in orbital_data["position"]] - orbital_name = orbital.get_name_from_quantum_numbers( - orbital_data["angular_momentum"], orbital_data["magnetic_number"] - ).lower() - - if group_dos_by == "atom": - dos_group_name = atom_position - elif group_dos_by == "angular": - # by orbital label - dos_group_name = orbital_name[0] - elif group_dos_by == "angular_and_magnetic": - # by orbital label - dos_group_name = orbital_name - else: - raise Exception(f"Unknow dos type: {group_dos_by}!") - - key = f"{kind_name}-{dos_group_name}" - if key in _pdos: - _pdos[key][1] += pdos - else: - _pdos[key] = [energy, pdos] - - dos = [] - for label, (energy, pdos) in _pdos.items(): - if spin_type == "down": - # invert y-axis - pdos = -pdos - label = f"{label} (↓)" - - if spin_type == "up": - label = f"{label} (↑)" - - orbital_pdos = { - "label": label, - "x": energy.tolist(), - "y": pdos.tolist(), - "borderColor": cmap(label), - "lineStyle": line_style, - } - dos.append(orbital_pdos) - - return dos - - -def cmap(label: str) -> str: - """Return RGB string of color for given pseudo info - Hardcoded at the momment. - """ - # if a unknow type generate random color based on ascii sum - ascn = sum([ord(c) for c in label]) - random.seed(ascn) - - return "#%06x" % random.randint(0, 0xFFFFFF) - - class Result(ResultPanel): title = "Electronic Structure" workchain_labels = ["bands", "pdos"] def __init__(self, node=None, **kwargs): - self.dos_group_label = ipw.Label( - "DOS grouped by:", - layout=ipw.Layout(justify_content="flex-start", width="120px"), - ) - self.group_dos_by = ipw.ToggleButtons( - options=[ - ("Atom", "atom"), - ("Orbital", "angular"), - ], - value="atom", - ) - self.settings = ipw.HBox( - children=[ - self.dos_group_label, - self.group_dos_by, - ], - layout={"margin": "0 0 30px 30px"}, - ) - self.group_dos_by.observe(self._observe_group_dos_by, names="value") super().__init__(node=node, **kwargs) - def _observe_group_dos_by(self, change): - """Update the view of the widget when the group_dos_by value changes.""" - self._update_view() - def _update_view(self): """Update the view of the widget.""" # - data = export_data(self.node, group_dos_by=self.group_dos_by.value) - _bands_plot_view = BandsPlotWidget( - bands=data.get("bands", None), - dos=data.get("dos", None), - ) + try: + pdos_node = self.node.outputs.pdos + except AttributeError: + pdos_node = None + + try: + bands_node = self.node.outputs.bands + except AttributeError: + bands_node = None + _bands_dos_widget = BandPdosWidget(bands=bands_node, pdos=pdos_node) # update the electronic structure tab - self.children = [ - self.settings, - _bands_plot_view, - ] + self.children = [_bands_dos_widget] diff --git a/src/aiidalab_qe/plugins/pdos/result.py b/src/aiidalab_qe/plugins/pdos/result.py index b57be09ef..db46484b8 100644 --- a/src/aiidalab_qe/plugins/pdos/result.py +++ b/src/aiidalab_qe/plugins/pdos/result.py @@ -2,160 +2,10 @@ """ -import ipywidgets as ipw -from aiida.orm import ProjectionData - +from aiidalab_qe.common.bandpdoswidget import BandPdosWidget from aiidalab_qe.common.panel import ResultPanel -def cmap(label: str) -> str: - """Return RGB string of color for given pseudo info - Hardcoded at the momment. - """ - import random - - # if a unknow type generate random color based on ascii sum - ascn = sum([ord(c) for c in label]) - random.seed(ascn) - - return "#%06x" % random.randint(0, 0xFFFFFF) - - -def _projections_curated( - projections: ProjectionData, - group_dos_by="atom", - spin_type="none", - line_style="solid", -): - """Collect the data from ProjectionData and parse it as dos list which can be - understand by bandsplot widget. `group_dos_by` is for which tag to be grouped, by atom or by orbital name. - The spin_type is used to invert all the y values of pdos to be shown as spin down pdos and to set label. - """ - _pdos = {} - - for orbital, pdos, energy in projections.get_pdos(): - orbital_data = orbital.get_orbital_dict() - kind_name = orbital_data["kind_name"] - atom_position = [round(i, 2) for i in orbital_data["position"]] - orbital_name = orbital.get_name_from_quantum_numbers( - orbital_data["angular_momentum"], orbital_data["magnetic_number"] - ).lower() - - if group_dos_by == "atom": - dos_group_name = atom_position - elif group_dos_by == "angular": - # by orbital label - dos_group_name = orbital_name[0] - elif group_dos_by == "angular_and_magnetic": - # by orbital label - dos_group_name = orbital_name - else: - raise Exception(f"Unknow dos type: {group_dos_by}!") - - key = f"{kind_name}-{dos_group_name}" - if key in _pdos: - _pdos[key][1] += pdos - else: - _pdos[key] = [energy, pdos] - - dos = [] - for label, (energy, pdos) in _pdos.items(): - if spin_type == "down": - # invert y-axis - pdos = -pdos - label = f"{label} (↓)" - - if spin_type == "up": - label = f"{label} (↑)" - - orbital_pdos = { - "label": label, - "x": energy.tolist(), - "y": pdos.tolist(), - "borderColor": cmap(label), - "lineStyle": line_style, - } - dos.append(orbital_pdos) - - return dos - - -def export_pdos_data(outputs, group_dos_by="atom"): - import json - - if "output_dos" in outputs.dos: - _, energy_dos, _ = outputs.dos.output_dos.get_x() - tdos_values = {f"{n}": v for n, v, _ in outputs.dos.output_dos.get_y()} - - dos = [] - - if "projections" in outputs.projwfc: - # The total dos parsed - tdos = { - "label": "Total DOS", - "x": energy_dos.tolist(), - "y": tdos_values.get("dos").tolist(), - "borderColor": "#8A8A8A", # dark gray - "backgroundColor": "#999999", # light gray - "backgroundAlpha": "40%", - "lineStyle": "solid", - } - dos.append(tdos) - - dos += _projections_curated( - outputs.projwfc.projections, - group_dos_by=group_dos_by, - spin_type="none", - ) - - else: - # The total dos parsed - tdos_up = { - "label": "Total DOS (↑)", - "x": energy_dos.tolist(), - "y": tdos_values.get("dos_spin_up").tolist(), - "borderColor": "#8A8A8A", # dark gray - "backgroundColor": "#999999", # light gray - "backgroundAlpha": "40%", - "lineStyle": "solid", - } - tdos_down = { - "label": "Total DOS (↓)", - "x": energy_dos.tolist(), - "y": (-tdos_values.get("dos_spin_down")).tolist(), # minus - "borderColor": "#8A8A8A", # dark gray - "backgroundColor": "#999999", # light gray - "backgroundAlpha": "40%", - "lineStyle": "dash", - } - dos += [tdos_up, tdos_down] - - # spin-up (↑) - dos += _projections_curated( - outputs.projwfc.projections_up, - group_dos_by=group_dos_by, - spin_type="up", - ) - - # spin-dn (↓) - dos += _projections_curated( - outputs.projwfc.projections_down, - group_dos_by=group_dos_by, - spin_type="down", - line_style="dash", - ) - - data_dict = { - "fermi_energy": outputs.nscf.output_parameters["fermi_energy"], - "dos": dos, - } - - return json.loads(json.dumps(data_dict)) - - else: - return None - - class Result(ResultPanel): title = "PDOS" workchain_labels = ["pdos"] @@ -165,52 +15,13 @@ def __init__(self, node=None, **kwargs): def _update_view(self): """Update the view of the widget.""" - from widget_bandsplot import BandsPlotWidget - group_dos_by = ipw.ToggleButtons( - options=[ - ("Atom", "atom"), - ("Orbital", "angular"), - ], - value="atom", - ) - settings = ipw.VBox( - children=[ - ipw.HBox( - children=[ - ipw.Label( - "DOS grouped by:", - layout=ipw.Layout( - justify_content="flex-start", width="120px" - ), - ), - group_dos_by, - ] - ), - ], - layout={"margin": "0 0 30px 30px"}, - ) - # - dos_data = export_pdos_data(self.outputs.pdos, group_dos_by=group_dos_by.value) - _bands_plot_view = BandsPlotWidget( - dos=dos_data, - ) + try: + pdos_node = self.node.outputs.pdos + except AttributeError: + pdos_node = None - def response(change): - dos_data = export_pdos_data( - self.outputs.pdos, group_dos_by=group_dos_by.value - ) - _bands_plot_view = BandsPlotWidget( - dos=dos_data, - ) - self.children = [ - settings, - _bands_plot_view, - ] + _pdos_plot_view = BandPdosWidget(pdos=pdos_node) - group_dos_by.observe(response, names="value") # update the electronic structure tab - self.children = [ - settings, - _bands_plot_view, - ] + self.children = [_pdos_plot_view] diff --git a/tests/test_plugins_bands.py b/tests/test_plugins_bands.py index a819f1ea8..a586ca285 100644 --- a/tests/test_plugins_bands.py +++ b/tests/test_plugins_bands.py @@ -3,17 +3,35 @@ @pytest.mark.usefixtures("sssp") def test_result(generate_qeapp_workchain): - from widget_bandsplot import BandsPlotWidget - - from aiidalab_qe.plugins.bands.result import Result, export_bands_data + from aiidalab_qe.common.bandpdoswidget import BandPdosWidget + import plotly.graph_objects as go + from aiidalab_qe.plugins.bands.result import Result wkchain = generate_qeapp_workchain() - data = export_bands_data(wkchain.node.outputs.bands) - assert data is not None # generate structure for scf calculation result = Result(wkchain.node) result._update_view() - assert isinstance(result.children[0], BandsPlotWidget) + assert isinstance(result.children[0], BandPdosWidget) + assert isinstance(result.children[0].bandsplot_widget, go.FigureWidget) + + # Check if data is correct + assert result.children[0].bands_data is not None + assert result.children[0].bands_data["pathlabels"] is not None + assert result.children[0].dos_data is None + + # Check Bands axis + assert result.children[0].bandsplot_widget.layout.xaxis.title.text == "k-points" + assert ( + result.children[0].bandsplot_widget.layout.yaxis.title.text + == "Electronic Bands (eV)" + ) + assert isinstance( + result.children[0].bandsplot_widget.layout.xaxis.rangeslider, + go.layout.xaxis.Rangeslider, + ) + assert result.children[0].bands_data["pathlabels"][0] == list( + result.children[0].bandsplot_widget.layout.xaxis.ticktext + ) @pytest.mark.usefixtures("sssp") diff --git a/tests/test_plugins_electronic_structure.py b/tests/test_plugins_electronic_structure.py index c9c107f43..8e209cede 100644 --- a/tests/test_plugins_electronic_structure.py +++ b/tests/test_plugins_electronic_structure.py @@ -3,6 +3,9 @@ def test_electronic_structure(generate_qeapp_workchain): from aiida import engine from aiidalab_qe.app.result.workchain_viewer import WorkChainViewer + from aiidalab_qe.common.bandpdoswidget import BandPdosWidget + import plotly.graph_objects as go + from aiidalab_qe.plugins.electronic_structure.result import Result wkchain = generate_qeapp_workchain() wkchain.node.set_exit_status(0) @@ -16,5 +19,34 @@ def test_electronic_structure(generate_qeapp_workchain): for tab in wcv.result_tabs.children if getattr(tab, "identifier", "") == "electronic_structure" ][0] - # It should have two children: settings and the _bands_plot_view - assert len(tab.children) == 2 + # It should have one children: the _bands_plot_view + assert len(tab.children) == 1 + + result = Result(node=wkchain.node) + result._update_view() + + assert isinstance(result.children[0], BandPdosWidget) + assert isinstance(result.children[0].bandsplot_widget, go.FigureWidget) + + # Check if data is correct + assert result.children[0].bands_data is not None + assert result.children[0].bands_data["pathlabels"] is not None + assert result.children[0].dos_data is not None + + # Check Bands axis + assert result.children[0].bandsplot_widget.layout.xaxis.title.text == "k-points" + assert ( + result.children[0].bandsplot_widget.layout.xaxis2.title.text + == "Density of states" + ) + assert ( + result.children[0].bandsplot_widget.layout.yaxis.title.text + == "Electronic Bands (eV)" + ) + assert isinstance( + result.children[0].bandsplot_widget.layout.xaxis.rangeslider, + go.layout.xaxis.Rangeslider, + ) + assert result.children[0].bands_data["pathlabels"][0] == list( + result.children[0].bandsplot_widget.layout.xaxis.ticktext + ) diff --git a/tests/test_plugins_pdos.py b/tests/test_plugins_pdos.py index b625143a9..354c67e7f 100644 --- a/tests/test_plugins_pdos.py +++ b/tests/test_plugins_pdos.py @@ -3,38 +3,26 @@ @pytest.mark.usefixtures("sssp") def test_result(generate_qeapp_workchain): - from aiidalab_qe.plugins.pdos.result import Result, export_pdos_data + from aiidalab_qe.common.bandpdoswidget import BandPdosWidget + import plotly.graph_objects as go + from aiidalab_qe.plugins.pdos.result import Result wkchain = generate_qeapp_workchain() - data = export_pdos_data(wkchain.node.outputs.pdos) - assert data is not None # generate structure for scf calculation result = Result(node=wkchain.node) result._update_view() - assert len(result.children) == 2 + assert isinstance(result.children[0], BandPdosWidget) + assert isinstance(result.children[0].bandsplot_widget, go.FigureWidget) + # Check if data is correct + assert result.children[0].bands_data is None + assert result.children[0].dos_data is not None -@pytest.mark.usefixtures("sssp") -def test_result_spin(generate_qeapp_workchain): - from aiidalab_qe.plugins.pdos.result import Result, export_pdos_data - - wkchain = generate_qeapp_workchain(spin_type="collinear") - data = export_pdos_data(wkchain.node.outputs.pdos) - assert data is not None - # generate structure for scf calculation - result = Result(node=wkchain.node) - result._update_view() - assert len(result.children) == 2 - + # Check PDOS settings is not None -@pytest.mark.usefixtures("sssp") -def test_result_group_by(generate_qeapp_workchain): - from aiidalab_qe.plugins.pdos.result import Result, export_pdos_data - - wkchain = generate_qeapp_workchain() - data = export_pdos_data(wkchain.node.outputs.pdos) - assert data is not None - # generate structure for scf calculation - result = Result(node=wkchain.node) - result._update_view() - result.children[0].children[0].children[1].value = "angular" + # Check Bands axis + assert ( + result.children[0].bandsplot_widget.layout.xaxis.title.text + == "Density of states (eV)" + ) + assert result.children[0].bandsplot_widget.layout.yaxis.title.text is None