Skip to content

Commit

Permalink
[vis] Fix annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jun 4, 2023
1 parent 38ae29a commit 1c14b08
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions phi/vis/_matplotlib/_matplotlib_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ def create_figure(self,
y_range = (1e-3 * y_range[1], y_range[1])
axis.set_ylim(y_range)
# --- Equal aspect ---
if None not in x_range and None not in y_range and '_' not in bounds.vector.item_names and all([n in ['x', 'y', 'z'] for n in bounds.vector.item_names]):
max_aspect = 4 if all([n in ['x', 'y', 'z'] for n in bounds.vector.item_names]) else 1.5
if None not in x_range and None not in y_range and '_' not in bounds.vector.item_names:
x_size, y_size = x_range[1] - x_range[0], y_range[1] - y_range[0]
if not x_log and not y_log and x_size > 0 and y_size > 0 and max(x_size/y_size/subplot_aspect, y_size/x_size*subplot_aspect) < 4:
if not x_log and not y_log and x_size > 0 and y_size > 0 and max(x_size/y_size/subplot_aspect, y_size/x_size*subplot_aspect) < max_aspect:
axis.set_aspect('equal', adjustable='box')
# --- Remove labels if axes shared ---
for left_col in range(col):
Expand Down Expand Up @@ -513,7 +514,7 @@ def _annotate_points(axis, points: math.Tensor, color: Tensor, alpha: Tensor):
y_view = axis.get_ylim()[1] - axis.get_ylim()[0]
x_c = .95 * axis.get_xlim()[1] + .1 * axis.get_xlim()[0]
y_c = .95 * axis.get_ylim()[1] + .1 * axis.get_ylim()[0]
for x, y, idx in zip(xs, ys, labelled_dims.meshgrid()):
for x, y, idx, idx_n in zip(xs, ys, labelled_dims.meshgrid(), labelled_dims.meshgrid(names=True)):
if axis.get_xscale() == 'log':
offset_x = x * (1 + .0003 * x_view) if x < x_c else x * (1 - .0003 * x_view)
else:
Expand All @@ -522,7 +523,7 @@ def _annotate_points(axis, points: math.Tensor, color: Tensor, alpha: Tensor):
offset_y = y * (1 + .0003 * y_view) if y < y_c else y * (1 - .0003 * y_view)
else:
offset_y = y + .01 * y_view if y < y_c else y - .01 * y_view
axis.text(offset_x, offset_y, index_label(idx), color=_plt_col(color[idx]), alpha=float(alpha[idx]))
axis.text(offset_x, offset_y, index_label(idx_n), color=_plt_col(color[idx]), alpha=float(alpha[idx]))


class PointCloud3D(Recipe):
Expand Down

0 comments on commit 1c14b08

Please sign in to comment.