Skip to content

Commit ba3ccf4

Browse files
Gregory RobertsGregory Roberts
Gregory Roberts
authored and
Gregory Roberts
committed
feat[frontend]: viz spec implementation and testing
1 parent a0c3cfe commit ba3ccf4

File tree

8 files changed

+224
-4
lines changed

8 files changed

+224
-4
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
- `VisualizationSpec` that allows `Medium` instances to specify color and transparency plotting attributes that override default ones.
12+
1013
### Changed
1114
- `ModeMonitor` and `ModeSolverMonitor` now use the default `td.ModeSpec()` with `num_modes=1` when `mode_spec` is not provided.
1215

tests/test_components/test_viz.py

+149
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests visualization operations."""
22

33
import matplotlib.pyplot as plt
4+
import pydantic.v1 as pd
45
import pytest
56
import tidy3d as td
67
from tidy3d.components.viz import Polygon, set_default_labels_and_title
@@ -100,3 +101,151 @@ def test_set_default_labels_title():
100101
ax = set_default_labels_and_title(
101102
axis_labels=axis_labels, axis=2, position=0, ax=ax, plot_length_units="inches"
102103
)
104+
105+
plt.close()
106+
107+
108+
def test_make_viz_spec():
109+
"""
110+
Tests core visualizaton spec creation.
111+
"""
112+
viz_spec = td.VisualizationSpec(facecolor="red", edgecolor="green", alpha=0.5)
113+
viz_spec = td.VisualizationSpec(facecolor="red", alpha=0.5)
114+
115+
116+
def test_unallowed_colors():
117+
"""
118+
Tests validator for visualization spec for colors not recognized by matplotlib.
119+
"""
120+
with pytest.raises(pd.ValidationError):
121+
_ = td.VisualizationSpec(facecolor="rr", edgecolor="green", alpha=0.5)
122+
with pytest.raises(pd.ValidationError):
123+
_ = td.VisualizationSpec(facecolor="red", edgecolor="gg", alpha=0.5)
124+
125+
126+
def test_unallowed_alpha():
127+
"""
128+
Tests validator for disallowed alpha values.
129+
"""
130+
with pytest.raises(pd.ValidationError):
131+
_ = td.VisualizationSpec(facecolor="red", edgecolor="green", alpha=-0.5)
132+
with pytest.raises(pd.ValidationError):
133+
_ = td.VisualizationSpec(facecolor="red", edgecolor="green", alpha=2.5)
134+
135+
136+
def test_plot_from_structure():
137+
"""
138+
Tests visualization spec can be added to medium and structure plotting function can be run.
139+
"""
140+
viz_spec = td.VisualizationSpec(facecolor="blue", edgecolor="pink", alpha=0.5)
141+
medium = td.Medium(permittivity=2.25, viz_spec=viz_spec)
142+
geometry = td.Box(size=(2, 0, 2))
143+
144+
structure = td.Structure(geometry=geometry, medium=medium)
145+
146+
structure.plot(z=0)
147+
plt.close()
148+
149+
150+
def plot_with_viz_spec(alpha, facecolor, edgecolor=None, use_viz_spec=True):
151+
"""
152+
Helper function for locally testing different visualization specs in structures through
153+
structure plotting function.
154+
"""
155+
if edgecolor is None:
156+
viz_spec = td.VisualizationSpec(facecolor=facecolor, alpha=alpha)
157+
else:
158+
viz_spec = td.VisualizationSpec(facecolor=facecolor, edgecolor=edgecolor, alpha=alpha)
159+
160+
medium = td.Medium(permittivity=2.25)
161+
if use_viz_spec:
162+
medium = td.Medium(permittivity=2.25, viz_spec=viz_spec)
163+
164+
geometry = td.Box(size=(2, 4, 2))
165+
166+
structure = td.Structure(geometry=geometry, medium=medium)
167+
168+
structure.plot(z=1)
169+
plt.show()
170+
171+
172+
def plot_with_multi_viz_spec(alphas, facecolors, edgecolors, rng, use_viz_spec=True):
173+
"""
174+
Helper function for plotting simulations with multiple visulation specs via simluation
175+
plotting function.
176+
"""
177+
viz_specs = [
178+
td.VisualizationSpec(
179+
facecolor=facecolors[idx], edgecolor=edgecolors[idx], alpha=alphas[idx]
180+
)
181+
for idx in range(0, len(alphas))
182+
]
183+
media = [td.Medium(permittivity=2.25) for idx in range(0, len(viz_specs))]
184+
if use_viz_spec:
185+
media = [
186+
td.Medium(permittivity=2.25, viz_spec=viz_specs[idx])
187+
for idx in range(0, len(viz_specs))
188+
]
189+
190+
structures = []
191+
for idx in range(0, len(viz_specs)):
192+
center = tuple(list(rng.uniform(-3, 3, 2)) + [0])
193+
size = tuple(rng.uniform(1, 2, 3))
194+
box = td.Box(center=center, size=size)
195+
196+
structures.append(td.Structure(geometry=box, medium=media[idx]))
197+
198+
sim = td.Simulation(
199+
size=(10.0, 10.0, 10.0),
200+
run_time=1e-12,
201+
structures=structures,
202+
grid_spec=td.GridSpec(wavelength=1.0),
203+
)
204+
205+
sim.plot(z=0.0)
206+
plt.show()
207+
208+
209+
@pytest.mark.skip(reason="Skipping test for CI, but useful for debugging locally with graphics.")
210+
def test_plot_from_structure_local():
211+
"""
212+
Local test for visualizing output when specifying visualization spec.
213+
"""
214+
plot_with_viz_spec(alpha=0.5, facecolor="red", edgecolor="blue")
215+
plot_with_viz_spec(alpha=0.1, facecolor="magenta", edgecolor="cyan")
216+
plot_with_viz_spec(alpha=0.9, facecolor="darkgreen", edgecolor="black")
217+
plot_with_viz_spec(alpha=0.8, facecolor="brown", edgecolor="deepskyblue")
218+
plot_with_viz_spec(alpha=0.2, facecolor="brown", edgecolor="deepskyblue")
219+
plot_with_viz_spec(alpha=1.0, facecolor="green")
220+
plot_with_viz_spec(alpha=0.75, facecolor="red", edgecolor="blue")
221+
plot_with_viz_spec(alpha=0.75, facecolor="red", edgecolor="blue", use_viz_spec=False)
222+
223+
with pytest.raises(pd.ValidationError):
224+
plot_with_viz_spec(alpha=0.5, facecolor="dark green", edgecolor="blue")
225+
with pytest.raises(pd.ValidationError):
226+
plot_with_viz_spec(alpha=0.5, facecolor="red", edgecolor="ble")
227+
with pytest.raises(pd.ValidationError):
228+
plot_with_viz_spec(alpha=1.5, facecolor="red", edgecolor="blue")
229+
with pytest.raises(pd.ValidationError):
230+
plot_with_viz_spec(alpha=-0.5, facecolor="red", edgecolor="blue")
231+
232+
233+
@pytest.mark.skip(reason="Skipping test for CI, but useful for debugging locally with graphics.")
234+
def test_plot_multi_from_structure_local(rng):
235+
"""
236+
Local test for visualizing output when creating multiple structures with variety of
237+
visualization specs.
238+
"""
239+
plot_with_multi_viz_spec(
240+
alphas=[0.5, 0.75, 0.25, 0.4],
241+
facecolors=["red", "green", "blue", "orange"],
242+
edgecolors=["black", "cyan", "magenta", "brown"],
243+
rng=rng,
244+
)
245+
plot_with_multi_viz_spec(
246+
alphas=[0.5, 0.75, 0.25, 0.4],
247+
facecolors=["red", "green", "blue", "orange"],
248+
edgecolors=["black", "cyan", "magenta", "brown"],
249+
rng=rng,
250+
use_viz_spec=False,
251+
)

tidy3d/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@
284284
SpaceTimeModulation,
285285
)
286286
from .components.transformation import RotationAroundAxis
287+
from .components.viz import VisualizationSpec
287288

288289
# config
289290
from .config import config
@@ -534,6 +535,7 @@ def set_logging_level(level: str) -> None:
534535
"HeuristicPECStaircasing",
535536
"PECConformal",
536537
"SurfaceImpedance",
538+
"VisualizationSpec",
537539
"EMESimulation",
538540
"EMESimulationData",
539541
"EMEMonitor",

tidy3d/components/geometry/base.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
ARROW_LENGTH,
4848
PLOT_BUFFER,
4949
PlotParams,
50+
VisualizationSpec,
5051
add_ax_if_none,
5152
arrow_style,
5253
equal_aspect,
@@ -446,6 +447,7 @@ def plot(
446447
z: float = None,
447448
ax: Ax = None,
448449
plot_length_units: LengthUnit = None,
450+
viz_spec: VisualizationSpec = None,
449451
**patch_kwargs,
450452
) -> Ax:
451453
"""Plot geometry cross section at single (x,y,z) coordinate.
@@ -462,6 +464,8 @@ def plot(
462464
Matplotlib axes to plot on, if not specified, one is created.
463465
plot_length_units : LengthUnit = None
464466
Specify units to use for axis labels, tick labels, and the title.
467+
viz_spec : VisualizationSpec = None
468+
Plotting parameters associated with a medium to use instead of defaults.
465469
**patch_kwargs
466470
Optional keyword arguments passed to the matplotlib patch plotting of structure.
467471
For details on accepted values, refer to
@@ -477,7 +481,10 @@ def plot(
477481
axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z)
478482
shapes_intersect = self.intersections_plane(x=x, y=y, z=z)
479483

480-
plot_params = self.plot_params.include_kwargs(**patch_kwargs)
484+
plot_params = self.plot_params
485+
if viz_spec is not None:
486+
plot_params = plot_params.override_with_viz_spec(viz_spec)
487+
plot_params = plot_params.include_kwargs(**patch_kwargs)
481488

482489
# for each intersection, plot the shape
483490
for shape in shapes_intersect:

tidy3d/components/medium.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
TensorReal,
8383
)
8484
from .validators import _warn_potential_error, validate_name_str, validate_parameter_perturbation
85-
from .viz import add_ax_if_none
85+
from .viz import VisualizationSpec, add_ax_if_none
8686

8787
# evaluate frequency as this number (Hz) if inf
8888
FREQ_EVAL_INF = 1e50
@@ -646,6 +646,12 @@ class AbstractMedium(ABC, Tidy3dBaseModel):
646646
description="Modulation spec applied on top of the base medium properties.",
647647
)
648648

649+
viz_spec: Optional[VisualizationSpec] = pd.Field(
650+
None,
651+
title="Visualization Specification",
652+
description="Plotting specification for visualizing medium.",
653+
)
654+
649655
@cached_property
650656
def _nonlinear_models(self) -> NonlinearSpec:
651657
"""The nonlinear models in the nonlinear_spec."""

tidy3d/components/scene.py

+6
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,8 @@ def _get_structure_plot_params(self, mat_index: int, medium: Medium) -> PlotPara
466466
# regular medium
467467
facecolor = MEDIUM_CMAP[(mat_index - 1) % len(MEDIUM_CMAP)]
468468
plot_params = plot_params.copy(update={"facecolor": facecolor})
469+
if medium.viz_spec is not None:
470+
plot_params = plot_params.override_with_viz_spec(medium.viz_spec)
469471

470472
return plot_params
471473

@@ -1068,6 +1070,8 @@ def _get_structure_eps_plot_params(
10681070
"""Constructs the plot parameters for a given medium in scene.plot_eps()."""
10691071

10701072
plot_params = plot_params_structure.copy(update={"linewidth": 0})
1073+
if medium.viz_spec is not None:
1074+
plot_params = plot_params.override_with_viz_spec(medium.viz_spec)
10711075
if alpha is not None:
10721076
plot_params = plot_params.copy(update={"alpha": alpha})
10731077

@@ -1390,6 +1394,8 @@ def _get_structure_heat_charge_property_plot_params(
13901394
"""
13911395

13921396
plot_params = plot_params_structure.copy(update={"linewidth": 0})
1397+
if medium.viz_spec is not None:
1398+
plot_params = plot_params.override_with_viz_spec(medium.viz_spec)
13931399
if alpha is not None:
13941400
plot_params = plot_params.copy(update={"alpha": alpha})
13951401

tidy3d/components/structure.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def plot(
136136
matplotlib.axes._subplots.Axes
137137
The supplied or created matplotlib axes.
138138
"""
139-
return self.geometry.plot(x=x, y=y, z=z, ax=ax, **patch_kwargs)
139+
return self.geometry.plot(
140+
x=x, y=y, z=z, ax=ax, viz_spec=self.medium.viz_spec, **patch_kwargs
141+
)
140142

141143

142144
class Structure(AbstractStructure):

tidy3d/components/viz.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44

55
from functools import wraps
66
from html import escape
7-
from typing import Any
7+
from typing import Any, Dict, Optional
88

99
import matplotlib.pyplot as plt
1010
import matplotlib.ticker as ticker
1111
import pydantic.v1 as pd
12+
from matplotlib.colors import is_color_like
1213
from matplotlib.patches import ArrowStyle, PathPatch
1314
from matplotlib.path import Path
1415
from numpy import array, concatenate, inf, ones
@@ -92,6 +93,10 @@ def include_kwargs(self, **kwargs) -> AbstractPlotParams:
9293
}
9394
return self.copy(update=update_dict)
9495

96+
def override_with_viz_spec(self, viz_spec) -> AbstractPlotParams:
97+
"""Override plot params with supplied VisualizationSpec."""
98+
return self.include_kwargs(**dict(viz_spec))
99+
95100
def to_kwargs(self) -> dict:
96101
"""Export the plot parameters as kwargs dict that can be supplied to plot function."""
97102
kwarg_dict = self.dict()
@@ -165,6 +170,46 @@ class PlotParams(AbstractPlotParams):
165170
arrow_style = ArrowStyle.Simple(head_length=12, head_width=9, tail_width=4)
166171

167172

173+
def is_valid_color(value: str) -> str:
174+
if not is_color_like(value):
175+
raise pd.ValidationError(f"{value} is not a valid plotting color")
176+
177+
return value
178+
179+
180+
class VisualizationSpec(Tidy3dBaseModel):
181+
"""Defines specification for visualization when used with plotting functions."""
182+
183+
facecolor: str = pd.Field(
184+
"",
185+
title="Face color",
186+
description="Color applied to the faces in visualization.",
187+
)
188+
189+
edgecolor: Optional[str] = pd.Field(
190+
"",
191+
title="Edge color",
192+
description="Color applied to the edges in visualization.",
193+
)
194+
195+
alpha: Optional[pd.confloat(ge=0.0, le=1.0)] = pd.Field(
196+
1.0,
197+
title="Opacity",
198+
description="Opacity/alpha value in plotting between 0 and 1.",
199+
)
200+
201+
@pd.validator("facecolor", always=True)
202+
def validate_color(value: str) -> str:
203+
return is_valid_color(value)
204+
205+
@pd.validator("edgecolor", always=True)
206+
def validate_and_copy_color(value: str, values: Dict[str, Any]) -> str:
207+
if (value == "") and "facecolor" in values:
208+
return is_valid_color(values["facecolor"])
209+
210+
return is_valid_color(value)
211+
212+
168213
"""=================================================================================================
169214
Descartes modified from https://pypi.org/project/descartes/ for Shapely >= 1.8.0
170215

0 commit comments

Comments
 (0)