Skip to content

Commit

Permalink
Fix plot dots
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Mar 7, 2024
1 parent 44b5870 commit cee56b4
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions scripts/plot_dots.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
def get_styles(exclude_solutions: list[str]) -> pl.DataFrame:
all_styles = pl.DataFrame(
data=[
["duckdb", "DuckDB", "#fff000", "o", 5.0],
["pandas", "pandas", "#e70488", "s", 5.0],
["polars", "Polars", "#adbac7", "p", 6.0],
["pyspark", "PySpark", "#e25a1c", "*", 7.0],
["polars", "Polars", "#0075ff", "o", 7.0],
["duckdb", "DuckDB", "#73BFB8", "^", 5.5],
["pandas", "pandas", "#26413C", "s", 5.0],
["pyspark", "PySpark", "#EFA9AE", "D", 4.5],
],
schema=["solution", "name", "color", "shape", "size"],
)
Expand Down Expand Up @@ -108,18 +108,18 @@ def formulate_caption(
)

notes = []
for name, group in notes_df.group_by("name"):
for name, group in notes_df.group_by(["name"]):
texts = group.get_column("text")
join_char = ", " if len(texts) >= 3 else " "

if len(texts) >= 2:
texts[-1] = "and " + texts[-1]

notes.append(f"{name} {join_char.join(texts)}.")
notes.append(f"{name[0]} {join_char.join(texts)}.")

if notes:
caption += f"Note: {' '.join(notes)} "
caption += "More information: https://www.pola.rs/benchmarks.html"
caption += "More information: https://www.pola.rs/benchmarks"
return "\n".join(textwrap.wrap(caption, int(width * 15 - 20)))


Expand Down Expand Up @@ -150,6 +150,15 @@ def create_plot(

styles = styles.join(timings, on="solution", how="semi")

# Cast columns to Enum to make sure the order is correct
timings = timings.with_columns(
pl.col("solution").cast(pl.Enum(styles.get_column("solution")))
)
name_versions = timings.select(
pl.col("name_version").sort_by("solution").unique(maintain_order=True)
).to_series()
timings = timings.with_columns(pl.col("name_version").cast(pl.Enum(name_versions)))

plot = (
p9.ggplot(
timings,
Expand All @@ -168,7 +177,7 @@ def create_plot(
+ p9.scale_shape_manual(values=styles.get_column("shape"))
+ p9.scale_size_manual(values=styles.get_column("size"))
+ p9.labs(
title="TPCH Benchmark",
title="TPC-H Benchmark",
subtitle=subtitle,
caption=caption,
x="duration (s)",
Expand Down

0 comments on commit cee56b4

Please sign in to comment.