Skip to content

Commit

Permalink
Change colorbar labels
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthak-dv committed Jun 27, 2024
1 parent 7dd4d00 commit 7f55bcb
Showing 1 changed file with 23 additions and 91 deletions.
114 changes: 23 additions & 91 deletions tardis/visualization/tools/interaction_radius_plot.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as clr
import plotly.graph_objects as go
import numpy as np
import pandas as pd
import astropy.units as u

from tardis.util.base import (
atomic_number2element_symbol,
element_symbol2atomic_number,
species_string_to_tuple,
roman_to_int,
int_to_roman,
)
import tardis.visualization.tools.sdec_plot as sdec
Expand Down Expand Up @@ -129,9 +125,29 @@ def _make_colorbar_labels(self):
If a species list is provided, uses that to generate labels.
Otherwise, generates labels from the species in the model.
"""
self.sdec_plotter.species = self.species
self.sdec_plotter._make_colorbar_labels()
self._species_name = self.sdec_plotter._species_name
if self._species_list is None:
# If species_list is none then the labels are just elements
species_name = [
atomic_number2element_symbol(atomic_num)
for atomic_num in self.species
]
else:
species_name = []
for species_key in self._species_mapped.keys():
species_ids = self._species_mapped[species_key]
if any(species in self.species for species in species_ids):
# If the key is for an element, label by element symbol
if species_key % 100 == 0:
label = atomic_number2element_symbol(species_key // 100)
else:
# If the key is for a specific ion, label by element and ion
atomic_number = species_key // 100
ion_number = species_key % 100
ion_numeral = int_to_roman(ion_number + 1)
label = f"{atomic_number2element_symbol(atomic_number)} {ion_numeral}"
species_name.append(label)

self._species_name = species_name

def _make_colorbar_colors(self):
"""
Expand Down Expand Up @@ -257,8 +273,6 @@ def generate_plot_mpl(
self.cmap = cm.get_cmap(cmapname, len(self._species_name))
# Get the number of unique colors
self._make_colorbar_colors()
# Show colorbar
# self._show_colorbar_mpl()

plot_data, plot_colors = self._generate_plot_data(packets_mode)
bin_edges = (self.velocity).to("km/s")
Expand Down Expand Up @@ -289,26 +303,6 @@ def generate_plot_mpl(

return self.ax

# def _show_colorbar_mpl(self):
# """Show matplotlib colorbar with labels of elements mapped to colors."""

# color_values = [
# self.cmap(species_counter / len(self._species_name))
# for species_counter in range(len(self._species_name))
# ]

# custcmap = clr.ListedColormap(color_values)
# norm = clr.Normalize(vmin=0, vmax=len(self._species_name))
# mappable = cm.ScalarMappable(norm=norm, cmap=custcmap)
# mappable.set_array(np.linspace(1, len(self._species_name) + 1, 256))
# cbar = plt.colorbar(mappable, ax=self.ax)

# bounds = np.arange(len(self._species_name)) + 0.5
# cbar.set_ticks(bounds)

# cbar.set_ticklabels(self._species_name)
# return

def generate_plot_ply(
self,
packets_mode="virtual",
Expand Down Expand Up @@ -383,66 +377,4 @@ def generate_plot_ply(
xaxis=dict(tickformat=".0f", tickmode="auto"),
)

# fig = self._show_colorbar_ply(fig)
return fig

# def _show_colorbar_ply(self, fig):
# """
# Show plotly colorbar with labels of elements mapped to colors.

# Parameters
# ----------
# fig : plotly.graph_objects.Figure
# Plotly figure object to add the colorbar to.

# Returns
# -------
# plotly.graph_objects.Figure
# Plotly figure object with the colorbar added.
# """
# # Interpolate [0, 1] range to create bins equal to number of elements
# colorscale_bins = np.linspace(0, 1, num=len(self._species_name) + 1)

# # Create a categorical colorscale [a list of (reference point, color)]
# # by mapping same reference points (excluding 1st and last bin edge)
# # twice in a row (https://plotly.com/python/colorscales/#constructing-a-discrete-or-discontinuous-color-scale)
# categorical_colorscale = []
# for species_counter in range(len(self._species_name)):
# color = pu.to_rgb255_string(
# self.cmap(colorscale_bins[species_counter])
# )
# categorical_colorscale.append(
# (colorscale_bins[species_counter], color)
# )
# categorical_colorscale.append(
# (colorscale_bins[species_counter + 1], color)
# )

# coloraxis_options = {
# "colorscale": categorical_colorscale,
# "showscale": True,
# "cmin": 0,
# "cmax": len(self._species_name),
# "colorbar": {
# "title": "Elements",
# "tickvals": np.arange(0, len(self._species_name)) + 0.5,
# "ticktext": self._species_name,
# # to change length and position of colorbar
# "len": 1,
# "yanchor": "top",
# "y": 1,
# },
# }

# colorbar_trace = go.Scatter(
# x=[None],
# y=[0],
# mode="markers",
# name="Colorbar",
# showlegend=False,
# hoverinfo="skip",
# marker=dict(color=[0], opacity=0, **coloraxis_options),
# )

# fig.add_trace(colorbar_trace)
# return fig

0 comments on commit 7f55bcb

Please sign in to comment.