Skip to content

Commit

Permalink
Merge pull request #22 from fusion-energy/main_function
Browse files Browse the repository at this point in the history
Main function and sidebar
  • Loading branch information
shimwell authored May 18, 2023
2 parents 2196db1 + ba1bdd6 commit f8fff5b
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 55 deletions.
1 change: 1 addition & 0 deletions src/openmc_geometry_plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
__all__ = ["__version__"]

from .core import *
from .app import *
96 changes: 47 additions & 49 deletions src/openmc_geometry_plot/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
import openmc
import streamlit as st
from matplotlib import colors
from pylab import cm, colormaps # *
from pylab import cm, colormaps
import numpy as np

# import dagmc_h5m_file_inspector as di

import openmc_geometry_plot # adds extra functions to openmc.Geometry


Expand Down Expand Up @@ -59,7 +57,6 @@ def header():


def main():
header()

file_label_col1, file_label_col2 = st.columns([1, 1])
file_label_col1.write(
Expand Down Expand Up @@ -99,7 +96,6 @@ def main():

# DAGMC route
elif dagmc_file is not None and geometry_xml_file is not None:

save_uploadedfile(dagmc_file)
save_uploadedfile(geometry_xml_file)

Expand All @@ -125,7 +121,6 @@ def main():
set_cell_names = set(all_cell_names)

elif dagmc_file is not None and geometry_xml_file is None:

save_uploadedfile(dagmc_file)

# make a basic openmc geometry
Expand All @@ -152,6 +147,7 @@ def main():

# CSG route
elif dagmc_file is None and geometry_xml_file is not None:

save_uploadedfile(geometry_xml_file)

tree = ET.parse(geometry_xml_file.name)
Expand Down Expand Up @@ -199,30 +195,28 @@ def main():
print("geometry is set to something so attempting to plot")
bb = my_geometry.bounding_box

col1, col2 = st.columns([1, 3])

view_direction = col1.selectbox(
view_direction = st.sidebar.selectbox(
label="View direction",
options=("z", "x", "y"),
index=0,
key="geometry_view_direction",
help="Setting the direction of view automatically sets the horizontal and vertical axis used for the plot.",
)
backend = col1.selectbox(
backend = st.sidebar.selectbox(
label="Ploting backend",
options=("matplotlib", "plotly"),
index=0,
key="geometry_ploting_backend",
help="Create png images with MatPlotLib or HTML plots with Plotly",
)
outline = col1.selectbox(
outline = st.sidebar.selectbox(
label="Outline",
options=("materials", "cells", None),
index=0,
key="outline",
help="Allows an outline to be drawn around the cells or materials, select None for no outline",
)
color_by = col1.selectbox(
color_by = st.sidebar.selectbox(
label="Color by",
options=("materials", "cells"),
index=0,
Expand All @@ -239,30 +233,42 @@ def main():
slice_index = {"z": 2, "y": 1, "x": 0}[view_direction]

if np.isinf(bb[0][x_index]) or np.isinf(bb[1][x_index]):
x_min = col1.number_input(label="minimum vertical axis value", key="x_min")
x_max = col1.number_input(label="maximum vertical axis value", key="x_max")
x_min = st.sidebar.number_input(
label="minimum vertical axis value", key="x_min"
)
x_max = st.sidebar.number_input(
label="maximum vertical axis value", key="x_max"
)
else:
x_min = float(bb[0][x_index])
x_max = float(bb[1][x_index])

# y axis is y values
if np.isinf(bb[0][y_index]) or np.isinf(bb[1][y_index]):
y_min = col1.number_input(label="minimum vertical axis value", key="y_min")
y_max = col1.number_input(label="maximum vertical axis value", key="y_max")
y_min = st.sidebar.number_input(
label="minimum vertical axis value", key="y_min"
)
y_max = st.sidebar.number_input(
label="maximum vertical axis value", key="y_max"
)
else:
y_min = float(bb[0][y_index])
y_max = float(bb[1][y_index])

# slice axis is z
if np.isinf(bb[0][slice_index]) or np.isinf(bb[1][slice_index]):
slice_min = col1.number_input(label="minimum slice value", key="slice_min")
slice_max = col1.number_input(label="maximum slice value", key="slice_max")
slice_min = st.sidebar.number_input(
label="minimum slice value", key="slice_min"
)
slice_max = st.sidebar.number_input(
label="maximum slice value", key="slice_max"
)
else:
slice_min = float(bb[0][slice_index])
slice_max = float(bb[1][slice_index])

if isinstance(x_min, float) and isinstance(x_max, float):
plot_right, plot_left = col1.slider(
plot_right, plot_left = st.sidebar.slider(
label="Left and right values for the horizontal axis",
min_value=x_min,
max_value=x_max,
Expand All @@ -272,7 +278,7 @@ def main():
)

if isinstance(y_min, float) and isinstance(y_max, float):
plot_bottom, plot_top = col1.slider(
plot_bottom, plot_top = st.sidebar.slider(
label="Bottom and top values for the vertical axis",
min_value=y_min,
max_value=y_max,
Expand All @@ -281,7 +287,7 @@ def main():
help="Set the lowest visible value and highest visible value on the vertical axis",
)
if isinstance(slice_min, float) and isinstance(slice_max, float):
slice_value = col1.slider(
slice_value = st.sidebar.slider(
label="Slice value",
min_value=slice_min,
max_value=slice_max,
Expand All @@ -290,18 +296,17 @@ def main():
help="Set the value of the slice axis",
)

pixels_across = col1.number_input(
pixels_across = st.sidebar.number_input(
label="Number of horizontal pixels",
value=500,
help="Increasing this value increases the image resolution but also requires longer to create the image",
)

selected_color_map = col1.selectbox(
selected_color_map = st.sidebar.selectbox(
label="Color map", options=colormaps(), index=82
) # index 81 is tab20c
) # index 82 is tab20c

if color_by == "materials":

cmap = cm.get_cmap(selected_color_map, len(set_mat_ids))
initial_hex_color = []
for i in range(cmap.N):
Expand All @@ -310,8 +315,7 @@ def main():
initial_hex_color.append(colors.rgb2hex(rgba))

for c, id in enumerate(set_mat_ids):
# todo add
st.color_picker(
st.sidebar.color_picker(
f"Color of material with id {id}",
key=f"mat_{id}",
value=initial_hex_color[c],
Expand All @@ -334,7 +338,7 @@ def main():
for c, (cell_id, cell_name) in enumerate(zip(set_cell_ids, all_cell_names)):
if cell_name in ["", None]:
cell_name = "not set"
st.color_picker(
st.sidebar.color_picker(
f"Color of cell id {cell_id}, cell name {cell_name}",
key=f"cell_{cell_id}",
value=initial_hex_color[c],
Expand All @@ -346,7 +350,7 @@ def main():
RGB = tuple(int(hex_color[i : i + 2], 16) / 255 for i in (0, 2, 4))
my_colors[id] = RGB

title = col1.text_input(
title = st.sidebar.text_input(
"Plot title",
help="Optionally set your own title for the plot",
value=f"Slice through OpenMC geometry with view direction {view_direction}",
Expand Down Expand Up @@ -413,7 +417,6 @@ def main():
)

if backend == "matplotlib":

extent = my_geometry.get_plot_extent(
plot_left,
plot_right,
Expand Down Expand Up @@ -463,18 +466,17 @@ def main():
)

plt.savefig("openmc_plot_geometry_image.png")
col2.pyplot(plt)
# col2.image("openmc_plot_geometry_image.png", use_column_width="always")
st.pyplot(plt)
# st.image("openmc_plot_geometry_image.png", use_column_width="always")

with open("openmc_plot_geometry_image.png", "rb") as file:
col1.download_button(
st.sidebar.download_button(
label="Download image",
data=file,
file_name="openmc_plot_geometry_image.png",
mime="image/png",
)
else:

data = [
go.Heatmap(
z=color_data_slice,
Expand All @@ -498,7 +500,6 @@ def main():
]

if outline is not None:

data.append(
go.Contour(
z=outline_data_slice,
Expand Down Expand Up @@ -533,26 +534,23 @@ def main():
plot.write_html("openmc_plot_geometry_image.html")

with open("openmc_plot_geometry_image.html", "rb") as file:
col1.download_button(
st.sidebar.download_button(
label="Download image",
data=file,
file_name="openmc_plot_geometry_image.html",
mime=None,
)
col2.plotly_chart(plot, use_container_width=True)

col2.write("Model info")
col2.write(f"Material IDS found {set_mat_ids}")
col2.write(f"Material names found {set_mat_names}")
col2.write(f"Cell IDS found {set_cell_ids}")
col2.write(f"Cell names found {set_cell_names}")
col2.write(
f"Bounding box lower left x={bb[0][0]} y={bb[0][1]} z={bb[0][2]}"
)
col2.write(
f"Bounding box upper right x={bb[1][0]} y={bb[1][1]} z={bb[1][2]}"
)
st.plotly_chart(plot, use_container_width=True)

st.write("Model info")
st.write(f"Material IDS found {set_mat_ids}")
st.write(f"Material names found {set_mat_names}")
st.write(f"Cell IDS found {set_cell_ids}")
st.write(f"Cell names found {set_cell_names}")
st.write(f"Bounding box lower left x={bb[0][0]} y={bb[0][1]} z={bb[0][2]}")
st.write(f"Bounding box upper right x={bb[1][0]} y={bb[1][1]} z={bb[1][2]}")


if __name__ == "__main__":
header()
main()
18 changes: 12 additions & 6 deletions src/openmc_geometry_plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,18 @@ def get_slice_of_material_ids(
my_plots = openmc.Plots([my_plot])
my_plots.export_to_xml(tmp_folder)

# TODO unset this afterwards
original_cross_sections = openmc.config["cross_sections"]
if 'cross_sections' in openmc.config.keys():
original_cross_sections = openmc.config["cross_sections"]
else:
original_cross_sections = None

package_dir = Path(__file__).parent
openmc.config["cross_sections"] = package_dir / "cross_sections.xml"

openmc.plot_geometry(cwd=tmp_folder, output=verbose)

openmc.config["cross_sections"] = original_cross_sections
if original_cross_sections:
openmc.config["cross_sections"] = original_cross_sections

if verbose:
print(f"Temporary image and xml files written to {tmp_folder}")
Expand Down Expand Up @@ -402,7 +405,6 @@ def get_slice_of_cell_ids(
bb = dag_universe.bounding_box

else:

original_materials = self.get_all_materials()
mat_ids = original_materials.keys()

Expand Down Expand Up @@ -476,15 +478,19 @@ def get_slice_of_cell_ids(
my_plots = openmc.Plots([my_plot])
my_plots.export_to_xml(tmp_folder)

original_cross_sections = openmc.config["cross_sections"]
if 'cross_sections' in openmc.config.keys():
original_cross_sections = openmc.config["cross_sections"]
else:
original_cross_sections = None

# TODO unset this afterwards
package_dir = Path(__file__).parent
openmc.config["cross_sections"] = package_dir / "cross_sections.xml"

openmc.plot_geometry(cwd=tmp_folder, output=verbose)

openmc.config["cross_sections"] = original_cross_sections
if original_cross_sections:
openmc.config["cross_sections"] = original_cross_sections

if verbose:
print(f"Temporary image and xml files written to {tmp_folder}")
Expand Down

0 comments on commit f8fff5b

Please sign in to comment.