Skip to content
This repository has been archived by the owner on Mar 11, 2024. It is now read-only.

Commit

Permalink
Add legend functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
a-maliarov committed Apr 26, 2022
1 parent 91685fd commit d25e511
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 1 deletion.
50 changes: 49 additions & 1 deletion simpleplots/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .base import Theme, Axes, Size
from .utils import (get_indices_of_values_in_list, smartrange, normalize_values,
get_font, choose_locator, choose_formatter)
get_font, choose_locator, choose_formatter, get_text_dimensions)
from .visuals import Spines, PointsGrid, CustomImageDraw
from .themes import StandardTheme
from .ticker import Locator, Formatter
Expand Down Expand Up @@ -316,4 +316,52 @@ def title(self, text: str) -> None:
self.draw.text(xy=coords, text=text, font=title_font, anchor="mm",
fill=self.theme.title_color)

def legend(self, spacing: int = 4) -> None:
legend_font = get_font('legend', self.theme, self.width)
section = self.grid.get_legend_bbox(self.axes, legend_font)

labels = '\n'.join(['bbbb' + ax.label for ax in self.axes])
bbox = self.draw.multiline_textbbox(
xy=section['point'],
text=labels,
font=legend_font,
anchor=section['anchor'],
spacing=spacing
)

letter_size = get_text_dimensions('b', legend_font)
new_bbox = list(bbox)
new_bbox[0] -= letter_size[0] / 2
new_bbox[1] -= letter_size[1] / 2
new_bbox[2] += letter_size[0] / 2
new_bbox[3] += letter_size[1]

self.draw.rounded_rectangle(
xy=new_bbox,
radius=15,
fill=self.theme.figure_background_color,
outline=self.theme.grid_line_color,
width=self.theme.grid_line_width // 2
)

for i, axes in enumerate(self.axes):
line_coords = (
bbox[0],
bbox[1] + letter_size[1] / 2 + letter_size[1] * i + spacing * i,
bbox[0] + letter_size[0] * 3,
bbox[1] + letter_size[1] / 2 + letter_size[1] * i + spacing * i
)
self.draw.line(line_coords, width=axes.linewidth, fill=axes.color)

text_coords = (
bbox[0] + letter_size[0] * 4,
bbox[1] + letter_size[1] * i + spacing * i
)
self.draw.text(
xy=text_coords,
text=axes.label,
font=legend_font,
fill=self.theme.legend_color
)

#-------------------------------------------------------------------------------
4 changes: 4 additions & 0 deletions simpleplots/themes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,8 @@ class StandardTheme(Theme):
title_size_perc = 0.033
title_color = (0, 0, 0)

legend_font = 'arial.ttf'
legend_size_perc = 0.02
legend_color = (0, 0, 0)

#-------------------------------------------------------------------------------
6 changes: 6 additions & 0 deletions simpleplots/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def get_font(type_: str, theme: Theme, image_width: int) -> ImageFont:
int(image_width * theme.title_size_perc)
)

elif type_ == 'legend':
return ImageFont.truetype(
os.path.join(fonts_folder, theme.legend_font),
int(image_width * theme.legend_size_perc)
)

def get_text_dimensions(text_string: str, font: ImageFont) -> Size:
"""Calculates size of a given text string using given font."""
ascent, descent = font.getmetrics()
Expand Down
62 changes: 62 additions & 0 deletions simpleplots/visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@

#-------------------------------------------------------------------------------

def _point_in_bbox(point: Point, bbox: Coords) -> bool:
if (bbox.x0 <= point.x and point.x <= bbox.x1 and
bbox.y0 <= point.y and point.y <= bbox.y1):

return True

else:
return False

#-------------------------------------------------------------------------------

class CustomImageDraw(ImageDraw.ImageDraw):

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -237,4 +248,55 @@ def get_axes_points_coords(self, axes: Axes) -> List[Point]:
xy_indices = np.dstack(np.asarray([px, py]))[0]
return np.asarray([self.get_point_coords(x, y) for x, y in xy_indices])

def get_legend_bbox(self, axes: List[Axes], font: ImageFont) -> dict:
"""Returns coordinates of legend mask."""
priority = {
'00': 0, '10': 6, '20': 1,
'01': 4, '11': 8, '21': 5,
'02': 2, '12': 7, '22': 3
}

sections = [
{'anchor': 'la'}, {'anchor': 'ra'}, {'anchor': 'ld'},
{'anchor': 'rd'}, {'anchor': 'lm'}, {'anchor': 'rm'},
{'anchor': 'ma'}, {'anchor': 'md'}, {'anchor': 'mm'},
]

for x in range(3):
for y in range(3):
w = (self.spines.width - self.horizontal_offset * 1.5)
h = (self.spines.height - self.vertical_offset * 1.5)

ow = self.spines.horizontal_offset + self.horizontal_offset * 0.75
oh = self.spines.vertical_offset + self.vertical_offset * 0.75

coords = Coords(
x0 = ow + w / 3 * x,
y0 = oh + h / 3 * y,
x1 = ow + w / 3 * (x + 1),
y1 = oh + h / 3 * (y + 1)
)

point = Point(
x = ow + w / 2 * x,
y = oh + h / 2 * y
)

sections[priority[f'{x}{y}']]['bbox'] = coords
sections[priority[f'{x}{y}']]['point'] = point

pts = [self.get_axes_points_coords(ax) for ax in axes]
points = np.concatenate(pts)

for p in points:
point = Point(p[0], p[1])
for sk, sv in enumerate(sections):
if 'hits' not in sections[sk]:
sections[sk]['hits'] = 0
if _point_in_bbox(point, sv['bbox']):
sections[sk]['hits'] += 1

section = sorted(sections, key=lambda d: d['hits'])[0]
return section

#-------------------------------------------------------------------------------

0 comments on commit d25e511

Please sign in to comment.