Skip to content

Commit

Permalink
RFC: refactor ParticleProjectionPlot for consistency with ProjectionPlot
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Jun 24, 2023
1 parent ac55102 commit c4c73af
Showing 1 changed file with 167 additions and 57 deletions.
224 changes: 167 additions & 57 deletions yt/visualization/particle_plots.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
import warnings
from typing import List, Union

import numpy as np

Expand Down Expand Up @@ -176,7 +177,7 @@ def __init__(
)


class ParticleProjectionPlot(PWViewerMPL, NormalPlot):
class ParticleProjectionPlot(NormalPlot):
r"""Creates a particle plot from a dataset
Given a ds object, a normal to project along, and a field name
Expand Down Expand Up @@ -330,6 +331,27 @@ class ParticleProjectionPlot(PWViewerMPL, NormalPlot):
_plot_type = "Particle"
_frb_generator = ParticleImageBuffer

# ignoring type check here, because mypy doesn't allow __new__ methods to
# return instances of subclasses. The design we use here is however based
# on the pathlib.Path class from the standard library
# https://github.com/python/mypy/issues/1020
def __new__( # type: ignore
cls, ds, normal, fields, *args, **kwargs
) -> Union["AxisAlignedParticleProjectionPlot", "OffAxisParticleProjectionPlot"]:
if cls is ParticleProjectionPlot:
normal = cls.sanitize_normal_vector(ds, normal)
if isinstance(normal, str):
cls = AxisAlignedParticleProjectionPlot
else:
cls = OffAxisParticleProjectionPlot
self = object.__new__(cls)
return self # type: ignore [return-value]


class AxisAlignedParticleProjectionPlot(ParticleProjectionPlot, PWViewerMPL):
_plot_type = "Particle"
_frb_generator = ParticleImageBuffer

def __init__(
self,
ds,
Expand Down Expand Up @@ -387,67 +409,155 @@ def __init__(
self._use_cbar = False
splat_color = color

if isinstance(normal, str):
axis = fix_axis(normal, ds)
(bounds, center, display_center) = get_window_parameters(
axis, center, width, ds
axis = fix_axis(normal, ds)
(bounds, center, display_center) = get_window_parameters(
axis, center, width, ds
)
x_coord = ds.coordinates.x_axis[axis]
y_coord = ds.coordinates.y_axis[axis]

depth = ds.coordinates.sanitize_depth(depth)

width = np.zeros_like(center)
width[x_coord] = bounds[1] - bounds[0]
width[y_coord] = bounds[3] - bounds[2]
width[axis] = depth[0].in_units(width[x_coord].units)

ParticleSource = ParticleAxisAlignedDummyDataSource(
center,
ds,
axis,
width,
fields,
weight_field=weight_field,
field_parameters=field_parameters,
data_source=data_source,
deposition=deposition,
density=density,
)

oblique = False
plt_origin = origin
periodic = True

self.projected = weight_field is None

PWViewerMPL.__init__(
self,
ParticleSource,
bounds,
origin=plt_origin,
fontsize=fontsize,
fields=fields,
window_size=window_size,
aspect=aspect,
splat_color=splat_color,
geometry=ds.geometry,
periodic=periodic,
oblique=oblique,
)

self.set_axes_unit(axes_unit)

if not self._use_cbar:
self.hide_colorbar()


class OffAxisParticleProjectionPlot(ParticleProjectionPlot, PWViewerMPL):
_plot_type = "Particle"
_frb_generator = ParticleImageBuffer

def __init__(
self,
ds,
normal=None,
fields=None,
color="b",
center="center",
width=None,
depth=(1, "1"),
weight_field=None,
axes_unit=None,
origin="center-window",
fontsize=18,
field_parameters=None,
window_size=8.0,
aspect=None,
data_source=None,
deposition="ngp",
density=False,
*,
north_vector=None,
axis=None,
):
if axis is not None:
issue_deprecation_warning(
"The 'axis' argument is a deprecated alias for the 'normal' argument. ",
stacklevel=3,
since="4.2",
)
normal = axis
if normal is None:
raise TypeError(
"ParticleProjectionPlot() missing 1 required positional argument: 'normal'"
)
x_coord = ds.coordinates.x_axis[axis]
y_coord = ds.coordinates.y_axis[axis]

depth = ds.coordinates.sanitize_depth(depth)

width = np.zeros_like(center)
width[x_coord] = bounds[1] - bounds[0]
width[y_coord] = bounds[3] - bounds[2]
width[axis] = depth[0].in_units(width[x_coord].units)

ParticleSource = ParticleAxisAlignedDummyDataSource(
center,
ds,
axis,
width,
fields,
weight_field=weight_field,
field_parameters=field_parameters,
data_source=data_source,
deposition=deposition,
density=density,
if data_source is not None:
warnings.warn(
"data_source argument has no effect with an "
"off-axis particle projection plot (not implemented)",
stacklevel=3, # TODO: check
)
# this will handle time series data and controllers
ts = self._initialize_dataset(ds)
self.ts = ts
ds = self.ds = ts[0]
normal = self.sanitize_normal_vector(ds, normal)
if field_parameters is None:
field_parameters = {}

oblique = False
plt_origin = origin
periodic = True
if axes_unit is None:
axes_unit = get_axes_unit(width, ds)

else:
(bounds, center_rot) = get_oblique_window_parameters(
normal, center, width, ds, depth=depth
)
# if no fields are passed in, we simply mark the x and
# y fields using a given color. Use the 'particle_ones'
# field to do this. We also turn off the colorbar in
# this case.
self._use_cbar = True
splat_color = None
if fields is None:
fields = [("all", "particle_ones")]
weight_field = ("all", "particle_ones")
self._use_cbar = False
splat_color = color

width = ds.coordinates.sanitize_width(normal, width, depth)

ParticleSource = ParticleOffAxisDummyDataSource(
center_rot,
ds,
normal,
width,
fields,
weight_field=weight_field,
field_parameters=field_parameters,
data_source=None,
deposition=deposition,
density=density,
north_vector=north_vector,
)
(bounds, center_rot) = get_oblique_window_parameters(
normal, center, width, ds, depth=depth
)

width = ds.coordinates.sanitize_width(normal, width, depth)

ParticleSource = ParticleOffAxisDummyDataSource(
center_rot,
ds,
normal,
width,
fields,
weight_field=weight_field,
field_parameters=field_parameters,
data_source=None,
deposition=deposition,
density=density,
north_vector=north_vector,
)

oblique = True
periodic = False
if origin != "center-window":
mylog.warning(
"The 'origin' keyword is ignored for off-axis "
"particle projections, it is always 'center-window'"
)
plt_origin = "center-window"
oblique = True
periodic = False
if origin != "center-window":
mylog.warning(
"The 'origin' keyword is ignored for off-axis "
"particle projections, it is always 'center-window'"
)
plt_origin = "center-window"

self.projected = weight_field is None

Expand Down

0 comments on commit c4c73af

Please sign in to comment.