Skip to content

Commit aaae37a

Browse files
Gregory RobertsGregory Roberts
Gregory Roberts
authored and
Gregory Roberts
committed
feat[frontend]: VisualizationSpec implementation and testing
1 parent 55938f2 commit aaae37a

File tree

8 files changed

+226
-4
lines changed

8 files changed

+226
-4
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
- `VisualizationSpec` that allows `Medium` instances to specify color and transparency plotting attributes that override default ones.
12+
1013
### Changed
1114
- `ModeMonitor` and `ModeSolverMonitor` now use the default `td.ModeSpec()` with `num_modes=1` when `mode_spec` is not provided.
1215

tests/test_components/test_viz.py

+149
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests visualization operations."""
22

33
import matplotlib.pyplot as plt
4+
import pydantic.v1 as pd
45
import pytest
56
import tidy3d as td
67
from tidy3d.components.viz import Polygon, set_default_labels_and_title
@@ -100,3 +101,151 @@ def test_set_default_labels_title():
100101
ax = set_default_labels_and_title(
101102
axis_labels=axis_labels, axis=2, position=0, ax=ax, plot_length_units="inches"
102103
)
104+
105+
plt.close()
106+
107+
108+
def test_make_viz_spec():
109+
"""
110+
Tests core visualizaton spec creation.
111+
"""
112+
viz_spec = td.VisualizationSpec(facecolor="red", edgecolor="green", alpha=0.5)
113+
viz_spec = td.VisualizationSpec(facecolor="red", alpha=0.5)
114+
115+
116+
def test_unallowed_colors():
117+
"""
118+
Tests validator for visualization spec for colors not recognized by matplotlib.
119+
"""
120+
with pytest.raises(pd.ValidationError):
121+
_ = td.VisualizationSpec(facecolor="rr", edgecolor="green", alpha=0.5)
122+
with pytest.raises(pd.ValidationError):
123+
_ = td.VisualizationSpec(facecolor="red", edgecolor="gg", alpha=0.5)
124+
125+
126+
def test_unallowed_alpha():
127+
"""
128+
Tests validator for disallowed alpha values.
129+
"""
130+
with pytest.raises(pd.ValidationError):
131+
_ = td.VisualizationSpec(facecolor="red", edgecolor="green", alpha=-0.5)
132+
with pytest.raises(pd.ValidationError):
133+
_ = td.VisualizationSpec(facecolor="red", edgecolor="green", alpha=2.5)
134+
135+
136+
def test_plot_from_structure():
137+
"""
138+
Tests visualization spec can be added to medium and structure plotting function can be run.
139+
"""
140+
viz_spec = td.VisualizationSpec(facecolor="blue", edgecolor="pink", alpha=0.5)
141+
medium = td.Medium(permittivity=2.25, viz_spec=viz_spec)
142+
geometry = td.Box(size=(2, 0, 2))
143+
144+
structure = td.Structure(geometry=geometry, medium=medium)
145+
146+
structure.plot(z=0)
147+
plt.close()
148+
149+
150+
def plot_with_viz_spec(alpha, facecolor, edgecolor=None, use_viz_spec=True):
151+
"""
152+
Helper function for locally testing different visualization specs in structures through
153+
structure plotting function.
154+
"""
155+
if edgecolor is None:
156+
viz_spec = td.VisualizationSpec(facecolor=facecolor, alpha=alpha)
157+
else:
158+
viz_spec = td.VisualizationSpec(facecolor=facecolor, edgecolor=edgecolor, alpha=alpha)
159+
160+
medium = td.Medium(permittivity=2.25)
161+
if use_viz_spec:
162+
medium = td.Medium(permittivity=2.25, viz_spec=viz_spec)
163+
164+
geometry = td.Box(size=(2, 4, 2))
165+
166+
structure = td.Structure(geometry=geometry, medium=medium)
167+
168+
structure.plot(z=1)
169+
plt.show()
170+
171+
172+
def plot_with_multi_viz_spec(alphas, facecolors, edgecolors, rng, use_viz_spec=True):
173+
"""
174+
Helper function for plotting simulations with multiple visulation specs via simluation
175+
plotting function.
176+
"""
177+
viz_specs = [
178+
td.VisualizationSpec(
179+
facecolor=facecolors[idx], edgecolor=edgecolors[idx], alpha=alphas[idx]
180+
)
181+
for idx in range(0, len(alphas))
182+
]
183+
media = [td.Medium(permittivity=2.25) for idx in range(0, len(viz_specs))]
184+
if use_viz_spec:
185+
media = [
186+
td.Medium(permittivity=2.25, viz_spec=viz_specs[idx])
187+
for idx in range(0, len(viz_specs))
188+
]
189+
190+
structures = []
191+
for idx in range(0, len(viz_specs)):
192+
center = tuple(list(rng.uniform(-3, 3, 2)) + [0])
193+
size = tuple(rng.uniform(1, 2, 3))
194+
box = td.Box(center=center, size=size)
195+
196+
structures.append(td.Structure(geometry=box, medium=media[idx]))
197+
198+
sim = td.Simulation(
199+
size=(10.0, 10.0, 10.0),
200+
run_time=1e-12,
201+
structures=structures,
202+
grid_spec=td.GridSpec(wavelength=1.0),
203+
)
204+
205+
sim.plot(z=0.0)
206+
plt.show()
207+
208+
209+
@pytest.mark.skip(reason="Skipping test for CI, but useful for debugging locally with graphics.")
210+
def test_plot_from_structure_local():
211+
"""
212+
Local test for visualizing output when specifying visualization spec.
213+
"""
214+
plot_with_viz_spec(alpha=0.5, facecolor="red", edgecolor="blue")
215+
plot_with_viz_spec(alpha=0.1, facecolor="magenta", edgecolor="cyan")
216+
plot_with_viz_spec(alpha=0.9, facecolor="darkgreen", edgecolor="black")
217+
plot_with_viz_spec(alpha=0.8, facecolor="brown", edgecolor="deepskyblue")
218+
plot_with_viz_spec(alpha=0.2, facecolor="brown", edgecolor="deepskyblue")
219+
plot_with_viz_spec(alpha=1.0, facecolor="green")
220+
plot_with_viz_spec(alpha=0.75, facecolor="red", edgecolor="blue")
221+
plot_with_viz_spec(alpha=0.75, facecolor="red", edgecolor="blue", use_viz_spec=False)
222+
223+
with pytest.raises(pd.ValidationError):
224+
plot_with_viz_spec(alpha=0.5, facecolor="dark green", edgecolor="blue")
225+
with pytest.raises(pd.ValidationError):
226+
plot_with_viz_spec(alpha=0.5, facecolor="red", edgecolor="ble")
227+
with pytest.raises(pd.ValidationError):
228+
plot_with_viz_spec(alpha=1.5, facecolor="red", edgecolor="blue")
229+
with pytest.raises(pd.ValidationError):
230+
plot_with_viz_spec(alpha=-0.5, facecolor="red", edgecolor="blue")
231+
232+
233+
@pytest.mark.skip(reason="Skipping test for CI, but useful for debugging locally with graphics.")
234+
def test_plot_multi_from_structure_local(rng):
235+
"""
236+
Local test for visualizing output when creating multiple structures with variety of
237+
visualization specs.
238+
"""
239+
plot_with_multi_viz_spec(
240+
alphas=[0.5, 0.75, 0.25, 0.4],
241+
facecolors=["red", "green", "blue", "orange"],
242+
edgecolors=["black", "cyan", "magenta", "brown"],
243+
rng=rng,
244+
)
245+
plot_with_multi_viz_spec(
246+
alphas=[0.5, 0.75, 0.25, 0.4],
247+
facecolors=["red", "green", "blue", "orange"],
248+
edgecolors=["black", "cyan", "magenta", "brown"],
249+
rng=rng,
250+
use_viz_spec=False,
251+
)

tidy3d/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@
284284
SpaceTimeModulation,
285285
)
286286
from .components.transformation import RotationAroundAxis
287+
from .components.viz import VisualizationSpec
287288

288289
# config
289290
from .config import config
@@ -534,6 +535,7 @@ def set_logging_level(level: str) -> None:
534535
"HeuristicPECStaircasing",
535536
"PECConformal",
536537
"SurfaceImpedance",
538+
"VisualizationSpec",
537539
"EMESimulation",
538540
"EMESimulationData",
539541
"EMEMonitor",

tidy3d/components/geometry/base.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
ARROW_LENGTH,
5252
PLOT_BUFFER,
5353
PlotParams,
54+
VisualizationSpec,
5455
add_ax_if_none,
5556
arrow_style,
5657
equal_aspect,
@@ -450,6 +451,7 @@ def plot(
450451
z: float = None,
451452
ax: Ax = None,
452453
plot_length_units: LengthUnit = None,
454+
viz_spec: VisualizationSpec = None,
453455
**patch_kwargs,
454456
) -> Ax:
455457
"""Plot geometry cross section at single (x,y,z) coordinate.
@@ -466,6 +468,8 @@ def plot(
466468
Matplotlib axes to plot on, if not specified, one is created.
467469
plot_length_units : LengthUnit = None
468470
Specify units to use for axis labels, tick labels, and the title.
471+
viz_spec : VisualizationSpec = None
472+
Plotting parameters associated with a medium to use instead of defaults.
469473
**patch_kwargs
470474
Optional keyword arguments passed to the matplotlib patch plotting of structure.
471475
For details on accepted values, refer to
@@ -481,7 +485,10 @@ def plot(
481485
axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z)
482486
shapes_intersect = self.intersections_plane(x=x, y=y, z=z)
483487

484-
plot_params = self.plot_params.include_kwargs(**patch_kwargs)
488+
plot_params = self.plot_params
489+
if viz_spec is not None:
490+
plot_params = plot_params.override_with_viz_spec(viz_spec)
491+
plot_params = plot_params.include_kwargs(**patch_kwargs)
485492

486493
# for each intersection, plot the shape
487494
for shape in shapes_intersect:

tidy3d/components/medium.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
TensorReal,
8383
)
8484
from .validators import _warn_potential_error, validate_name_str, validate_parameter_perturbation
85-
from .viz import add_ax_if_none
85+
from .viz import VisualizationSpec, add_ax_if_none
8686

8787
# evaluate frequency as this number (Hz) if inf
8888
FREQ_EVAL_INF = 1e50
@@ -646,6 +646,12 @@ class AbstractMedium(ABC, Tidy3dBaseModel):
646646
description="Modulation spec applied on top of the base medium properties.",
647647
)
648648

649+
viz_spec: Optional[VisualizationSpec] = pd.Field(
650+
None,
651+
title="Visualization Specification",
652+
description="Plotting specification for visualizing medium.",
653+
)
654+
649655
@cached_property
650656
def _nonlinear_models(self) -> NonlinearSpec:
651657
"""The nonlinear models in the nonlinear_spec."""

tidy3d/components/scene.py

+6
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,8 @@ def _get_structure_plot_params(self, mat_index: int, medium: Medium) -> PlotPara
470470
# regular medium
471471
facecolor = MEDIUM_CMAP[(mat_index - 1) % len(MEDIUM_CMAP)]
472472
plot_params = plot_params.copy(update={"facecolor": facecolor})
473+
if medium.viz_spec is not None:
474+
plot_params = plot_params.override_with_viz_spec(medium.viz_spec)
473475

474476
return plot_params
475477

@@ -1072,6 +1074,8 @@ def _get_structure_eps_plot_params(
10721074
"""Constructs the plot parameters for a given medium in scene.plot_eps()."""
10731075

10741076
plot_params = plot_params_structure.copy(update={"linewidth": 0})
1077+
if medium.viz_spec is not None:
1078+
plot_params = plot_params.override_with_viz_spec(medium.viz_spec)
10751079
if alpha is not None:
10761080
plot_params = plot_params.copy(update={"alpha": alpha})
10771081

@@ -1394,6 +1398,8 @@ def _get_structure_heat_charge_property_plot_params(
13941398
"""
13951399

13961400
plot_params = plot_params_structure.copy(update={"linewidth": 0})
1401+
if medium.viz_spec is not None:
1402+
plot_params = plot_params.override_with_viz_spec(medium.viz_spec)
13971403
if alpha is not None:
13981404
plot_params = plot_params.copy(update={"alpha": alpha})
13991405

tidy3d/components/structure.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def plot(
136136
matplotlib.axes._subplots.Axes
137137
The supplied or created matplotlib axes.
138138
"""
139-
return self.geometry.plot(x=x, y=y, z=z, ax=ax, **patch_kwargs)
139+
return self.geometry.plot(
140+
x=x, y=y, z=z, ax=ax, viz_spec=self.medium.viz_spec, **patch_kwargs
141+
)
140142

141143

142144
class Structure(AbstractStructure):

tidy3d/components/viz.py

+48-1
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,22 @@
44

55
from functools import wraps
66
from html import escape
7-
from typing import Any
7+
from typing import Any, Dict, Optional
88

99
import pydantic.v1 as pd
1010

1111
try:
1212
import matplotlib.pyplot as plt
1313
import matplotlib.ticker as ticker
14+
from matplotlib.colors import is_color_like
1415
from matplotlib.patches import ArrowStyle, PathPatch
1516
from matplotlib.path import Path
1617

1718
# default arrow style
1819
arrow_style = ArrowStyle.Simple(head_length=12, head_width=9, tail_width=4)
1920
except ImportError:
2021
arrow_style = None
22+
2123
from numpy import array, concatenate, inf, ones
2224

2325
from ..constants import UnitScaling
@@ -99,6 +101,10 @@ def include_kwargs(self, **kwargs) -> AbstractPlotParams:
99101
}
100102
return self.copy(update=update_dict)
101103

104+
def override_with_viz_spec(self, viz_spec) -> AbstractPlotParams:
105+
"""Override plot params with supplied VisualizationSpec."""
106+
return self.include_kwargs(**dict(viz_spec))
107+
102108
def to_kwargs(self) -> dict:
103109
"""Export the plot parameters as kwargs dict that can be supplied to plot function."""
104110
kwarg_dict = self.dict()
@@ -168,6 +174,47 @@ class PlotParams(AbstractPlotParams):
168174
STRUCTURE_EPS_CMAP = "gist_yarg"
169175
STRUCTURE_HEAT_COND_CMAP = "gist_yarg"
170176

177+
178+
def is_valid_color(value: str) -> str:
179+
if not is_color_like(value):
180+
raise pd.ValidationError(f"{value} is not a valid plotting color")
181+
182+
return value
183+
184+
185+
class VisualizationSpec(Tidy3dBaseModel):
186+
"""Defines specification for visualization when used with plotting functions."""
187+
188+
facecolor: str = pd.Field(
189+
"",
190+
title="Face color",
191+
description="Color applied to the faces in visualization.",
192+
)
193+
194+
edgecolor: Optional[str] = pd.Field(
195+
"",
196+
title="Edge color",
197+
description="Color applied to the edges in visualization.",
198+
)
199+
200+
alpha: Optional[pd.confloat(ge=0.0, le=1.0)] = pd.Field(
201+
1.0,
202+
title="Opacity",
203+
description="Opacity/alpha value in plotting between 0 and 1.",
204+
)
205+
206+
@pd.validator("facecolor", always=True)
207+
def validate_color(value: str) -> str:
208+
return is_valid_color(value)
209+
210+
@pd.validator("edgecolor", always=True)
211+
def validate_and_copy_color(value: str, values: Dict[str, Any]) -> str:
212+
if (value == "") and "facecolor" in values:
213+
return is_valid_color(values["facecolor"])
214+
215+
return is_valid_color(value)
216+
217+
171218
"""=================================================================================================
172219
Descartes modified from https://pypi.org/project/descartes/ for Shapely >= 1.8.0
173220

0 commit comments

Comments
 (0)