diff --git a/docs/io/output/vpacket_logging.rst b/docs/io/output/vpacket_logging.rst index 308ce6af409..12c9f0cb18d 100644 --- a/docs/io/output/vpacket_logging.rst +++ b/docs/io/output/vpacket_logging.rst @@ -42,6 +42,9 @@ After running the simulation, the following information can be retrieved: * - ``transport.virt_packet_last_interaction_in_nu`` - Numpy array - Frequencies of the r-packets which spawned the virtual packet + * - ``transport.virt_packet_last_interaction_in_r`` + - Numpy array + - Radii of the r-packets which spawned the virtual packet * - ``transport.virt_packet_last_line_interaction_in_id`` - Numpy array - | If the last interaction was a line interaction, the diff --git a/docs/io/visualization/how_to_liv_plot.ipynb b/docs/io/visualization/how_to_liv_plot.ipynb new file mode 100644 index 00000000000..795b9090277 --- /dev/null +++ b/docs/io/visualization/how_to_liv_plot.ipynb @@ -0,0 +1,529 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to Generate a Last Interaction Velocity (LIV) Plot\n", + "The Last Interaction Velocity Plot tracks and display the velocities at which different elements (or species) last interacted with packets in the simulation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, create and run a simulation for which you want to generate this plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tardis import run_tardis\n", + "from tardis.io.atom_data.util import download_atom_data\n", + "\n", + "# We download the atomic data needed to run the simulation\n", + "download_atom_data('kurucz_cd23_chianti_H_He')\n", + "\n", + "sim = run_tardis(\"tardis_example.yml\", virtual_packet_logging=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "Note\n", + "\n", + "The virtual packet logging capability must be active in order to produce the Last Interaction Velocity Plot for virtual packets population. Thus, make sure to set `virtual_packet_logging: True` in your configuration file if you want to generate the Last Interaction Velocity Plot with virtual packets. It should be added under the `virtual` property of the `spectrum` property, as described in the [configuration schema](https://tardis-sn.github.io/tardis/io/configuration/components/spectrum.html).\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, import the plotting interface for Last Interaction Velocity Plot, i.e. the `LIVPlotter` class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tardis.visualization.tools.liv_plot import LIVPlotter" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And create a plotter object to process the data of simulation object `sim` for generating the Last Interaction Velocity plot." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter = LIVPlotter.from_simulation(sim)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Static Plot (in matplotlib)\n", + "You can now call the `generate_plot_mpl()` method on your plotter object to create a highly informative and visually appealing Last Interaction Velocity plot using matplotlib." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Virtual packets mode\n", + "By default, a Last Interaction Velocity plot is produced for the virtual packet population of the simulation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Real packets mode\n", + "You can produce a Last Interaction Velocity plot for the real packet population of the simulation by setting `packets_mode=\"real\"` which is `\"virtual\"` by default." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl(packets_mode=\"real\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting only the top contributing elements" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `nelements` option allows you to plot the top contributing elements to the spectrum. Only the top elements are shown in the plot. Please note this works only for elements and not for ions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl(nelements=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Choosing what elements/ions to plot\n", + "\n", + "You can pass a `species_list` for the species you want plotted in the Last Interaction Velocity Plot. Valid options include elements (e.g., Si), ions (specified in Roman numeral format, e.g., Si II), a range of ions (e.g., Si I-III), or any combination of these." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl(species_list = [\"Si I-III\", \"O\", \"Ca\", \"S\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When using both the `nelements` and the `species_list` options, `species_list` takes precedence. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl(species_list = [\"Si I-III\", \"Ca\", \"S\"], nelements=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting a specific number of bins\n", + "You can regroup the bins with broader or narrower widths within the same velocity range using `num_bins`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl(species_list = [\"Si I-III\", \"O\", \"Ca\", \"S\"], num_bins=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting on the Log Scale\n", + "You can plot on the log scale on x-axis using `xlog_scale=True` and on y-axis using `ylog_scale=True` by default both are set to `False` which plots on a linear scale." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl(species_list = [\"Si I-III\", \"O\", \"Ca\", \"S\"], xlog_scale=True, ylog_scale=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting a specific velocity range\n", + "You can restrict the range of bins to plot in the Last Interaction Velocity Plot by specifying a valid `velocity_range`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_mpl(species_list = [\"Si I-III\", \"O\", \"Ca\", \"S\"], velocity_range=(12500, 15050))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Additional plotting options" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# To list all available options (or parameters) with their description\n", + "help(plotter.generate_plot_mpl)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `generate_plot_mpl` method also has options specific to the matplotlib API, thereby providing you with more control over how your last interaction velocity looks. Possible cases where you may use them are:\n", + "\n", + "- `ax`: To plot on an Axis of a plot you're already working with, e.g. for subplots.\n", + "\n", + "- `figsize`: To resize the plot as per your requirements.\n", + "\n", + "- `cmapname`: To use a colormap of your preference, instead of \"jet\"." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Interactive Plot (in plotly)\n", + "If you're using the Last Interaction Velocity plot for exploration, consider creating an interactive version with `generate_plot_ply()`. This allows you to zoom, pan, inspect data values by hovering, resize the scale, and more conveniently.\n", + "\n", + "\n", + "\n", + "**This method takes the same arguments as `generate_plot_mpl` except for a few specific to the Plotly library.** You can produce all the plots shown above in Plotly by passing the same arguments." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Virtual packets mode\n", + "By default, a Last Interaction Velocity plot is produced for the virtual packet population of the simulation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Real packets mode\n", + "You can produce a Last Interaction Velocity plot for the real packet population of the simulation by setting `packets_mode=\"real\"` which is `\"virtual\"` by default." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply(packets_mode=\"real\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting only the top contributing elements" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `nelements` option allows you to plot the top contributing elements to the spectrum. Only the top elements are shown in the plot. Please note this works only for elements and not for ions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply(nelements=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Choosing what elements/ions to plot\n", + "\n", + "You can pass a `species_list` for the species you want plotted in the Last Interaction Velocity Plot. Valid options include elements (e.g., Si), ions (specified in Roman numeral format, e.g., Si II), a range of ions (e.g., Si I-III), or any combination of these." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply(species_list = [\"Si I-III\", \"Ca\", \"S\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When using both the `nelements` and the `species_list` options, `species_list` takes precedence. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply(species_list = [\"Si I-III\", \"Ca\", \"S\"], nelements=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting a specific number of bins\n", + "You can regroup the bins with broader and narrower widths within the same velocity range using `num_bins`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply(species_list = [\"Si I-III\", \"Ca\", \"S\"], num_bins=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting on the Log Scale\n", + "You can plot on the log scale on x-axis using `xlog_scale=True` and on y-axis using `ylog_scale=True` by default both are set to `False`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply(species_list = [\"Si I-III\", \"Ca\", \"S\"], xlog_scale=True, ylog_scale=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plotting a specific velocity range\n", + "You can restrict the range of bins to plot in the Last Interaction Velocity Plot by specifying a valid `velocity_range`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotter.generate_plot_ply(species_list = [\"Si I-III\", \"Ca\", \"S\"], velocity_range=(12500, 15050))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Additional plotting options" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# To list all available options (or parameters) with their description\n", + "help(plotter.generate_plot_ply)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "The `generate_plot_ply` method also has options specific to the plotly API, thereby providing you with more control over how your last interaction velocity plot looks. Possible cases where you may use them are:\n", + "\n", + " - `fig`: To plot the last interaction velocity plot on a figure you are already using e.g. for subplots.\n", + "\n", + " - `graph_height`: To specify the height of the graph as needed.\n", + " \n", + " - `cmapname`: To use a colormap of your preference instead of \"jet\"." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using simulation saved as HDF\n", + "Other than producing the Last Interaction Velocity Plot for simulation objects in runtime, you can also produce it for saved TARDIS simulations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# hdf_plotter = LIVPlotter.from_hdf(\"demo.h5\") ## Files is too large - just as an example" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This `hdf_plotter` object is similar to the `plotter` object we used above, **so you can use each plotting method demonstrated above with this too.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Static plot with virtual packets mode\n", + "# hdf_plotter.generate_plot_mpl()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Static plot with real packets mode\n", + "#hdf_plotter.generate_plot_mpl(packets_mode=\"real\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interactive plot with virtual packets mode and specific list of species\n", + "# hdf_plotter.generate_plot_ply(species_list=[\"Si I-III\", \"Ca\", \"O\", \"S\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interactive plot with virtual packets mode and regrouped bins\n", + "# hdf_plotter.generate_plot_ply(num_bins=10)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tardis", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/io/visualization/index.rst b/docs/io/visualization/index.rst index b0ff58a6f3b..c98a32caae0 100644 --- a/docs/io/visualization/index.rst +++ b/docs/io/visualization/index.rst @@ -13,6 +13,7 @@ diagnostic visualizations. :maxdepth: 2 how_to_sdec_plot + how_to_liv_plot tutorial_convergence_plot tutorial_montecarlo_packet_visualization diff --git a/tardis/transport/montecarlo/base.py b/tardis/transport/montecarlo/base.py index 909ae64b19d..fc580ac2de6 100644 --- a/tardis/transport/montecarlo/base.py +++ b/tardis/transport/montecarlo/base.py @@ -192,6 +192,7 @@ def run( ] = v_packets_energy_hist transport_state.last_interaction_type = last_interaction_tracker.types transport_state.last_interaction_in_nu = last_interaction_tracker.in_nus + transport_state.last_interaction_in_r = last_interaction_tracker.in_rs transport_state.last_line_interaction_in_id = ( last_interaction_tracker.in_ids ) diff --git a/tardis/transport/montecarlo/montecarlo_transport_state.py b/tardis/transport/montecarlo/montecarlo_transport_state.py index 98fc5bf976c..a2fd0455190 100644 --- a/tardis/transport/montecarlo/montecarlo_transport_state.py +++ b/tardis/transport/montecarlo/montecarlo_transport_state.py @@ -28,6 +28,7 @@ class MonteCarloTransportState(HDFWriterMixin): "emitted_packet_mask", "last_interaction_type", "last_interaction_in_nu", + "last_interaction_in_r", "last_line_interaction_out_id", "last_line_interaction_in_id", "last_line_interaction_shell_id", @@ -39,6 +40,7 @@ class MonteCarloTransportState(HDFWriterMixin): "virt_packet_initial_rs", "virt_packet_initial_mus", "virt_packet_last_interaction_in_nu", + "virt_packet_last_interaction_in_r", "virt_packet_last_interaction_type", "virt_packet_last_line_interaction_in_id", "virt_packet_last_line_interaction_out_id", @@ -49,6 +51,7 @@ class MonteCarloTransportState(HDFWriterMixin): last_interaction_type = None last_interaction_in_nu = None + last_interaction_in_r = None last_line_interaction_out_id = None last_line_interaction_in_id = None last_line_interaction_shell_id = None @@ -399,6 +402,20 @@ def virt_packet_last_interaction_in_nu(self): ) return None + @property + def virt_packet_last_interaction_in_r(self): + try: + return u.Quantity(self.vpacket_tracker.last_interaction_in_r, u.cm) + except AttributeError: + warnings.warn( + "MontecarloTransport.virt_packet_last_interaction_in_r:" + "Set 'virtual_packet_logging: True' in the configuration file" + "to access this property" + "It should be added under 'virtual' property of 'spectrum' property", + UserWarning, + ) + return None + @property def virt_packet_last_interaction_type(self): try: diff --git a/tardis/transport/montecarlo/packet_collections.py b/tardis/transport/montecarlo/packet_collections.py index 9745726d7bb..dde92f6f4f5 100644 --- a/tardis/transport/montecarlo/packet_collections.py +++ b/tardis/transport/montecarlo/packet_collections.py @@ -54,10 +54,12 @@ def initialize_last_interaction_tracker(no_of_packets): ) last_interaction_types = -1 * np.ones(no_of_packets, dtype=np.int64) last_interaction_in_nus = np.zeros(no_of_packets, dtype=np.float64) + last_interaction_in_rs = np.zeros(no_of_packets, dtype=np.float64) return LastInteractionTracker( last_interaction_types, last_interaction_in_nus, + last_interaction_in_rs, last_line_interaction_in_ids, last_line_interaction_out_ids, last_line_interaction_shell_ids, @@ -67,6 +69,7 @@ def initialize_last_interaction_tracker(no_of_packets): last_interaction_tracker_spec = [ ("types", int64[:]), ("in_nus", float64[:]), + ("in_rs", float64[:]), ("in_ids", int64[:]), ("out_ids", int64[:]), ("shell_ids", int64[:]), @@ -79,12 +82,14 @@ def __init__( self, types, in_nus, + in_rs, in_ids, out_ids, shell_ids, ): self.types = types self.in_nus = in_nus + self.in_rs = in_rs self.in_ids = in_ids self.out_ids = out_ids self.shell_ids = shell_ids @@ -92,6 +97,7 @@ def __init__( def update_last_interaction(self, r_packet, i): self.types[i] = r_packet.last_interaction_type self.in_nus[i] = r_packet.last_interaction_in_nu + self.in_rs[i] = r_packet.last_interaction_in_r self.in_ids[i] = r_packet.last_line_interaction_in_id self.out_ids[i] = r_packet.last_line_interaction_out_id self.shell_ids[i] = r_packet.last_line_interaction_shell_id @@ -110,6 +116,7 @@ def update_last_interaction(self, r_packet, i): ("number_of_vpackets", int64), ("length", int64), ("last_interaction_in_nu", float64[:]), + ("last_interaction_in_r", float64[:]), ("last_interaction_type", int64[:]), ("last_interaction_in_id", int64[:]), ("last_interaction_out_id", int64[:]), @@ -139,6 +146,9 @@ def __init__( self.last_interaction_in_nu = np.zeros( temporary_v_packet_bins, dtype=np.float64 ) + self.last_interaction_in_r = np.zeros( + temporary_v_packet_bins, dtype=np.float64 + ) self.last_interaction_type = -1 * np.ones( temporary_v_packet_bins, dtype=np.int64 ) @@ -162,6 +172,7 @@ def add_packet( initial_mu, initial_r, last_interaction_in_nu, + last_interaction_in_r, last_interaction_type, last_interaction_in_id, last_interaction_out_id, @@ -182,6 +193,8 @@ def add_packet( Initial r of the packet. last_interaction_in_nu : float Frequency of the last interaction of the packet. + last_interaction_in_r : float + Radius of the last interaction of the packet. last_interaction_type : int Type of the last interaction of the packet. last_interaction_in_id : int @@ -205,6 +218,7 @@ def add_packet( temp_last_interaction_in_nu = np.empty( temp_length, dtype=np.float64 ) + temp_last_interaction_in_r = np.empty(temp_length, dtype=np.float64) temp_last_interaction_type = np.empty(temp_length, dtype=np.int64) temp_last_interaction_in_id = np.empty(temp_length, dtype=np.int64) temp_last_interaction_out_id = np.empty(temp_length, dtype=np.int64) @@ -219,6 +233,9 @@ def add_packet( temp_last_interaction_in_nu[ : self.length ] = self.last_interaction_in_nu + temp_last_interaction_in_r[ + : self.length + ] = self.last_interaction_in_r temp_last_interaction_type[ : self.length ] = self.last_interaction_type @@ -237,6 +254,7 @@ def add_packet( self.initial_mus = temp_initial_mus self.initial_rs = temp_initial_rs self.last_interaction_in_nu = temp_last_interaction_in_nu + self.last_interaction_in_r = temp_last_interaction_in_r self.last_interaction_type = temp_last_interaction_type self.last_interaction_in_id = temp_last_interaction_in_id self.last_interaction_out_id = temp_last_interaction_out_id @@ -248,6 +266,7 @@ def add_packet( self.initial_mus[self.idx] = initial_mu self.initial_rs[self.idx] = initial_r self.last_interaction_in_nu[self.idx] = last_interaction_in_nu + self.last_interaction_in_r[self.idx] = last_interaction_in_r self.last_interaction_type[self.idx] = last_interaction_type self.last_interaction_in_id[self.idx] = last_interaction_in_id self.last_interaction_out_id[self.idx] = last_interaction_out_id @@ -268,6 +287,7 @@ def finalize_arrays(self): self.initial_mus = self.initial_mus[: self.idx] self.initial_rs = self.initial_rs[: self.idx] self.last_interaction_in_nu = self.last_interaction_in_nu[: self.idx] + self.last_interaction_in_r = self.last_interaction_in_r[: self.idx] self.last_interaction_type = self.last_interaction_type[: self.idx] self.last_interaction_in_id = self.last_interaction_in_id[: self.idx] self.last_interaction_out_id = self.last_interaction_out_id[: self.idx] @@ -328,6 +348,9 @@ def consolidate_vpacket_tracker( vpacket_tracker.last_interaction_in_nu[ current_start_vpacket_tracker_idx:current_end_vpacket_tracker_idx ] = vpacket_collection.last_interaction_in_nu + vpacket_tracker.last_interaction_in_r[ + current_start_vpacket_tracker_idx:current_end_vpacket_tracker_idx + ] = vpacket_collection.last_interaction_in_r vpacket_tracker.last_interaction_type[ current_start_vpacket_tracker_idx:current_end_vpacket_tracker_idx diff --git a/tardis/transport/montecarlo/r_packet.py b/tardis/transport/montecarlo/r_packet.py index 85d6bfcd87c..038468afcac 100644 --- a/tardis/transport/montecarlo/r_packet.py +++ b/tardis/transport/montecarlo/r_packet.py @@ -40,6 +40,7 @@ class PacketStatus(IntEnum): ("index", int64), ("last_interaction_type", int64), ("last_interaction_in_nu", float64), + ("last_interaction_in_r", float64), ("last_line_interaction_in_id", int64), ("last_line_interaction_out_id", int64), ("last_line_interaction_shell_id", int64), @@ -59,6 +60,7 @@ def __init__(self, r, mu, nu, energy, seed, index=0): self.index = index self.last_interaction_type = -1 self.last_interaction_in_nu = 0.0 + self.last_interaction_in_r = 0.0 self.last_line_interaction_in_id = -1 self.last_line_interaction_out_id = -1 self.last_line_interaction_shell_id = -1 diff --git a/tardis/transport/montecarlo/r_packet_transport.py b/tardis/transport/montecarlo/r_packet_transport.py index c33b0e21be5..a7a9aa01cbd 100644 --- a/tardis/transport/montecarlo/r_packet_transport.py +++ b/tardis/transport/montecarlo/r_packet_transport.py @@ -140,6 +140,7 @@ def trace_packet( if tau_trace_combined > tau_event and not disable_line_scattering: interaction_type = InteractionType.LINE # Line r_packet.last_interaction_in_nu = r_packet.nu + r_packet.last_interaction_in_r = r_packet.r r_packet.last_line_interaction_in_id = cur_line_id r_packet.last_line_interaction_shell_id = r_packet.current_shell_id r_packet.next_line_id = cur_line_id diff --git a/tardis/transport/montecarlo/tests/test_base.py b/tardis/transport/montecarlo/tests/test_base.py index 9c7ec3bffbe..3e0e0e450c8 100644 --- a/tardis/transport/montecarlo/tests/test_base.py +++ b/tardis/transport/montecarlo/tests/test_base.py @@ -53,6 +53,7 @@ def test_hdf_transport( "emitted_packet_mask", "last_interaction_type", "last_interaction_in_nu", + "last_interaction_in_r", "last_line_interaction_out_id", "last_line_interaction_in_id", "last_line_interaction_shell_id", @@ -61,6 +62,7 @@ def test_hdf_transport( "virt_packet_initial_rs", "virt_packet_initial_mus", "virt_packet_last_interaction_in_nu", + "virt_packet_last_interaction_in_r", "virt_packet_last_interaction_type", "virt_packet_last_line_interaction_in_id", "virt_packet_last_line_interaction_out_id", diff --git a/tardis/transport/montecarlo/tests/test_numba_interface.py b/tardis/transport/montecarlo/tests/test_numba_interface.py index c3d5a3f6bc4..25d907049f6 100644 --- a/tardis/transport/montecarlo/tests/test_numba_interface.py +++ b/tardis/transport/montecarlo/tests/test_numba_interface.py @@ -71,6 +71,9 @@ def test_VPacketCollection_add_packet(verysimple_3vpacket_collection): last_interaction_in_nus = np.array( [3.0e15, 0.0, 1e15, 1e5], dtype=np.float64 ) + last_interaction_in_rs = np.array( + [3e42, 4.5e45, 0, 9.0e40], dtype=np.float64 + ) last_interaction_types = np.array([1, 1, 3, 2], dtype=np.int64) last_interaction_in_ids = np.array([100, 0, 1, 1000], dtype=np.int64) last_interaction_out_ids = np.array([1201, 123, 545, 1232], dtype=np.int64) @@ -82,6 +85,7 @@ def test_VPacketCollection_add_packet(verysimple_3vpacket_collection): initial_mu, initial_r, last_interaction_in_nu, + last_interaction_in_r, last_interaction_type, last_interaction_in_id, last_interaction_out_id, @@ -92,6 +96,7 @@ def test_VPacketCollection_add_packet(verysimple_3vpacket_collection): initial_mus, initial_rs, last_interaction_in_nus, + last_interaction_in_rs, last_interaction_types, last_interaction_in_ids, last_interaction_out_ids, @@ -103,6 +108,7 @@ def test_VPacketCollection_add_packet(verysimple_3vpacket_collection): initial_mu, initial_r, last_interaction_in_nu, + last_interaction_in_r, last_interaction_type, last_interaction_in_id, last_interaction_out_id, @@ -139,6 +145,12 @@ def test_VPacketCollection_add_packet(verysimple_3vpacket_collection): ], last_interaction_in_nus, ) + npt.assert_array_equal( + verysimple_3vpacket_collection.last_interaction_in_r[ + : verysimple_3vpacket_collection.idx + ], + last_interaction_in_rs, + ) npt.assert_array_equal( verysimple_3vpacket_collection.last_interaction_type[ : verysimple_3vpacket_collection.idx diff --git a/tardis/transport/montecarlo/vpacket.py b/tardis/transport/montecarlo/vpacket.py index b4ce47ff2cc..657e1c016d0 100644 --- a/tardis/transport/montecarlo/vpacket.py +++ b/tardis/transport/montecarlo/vpacket.py @@ -359,6 +359,7 @@ def trace_vpacket_volley( v_packet_mu, r_packet.r, r_packet.last_interaction_in_nu, + r_packet.last_interaction_in_r, r_packet.last_interaction_type, r_packet.last_line_interaction_in_id, r_packet.last_line_interaction_out_id, diff --git a/tardis/visualization/__init__.py b/tardis/visualization/__init__.py index 73ccae5ce78..cadfff0cdc3 100644 --- a/tardis/visualization/__init__.py +++ b/tardis/visualization/__init__.py @@ -11,3 +11,4 @@ from tardis.visualization.widgets.custom_abundance import CustomAbundanceWidget from tardis.visualization.tools.sdec_plot import SDECPlotter from tardis.visualization.tools.rpacket_plot import RPacketPlotter +from tardis.visualization.tools.liv_plot import LIVPlotter diff --git a/tardis/visualization/plot_util.py b/tardis/visualization/plot_util.py index 7d3b81186f1..54b740411dc 100644 --- a/tardis/visualization/plot_util.py +++ b/tardis/visualization/plot_util.py @@ -60,3 +60,22 @@ def get_mid_point_idx(arr): """ mid_value = (arr[0] + arr[-1]) / 2 return np.abs(arr - mid_value).argmin() + + +def to_rgb255_string(color_tuple): + """ + Convert a matplotlib RGBA tuple to a generic RGB 255 string. + + Parameters + ---------- + color_tuple : tuple + Matplotlib RGBA tuple of float values in closed interval [0, 1] + + Returns + ------- + str + RGB string of format rgb(r,g,b) where r,g,b are integers between + 0 and 255 (both inclusive) + """ + color_tuple_255 = tuple([int(x * 255) for x in color_tuple[:3]]) + return f"rgb{color_tuple_255}" diff --git a/tardis/visualization/tools/liv_plot.py b/tardis/visualization/tools/liv_plot.py new file mode 100644 index 00000000000..eef3a4e0cf4 --- /dev/null +++ b/tardis/visualization/tools/liv_plot.py @@ -0,0 +1,532 @@ +import logging +import matplotlib.pyplot as plt +import matplotlib.cm as cm +import plotly.graph_objects as go +import numpy as np +import pandas as pd +import astropy.units as u + +from tardis.util.base import ( + atomic_number2element_symbol, + int_to_roman, +) +import tardis.visualization.tools.sdec_plot as sdec +from tardis.visualization import plot_util as pu + +logger = logging.getLogger(__name__) + + +class LIVPlotter: + """ + Plotting interface for the last interaction velocity plot. + """ + + def __init__(self, data, time_explosion, velocity): + """ + Initialize the plotter with required data from the simulation. + + Parameters + ---------- + data : dict of SDECData + Dictionary to store data required for last interaction velocity plot, + for both packet modes (real, virtual). + + time_explosion : astropy.units.Quantity + Time of the explosion. + + velocity : astropy.units.Quantity + Velocity array from the simulation. + """ + + self.data = data + self.time_explosion = time_explosion + self.velocity = velocity + self.sdec_plotter = sdec.SDECPlotter(data) + + @classmethod + def from_simulation(cls, sim): + """ + Create an instance of the plotter from a TARDIS simulation object. + + Parameters + ---------- + sim : tardis.simulation.Simulation + TARDIS simulation object produced by running a simulation. + + Returns + ------- + LIVPlotter + """ + + return cls( + dict( + virtual=sdec.SDECData.from_simulation(sim, "virtual"), + real=sdec.SDECData.from_simulation(sim, "real"), + ), + sim.plasma.time_explosion, + sim.simulation_state.velocity, + ) + + @classmethod + def from_hdf(cls, hdf_fpath): + """ + Create an instance of the Plotter from a simulation HDF file. + + Parameters + ---------- + hdf_fpath : str + Valid path to the HDF file where simulation is saved. + + Returns + ------- + LIVPlotter + """ + with pd.HDFStore(hdf_fpath, "r") as hdf: + time_explosion = ( + hdf["/simulation/plasma/scalars"]["time_explosion"] * u.s + ) + v_inner = hdf["/simulation/simulation_state/v_inner"] * (u.cm / u.s) + v_outer = hdf["/simulation/simulation_state/v_outer"] * (u.cm / u.s) + velocity = pd.concat( + [v_inner, pd.Series([v_outer.iloc[-1]])], ignore_index=True + ).tolist() * (u.cm / u.s) + return cls( + dict( + virtual=sdec.SDECData.from_hdf(hdf_fpath, "virtual"), + real=sdec.SDECData.from_hdf(hdf_fpath, "real"), + ), + time_explosion, + velocity, + ) + + def _parse_species_list(self, species_list, packets_mode, nelements=None): + """ + Parse user requested species list and create list of species ids to be used. + + Parameters + ---------- + species_list : list of species to plot + List of species (e.g. Si II, Ca II, etc.) that the user wants to show as unique colours. + Species can be given as an ion (e.g. Si II), an element (e.g. Si), a range of ions + (e.g. Si I - V), or any combination of these (e.g. species_list = [Si II, Fe I-V, Ca]) + packets_mode : str, optional + Packet mode, either 'virtual' or 'real'. Default is 'virtual'. + nelements : int, optional + Number of elements to include in plot. The most interacting elements are included. If None, displays all elements. + + Raises + ------ + ValueError + If species list contains invalid entries. + + """ + self.sdec_plotter._parse_species_list(species_list) + self._species_list = self.sdec_plotter._species_list + self._species_mapped = self.sdec_plotter._species_mapped + self._keep_colour = self.sdec_plotter._keep_colour + + if nelements: + interaction_counts = ( + self.data[packets_mode] + .packets_df_line_interaction["last_line_interaction_species"] + .value_counts() + ) + interaction_counts.index = interaction_counts.index // 100 + element_counts = interaction_counts.groupby( + interaction_counts.index + ).sum() + top_elements = element_counts.nlargest(nelements).index + top_species_list = [ + atomic_number2element_symbol(element) + for element in top_elements + ] + self._parse_species_list(top_species_list, packets_mode) + + def _make_colorbar_labels(self): + """ + Generate labels for the colorbar based on species. + + If a species list is provided, uses that to generate labels. + Otherwise, generates labels from the species in the model. + """ + if self._species_list is None: + species_name = [ + atomic_number2element_symbol(atomic_num) + for atomic_num in self.species + ] + else: + species_name = [] + for species_key, species_ids in self._species_mapped.items(): + if any(species in self.species for species in species_ids): + if species_key % 100 == 0: + label = atomic_number2element_symbol(species_key // 100) + else: + atomic_number = species_key // 100 + ion_number = species_key % 100 + ion_numeral = int_to_roman(ion_number + 1) + label = f"{atomic_number2element_symbol(atomic_number)} {ion_numeral}" + species_name.append(label) + + self._species_name = species_name + + def _make_colorbar_colors(self): + """ + Generate colors for the species to be plotted. + + This method creates a list of colors corresponding to the species names. + The colors are generated based on the species present in the model and + the requested species list. + """ + color_list = [] + species_keys = list(self._species_mapped.keys()) + num_species = len(species_keys) + + for species_counter, species_key in enumerate(species_keys): + if any( + species in self.species + for species in self._species_mapped[species_key] + ): + color = self.cmap(species_counter / num_species) + color_list.append(color) + + self._color_list = color_list + + def _generate_plot_data(self, packets_mode): + """ + Generate plot data and colors for species in the model. + + Parameters + ---------- + packets_mode : str + Packet mode, either 'virtual' or 'real'. + + Returns + ------- + plot_data : list + List of velocity data for each species. + + plot_colors : list + List of colors corresponding to each species. + """ + groups = self.data[packets_mode].packets_df_line_interaction.groupby( + by="last_line_interaction_species" + ) + + plot_colors = [] + plot_data = [] + species_counter = 0 + + for specie_list in self._species_mapped.values(): + full_v_last = [] + for specie in specie_list: + if specie in self.species: + g_df = groups.get_group(specie) + r_last_interaction = ( + g_df["last_interaction_in_r"].values * u.cm + ) + v_last_interaction = ( + r_last_interaction / self.time_explosion + ).to("km/s") + full_v_last.extend(v_last_interaction) + if full_v_last: + plot_data.append(full_v_last) + plot_colors.append(self._color_list[species_counter]) + species_counter += 1 + + return plot_data, plot_colors + + def _prepare_plot_data( + self, packets_mode, species_list, cmapname, num_bins, nelements + ): + """ + Prepare data and settings required for generating a plot. + + This method handles the common logic for preparing data and settings + needed to generate both matplotlib and plotly plots. It parses the species + list, generates color labels and colormap, and bins the velocity data. + + Parameters + ---------- + packets_mode : str + Packet mode, either 'virtual' or 'real'. + species_list : list of str + List of species to plot. Species can be specified as an ion + (e.g., Si II), an element (e.g., Si), a range of ions (e.g., Si I-V), + or any combination of these. + cmapname : str + Name of the colormap to use. A specific colormap can be chosen, such + as "jet", "viridis", "plasma", etc. + num_bins : int, optional + Number of bins for regrouping within the same range. If None, + no regrouping is done. + + Raises + ------ + ValueError + If no species are provided for plotting, or if no valid species are + found in the model. + + Returns + ------- + plot_data : list + List of velocity data for each species. + plot_colors : list + List of colors corresponding to each species. + new_bin_edges : np.ndarray + Array of bin edges for the velocity data. + """ + if species_list is None: + # Extract all unique elements from the packets data + species_in_model = np.unique( + self.data[packets_mode] + .packets_df_line_interaction["last_line_interaction_species"] + .values + ) + species_list = [ + f"{atomic_number2element_symbol(specie // 100)}" + for specie in species_in_model + ] + self._parse_species_list(species_list, packets_mode, nelements) + species_in_model = np.unique( + self.data[packets_mode] + .packets_df_line_interaction["last_line_interaction_species"] + .values + ) + if self._species_list is None or not self._species_list: + raise ValueError("No species provided for plotting.") + msk = np.isin(self._species_list, species_in_model) + self.species = np.array(self._species_list)[msk] + + if len(self.species) == 0: + raise ValueError("No valid species found for plotting.") + + self._make_colorbar_labels() + self.cmap = cm.get_cmap(cmapname, len(self._species_name)) + self._make_colorbar_colors() + plot_data, plot_colors = self._generate_plot_data(packets_mode) + bin_edges = (self.velocity).to("km/s") + + if num_bins: + if num_bins < 1: + raise ValueError("Number of bins must be positive") + elif num_bins > len(bin_edges) - 1: + logger.warn( + "Number of bins must be less than or equal to number of shells. Plotting with number of bins equals to number of shells." + ) + new_bin_edges = bin_edges + else: + new_bin_edges = np.linspace( + bin_edges[0], bin_edges[-1], num_bins + 1 + ) + else: + new_bin_edges = bin_edges + + return plot_data, plot_colors, new_bin_edges + + def _get_step_plot_data(self, data, bin_edges): + """ + Generate step plot data from histogram data. + + Parameters + ---------- + data : array-like + Data to be binned into a histogram. + bin_edges : array-like + Edges of the bins for the histogram. + + Returns + ------- + step_x : np.ndarray + x-coordinates for the step plot. + step_y : np.ndarray + y-coordinates for the step plot. + """ + hist, _ = np.histogram(data, bins=bin_edges) + step_x = np.repeat(bin_edges, 2)[1:-1] + step_y = np.repeat(hist, 2) + return step_x, step_y + + def generate_plot_mpl( + self, + species_list=None, + nelements=None, + packets_mode="virtual", + ax=None, + figsize=(11, 5), + cmapname="jet", + xlog_scale=False, + ylog_scale=False, + num_bins=None, + velocity_range=None, + ): + """ + Generate the last interaction velocity distribution plot using matplotlib. + + Parameters + ---------- + species_list : list of str, optional + List of species to plot. Default is None which plots all species in the model. + nelements : int, optional + Number of elements to include in plot. The most interacting elements are included. If None, displays all elements. + packets_mode : str, optional + Packet mode, either 'virtual' or 'real'. Default is 'virtual'. + ax : matplotlib.axes.Axes, optional + Axes object to plot on. If None, creates a new figure. + figsize : tuple, optional + Size of the figure. Default is (11, 5). + cmapname : str, optional + Colormap name. Default is 'jet'. A specific colormap can be chosen, such as "jet", "viridis", "plasma", etc. + xlog_scale : bool, optional + If True, x-axis is scaled logarithmically. Default is False. + ylog_scale : bool, optional + If True, y-axis is scaled logarithmically. Default is False. + num_bins : int, optional + Number of bins for regrouping within the same range. Default is None. + velocity_range : tuple, optional + Limits for the x-axis. If specified, overrides any automatically determined limits. + + Returns + ------- + matplotlib.axes.Axes + Axes object with the plot. + """ + # If species_list and nelements requested, tell user that nelements is ignored + if species_list is not None and nelements is not None: + logger.info( + "Both nelements and species_list were requested. Species_list takes priority; nelements is ignored" + ) + nelements = None + + plot_data, plot_colors, bin_edges = self._prepare_plot_data( + packets_mode, species_list, cmapname, num_bins, nelements + ) + + if ax is None: + self.ax = plt.figure(figsize=figsize).add_subplot(111) + else: + self.ax = ax + + for data, color, name in zip( + plot_data, plot_colors, self._species_name + ): + step_x, step_y = self._get_step_plot_data(data, bin_edges) + self.ax.plot( + step_x, + step_y, + label=name, + color=color, + linewidth=2.5, + drawstyle="steps-post", + alpha=0.75, + ) + + self.ax.ticklabel_format(axis="y", scilimits=(0, 0)) + self.ax.tick_params("both", labelsize=15) + self.ax.set_xlabel("Last Interaction Velocity (km/s)", fontsize=14) + self.ax.set_ylabel("Packet Count", fontsize=15) + self.ax.legend(fontsize=15, bbox_to_anchor=(1.0, 1.0), loc="upper left") + self.ax.figure.tight_layout() + if xlog_scale: + self.ax.set_xscale("log") + if ylog_scale: + self.ax.set_yscale("log") + if velocity_range: + self.ax.set_xlim(velocity_range[0], velocity_range[1]) + + return self.ax + + def generate_plot_ply( + self, + species_list=None, + nelements=None, + packets_mode="virtual", + fig=None, + graph_height=600, + cmapname="jet", + xlog_scale=False, + ylog_scale=False, + num_bins=None, + velocity_range=None, + ): + """ + Generate the last interaction velocity distribution plot using plotly. + + Parameters + ---------- + species_list : list of str, optional + List of species to plot. Default is None which plots all species in the model. + nelements : int, optional + Number of elements to include in plot. The most interacting elements are included. If None, displays all elements. + packets_mode : str, optional + Packet mode, either 'virtual' or 'real'. Default is 'virtual'. + fig : plotly.graph_objects.Figure, optional + Plotly figure object to add the plot to. If None, creates a new figure. + graph_height : int, optional + Height (in px) of the plotly graph to display. Default value is 600. + cmapname : str, optional + Colormap name. Default is 'jet'. A specific colormap can be chosen, such as "jet", "viridis", "plasma", etc. + xlog_scale : bool, optional + If True, x-axis is scaled logarithmically. Default is False. + ylog_scale : bool, optional + If True, y-axis is scaled logarithmically. Default is False. + num_bins : int, optional + Number of bins for regrouping within the same range. Default is None. + velocity_range : tuple, optional + Limits for the x-axis. If specified, overrides any automatically determined limits. + + Returns + ------- + plotly.graph_objects.Figure + Plotly figure object with the plot. + """ + # If species_list and nelements requested, tell user that nelements is ignored + if species_list is not None and nelements is not None: + logger.info( + "Both nelements and species_list were requested. Species_list takes priority; nelements is ignored" + ) + nelements = None + + plot_data, plot_colors, bin_edges = self._prepare_plot_data( + packets_mode, species_list, cmapname, num_bins, nelements + ) + + if fig is None: + self.fig = go.Figure() + else: + self.fig = fig + + for data, color, name in zip( + plot_data, plot_colors, self._species_name + ): + step_x, step_y = self._get_step_plot_data(data, bin_edges) + self.fig.add_trace( + go.Scatter( + x=step_x, + y=step_y, + mode="lines", + line=dict( + color=pu.to_rgb255_string(color), + width=2.5, + shape="hv", + ), + name=name, + opacity=0.75, + ) + ) + self.fig.update_layout( + height=graph_height, + xaxis_title="Last Interaction Velocity (km/s)", + yaxis_title="Packet Count", + font=dict(size=15), + yaxis=dict(exponentformat="power" if ylog_scale else "e"), + xaxis=dict(exponentformat="power" if xlog_scale else "none"), + ) + if xlog_scale: + self.fig.update_xaxes(type="log") + if ylog_scale: + self.fig.update_yaxes(type="log", dtick=1) + + if velocity_range: + self.fig.update_xaxes(range=velocity_range) + + return self.fig diff --git a/tardis/visualization/tools/sdec_plot.py b/tardis/visualization/tools/sdec_plot.py index 75b01aa4f39..3a68a19ae4a 100644 --- a/tardis/visualization/tools/sdec_plot.py +++ b/tardis/visualization/tools/sdec_plot.py @@ -4,6 +4,7 @@ This plot is a spectral diagnostics plot similar to those originally proposed by M. Kromer (see, for example, Kromer et al. 2013, figure 4). """ + import logging import astropy.units as u @@ -40,6 +41,7 @@ def __init__( last_line_interaction_in_id, last_line_interaction_out_id, last_line_interaction_in_nu, + last_interaction_in_r, lines_df, packet_nus, packet_energies, @@ -67,6 +69,8 @@ def __init__( emission (interaction out) last_line_interaction_in_nu : np.array Frequency values of the last absorption of emitted packets + last_line_interaction_in_r : np.array + Radius of the last interaction experienced by emitted packets lines_df : pd.DataFrame Data about the atomic lines present in simulation model's plasma packet_nus : astropy.Quantity @@ -98,6 +102,7 @@ def __init__( "last_line_interaction_out_id": last_line_interaction_out_id, "last_line_interaction_in_id": last_line_interaction_in_id, "last_line_interaction_in_nu": last_line_interaction_in_nu, + "last_interaction_in_r": last_interaction_in_r, } ) @@ -177,6 +182,7 @@ def from_simulation(cls, sim, packets_mode): last_line_interaction_in_id=transport_state.vpacket_tracker.last_interaction_in_id, last_line_interaction_out_id=transport_state.vpacket_tracker.last_interaction_out_id, last_line_interaction_in_nu=transport_state.vpacket_tracker.last_interaction_in_nu, + last_interaction_in_r=transport_state.vpacket_tracker.last_interaction_in_r, lines_df=lines_df, packet_nus=u.Quantity( transport_state.vpacket_tracker.nus, "Hz" @@ -210,6 +216,9 @@ def from_simulation(cls, sim, packets_mode): last_line_interaction_in_nu=transport_state.last_interaction_in_nu[ transport_state.emitted_packet_mask ], + last_interaction_in_r=transport_state.last_interaction_in_r[ + transport_state.emitted_packet_mask + ], lines_df=lines_df, packet_nus=transport_state.packet_collection.output_nus[ transport_state.emitted_packet_mask @@ -283,6 +292,12 @@ def from_hdf(cls, hdf_fpath, packets_mode): ].to_numpy(), "Hz", ), + last_interaction_in_r=u.Quantity( + hdf[ + "/simulation/transport/transport_state/virt_packet_last_interaction_in_r" + ].to_numpy(), + "cm", + ), lines_df=lines_df, packet_nus=u.Quantity( hdf[ @@ -347,6 +362,12 @@ def from_hdf(cls, hdf_fpath, packets_mode): ].to_numpy()[emitted_packet_mask], "Hz", ), + last_interaction_in_r=u.Quantity( + hdf[ + "/simulation/transport/transport_state/last_interaction_in_r" + ].to_numpy()[emitted_packet_mask], + "cm", + ), lines_df=lines_df, packet_nus=u.Quantity( hdf[ @@ -509,6 +530,7 @@ def _parse_species_list(self, species_list): ) else: full_species_list = [] + species_mapped = {} for species in species_list: # check if a hyphen is present. If it is, then it indicates a # range of ions. Add each ion in that range to the list as a new entry @@ -546,20 +568,20 @@ def _parse_species_list(self, species_list): # the requested ion for species in full_species_list: if " " in species: - requested_species_ids.append( - [ - species_string_to_tuple(species)[0] * 100 - + species_string_to_tuple(species)[1] - ] + species_id = ( + species_string_to_tuple(species)[0] * 100 + + species_string_to_tuple(species)[1] ) + requested_species_ids.append([species_id]) + species_mapped[species_id] = [species_id] else: atomic_number = element_symbol2atomic_number(species) - requested_species_ids.append( - [ - atomic_number * 100 + ion_number - for ion_number in np.arange(atomic_number) - ] - ) + species_ids = [ + atomic_number * 100 + ion_number + for ion_number in np.arange(atomic_number) + ] + requested_species_ids.append(species_ids) + species_mapped[atomic_number * 100] = species_ids # add the atomic number to a list so you know that this element should # have all species in the same colour, i.e. it was requested like # species_list = [Si] @@ -570,6 +592,7 @@ def _parse_species_list(self, species_list): for species_id in temp_list ] + self._species_mapped = species_mapped self._species_list = requested_species_ids self._keep_colour = keep_colour else: @@ -1692,25 +1715,6 @@ def generate_plot_ply( return self.fig - @staticmethod - def to_rgb255_string(color_tuple): - """ - Convert a matplotlib RGBA tuple to a generic RGB 255 string. - - Parameters - ---------- - color_tuple : tuple - Matplotlib RGBA tuple of float values in closed interval [0, 1] - - Returns - ------- - str - RGB string of format rgb(r,g,b) where r,g,b are integers between - 0 and 255 (both inclusive) - """ - color_tuple_255 = tuple([int(x * 255) for x in color_tuple[:3]]) - return f"rgb{color_tuple_255}" - def _plot_emission_ply(self): """Plot emission part of the SDEC Plot using plotly.""" # By specifying a common stackgroup, plotly will itself add up @@ -1767,7 +1771,7 @@ def _plot_emission_ply(self): name=species_name + " Emission", hovertemplate=f"{species_name:s} Emission
" # noqa: ISC003 + "(%{x:.2f}, %{y:.3g})", - fillcolor=self.to_rgb255_string( + fillcolor=pu.to_rgb255_string( self._color_list[species_counter] ), stackgroup="emission", @@ -1826,7 +1830,7 @@ def _plot_absorption_ply(self): name=species_name + " Absorption", hovertemplate=f"{species_name:s} Absorption
" # noqa: ISC003 + "(%{x:.2f}, %{y:.3g})", - fillcolor=self.to_rgb255_string( + fillcolor=pu.to_rgb255_string( self._color_list[species_counter] ), stackgroup="absorption", @@ -1865,7 +1869,7 @@ def _show_colorbar_ply(self): # twice in a row (https://plotly.com/python/colorscales/#constructing-a-discrete-or-discontinuous-color-scale) categorical_colorscale = [] for species_counter in range(len(self._species_name)): - color = self.to_rgb255_string( + color = pu.to_rgb255_string( self.cmap(colorscale_bins[species_counter]) ) categorical_colorscale.append(