Skip to content

Commit

Permalink
Plot fixes (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Apr 16, 2024
1 parent 1cd26ee commit a791f33
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 139 deletions.
2 changes: 1 addition & 1 deletion queries/polars/q17.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def q() -> None:
part_ds.filter(pl.col("p_brand") == var_1)
.filter(pl.col("p_container") == var_2)
.join(line_item_ds, how="left", left_on="p_partkey", right_on="l_partkey")
).cache()
)

q_final = (
res_1.group_by("p_partkey")
Expand Down
272 changes: 138 additions & 134 deletions scripts/plot_bars.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
To use this script, run:
```shell
.venv/bin/python ./scripts/plot_results.py
.venv/bin/python -m scripts.plot_bars
```
"""

Expand All @@ -19,35 +19,143 @@
if TYPE_CHECKING:
from plotly.graph_objects import Figure

from settings import FileType

settings = Settings()

if settings.run.include_io:
LIMIT = settings.plot.limit_with_io
else:
LIMIT = settings.plot.limit_without_io

# colors for each bar

COLORS = {
"polars": "#f7c5a0",
"duckdb": "#fff000",
"pandas": "#72ccff",
"dask": "#efa9ae",
"pyspark": "#87f7cf",
"polars": "#0075FF",
"polars-eager": "#00B4D8",
"duckdb": "#80B9C8",
"pyspark": "#C29470",
"dask": "#77D487",
"pandas": "#2B8C5D",
"modin": "#50B05F",
}

# default base template for plot's theme
DEFAULT_THEME = "plotly_dark"

# other configuration
BAR_TYPE = "group"
LABEL_UPDATES = {
"x": "query",
"y": "seconds",
"color": "Solution",
"pattern_shape": "Solution",
SOLUTION_NAME_MAP = {
"polars": "Polars",
"polars-eager": "Polars - eager",
"duckdb": "DuckDB",
"pandas": "pandas",
"dask": "Dask",
"modin": "Modin",
"pyspark": "PySpark",
}


def main() -> None:
pl.Config.set_tbl_rows(100)
df = prep_data()
plot(df)


def prep_data() -> pl.DataFrame:
lf = pl.scan_csv(settings.paths.timings / settings.paths.timings_filename)

# Scale factor not used at the moment
lf = lf.drop("scale_factor")

# Select timings either with or without IO
if settings.run.include_io:
io = pl.col("include_io")
else:
io = ~pl.col("include_io")
lf = lf.filter(io).drop("include_io")

# Select relevant queries
lf = lf.filter(pl.col("query_number") <= settings.plot.n_queries)

# Get the last timing entry per solution/version/query combination
lf = lf.group_by("solution", "version", "query_number").last()

# Insert missing query entries
groups = lf.select("solution", "version").unique()
queries = pl.LazyFrame({"query_number": range(1, settings.plot.n_queries + 1)})
groups_queries = groups.join(queries, how="cross")
lf = groups_queries.join(lf, on=["solution", "version", "query_number"], how="left")
lf = lf.with_columns(pl.col("duration[s]").fill_null(0))

# Order the groups
solutions_in_data = lf.select("solution").collect().to_series().unique()
solution = pl.LazyFrame({"solution": [s for s in COLORS if s in solutions_in_data]})
lf = solution.join(lf, on=["solution"], how="left")

# Make query number a string
lf = lf.with_columns(pl.format("Q{}", "query_number").alias("query")).drop(
"query_number"
)

return lf.select("solution", "version", "query", "duration[s]").collect()


def plot(df: pl.DataFrame) -> Figure:
"""Generate a Plotly Figure of a grouped bar chart displaying benchmark results."""
x = df.get_column("query")
y = df.get_column("duration[s]")

group = df.select(
pl.format("{} ({})", pl.col("solution").replace(SOLUTION_NAME_MAP), "version")
).to_series()

# build plotly figure object
color_seq = [c for (s, c) in COLORS.items() if s in df["solution"].unique()]

fig = px.histogram(
x=x,
y=y,
color=group,
barmode="group",
template="plotly_white",
color_discrete_sequence=color_seq,
title=get_title(settings.run.include_io, settings.run.file_type),
)

fig.update_layout(
bargroupgap=0.1,
# paper_bgcolor="rgba(41,52,65,1)",
xaxis_title="Query",
yaxis_title="Seconds",
yaxis_range=[0, LIMIT],
# plot_bgcolor="rgba(41,52,65,1)",
margin={"t": 100},
legend={
"title": "",
"orientation": "h",
"xanchor": "center",
"yanchor": "top",
"x": 0.5,
},
)

add_annotations(fig, LIMIT, df)

write_plot_image(fig)

# display the object using available environment context
if settings.plot.show:
fig.show()


def get_title(include_io: bool, file_type: FileType) -> str:
if not include_io:
title = "Runtime excluding data read from disk"
else:
file_type_map = {"parquet": "Parquet", "csv": "CSV", "feather": "Feather"}
file_type_formatted = file_type_map[file_type]
title = f"Runtime including data read from disk ({file_type_formatted})"

subtitle = "(lower is better)"

return f"{title}<br><i>{subtitle}<i>"


def add_annotations(fig: Any, limit: int, df: pl.DataFrame) -> None:
# order of solutions in the file
# e.g. ['polar', 'pandas']
Expand All @@ -58,39 +166,30 @@ def add_annotations(fig: Any, limit: int, df: pl.DataFrame) -> None:
.with_row_index()
)

# every bar in the plot has a different offset for the text
start_offset = 10
offsets = [start_offset + 12 * i for i in range(bar_order.height)]

# we look for the solutions that surpassed the limit
# and create a text label for them
df = (
df.filter(pl.col("duration[s]") > limit)
.with_columns(
pl.when(pl.col("success"))
.then(
pl.format(
"{} took {} s", "solution", pl.col("duration[s]").cast(pl.Int32)
).alias("labels")
)
.otherwise(pl.format("{} had an internal error", "solution"))
pl.format(
"{} took {}s", "solution", pl.col("duration[s]").cast(pl.Int32)
).alias("labels")
)
.join(bar_order, on="solution")
.group_by("query_no")
.group_by("query")
.agg(pl.col("labels"), pl.col("index").min())
.with_columns(pl.col("labels").list.join(",\n"))
)

# then we create a dictionary similar to something like this:
# anno_data = {
# "q1": (offset, "label"),
# "q3": (offset, "label"),
# "q1": "label",
# "q3": "label",
# }

if df.height > 0:
anno_data = {
v[0]: (offsets[int(v[1])], v[2])
for v in df.select(["query_no", "index", "labels"])
v[0]: v[1]
for v in df.select("query", "labels")
.transpose()
.to_dict(as_series=False)
.values()
Expand All @@ -99,14 +198,14 @@ def add_annotations(fig: Any, limit: int, df: pl.DataFrame) -> None:
# a dummy with no text
anno_data = {"q1": (0, "")}

for q_name, (x_shift, anno_text) in anno_data.items():
for q_name, anno_text in anno_data.items():
fig.add_annotation(
align="right",
x=q_name,
y=LIMIT,
xshift=x_shift,
xshift=0,
yshift=30,
font={"color": "white"},
# font={"color": "white"},
showarrow=False,
text=anno_text,
)
Expand All @@ -122,104 +221,9 @@ def write_plot_image(fig: Any) -> None:
else:
file_name = "plot_without_io.html"

fig.write_html(path / file_name)


def plot(
df: pl.DataFrame,
x: str = "query_no",
y: str = "duration[s]",
group: str = "solution",
limit: int = 120,
) -> Figure:
"""Generate a Plotly Figure of a grouped bar chart displaying benchmark results.
Parameters
----------
df
DataFrame containing `x`, `y`, and `group`.
x
Column for X Axis. Defaults to "query_no".
y
Column for Y Axis. Defaults to "duration[s]".
group
Column for group. Defaults to "solution".
limit
height limit in seconds
Returns
-------
px.Figure: Plotly Figure (histogram)
"""
# build plotly figure object
fig = px.histogram(
x=df[x],
y=df[y],
color=df[group],
barmode=BAR_TYPE,
template=DEFAULT_THEME,
color_discrete_map=COLORS,
pattern_shape=df[group],
labels=LABEL_UPDATES,
)

fig.update_layout(
bargroupgap=0.1,
paper_bgcolor="rgba(41,52,65,1)",
yaxis_range=[0, limit],
plot_bgcolor="rgba(41,52,65,1)",
margin={"t": 100},
legend={
"orientation": "h",
"xanchor": "left",
"yanchor": "top",
"x": 0.37,
"y": -0.1,
},
)

add_annotations(fig, limit, df)
print(path / file_name)

write_plot_image(fig)

# display the object using available environment context
if settings.plot.show:
fig.show()


def main() -> None:
e = pl.lit(True)

if settings.run.include_io:
e = e & pl.col("include_io")
else:
e = e & ~pl.col("include_io")

df = (
pl.scan_csv(settings.paths.timings)
.filter(e)
# filter the max query to plot
.filter(
pl.col("query_no").str.extract(r"q(\d+)", 1).cast(int)
<= settings.plot.n_queries
)
# create a version no
.with_columns(
pl.when(pl.col("success")).then(pl.col("duration[s]")).otherwise(0),
pl.format("{}-{}", "solution", "version").alias("solution-version"),
)
# ensure we get the latest version
.sort("solution", "version")
.group_by("solution", "query_no", maintain_order=True)
.last()
.collect()
)
order = pl.DataFrame(
{"solution": ["polars", "duckdb", "pandas", "dask", "pyspark"]}
)
df = order.join(df, on="solution", how="left")

plot(df, limit=LIMIT, group="solution-version")
fig.write_html(path / file_name)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion scripts/plot_dots.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def parse_queries(s: str) -> list[str]:

def read_csv(filename: str) -> pl.DataFrame:
if filename == "-":
df = pl.read_csv(settings.paths.timings)
df = pl.read_csv(settings.paths.timings / settings.paths.timings_filename)
else:
df = pl.read_csv(filename)
return df
Expand Down
6 changes: 3 additions & 3 deletions settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Paths(BaseSettings):


class Run(BaseSettings):
include_io: bool = False
include_io: bool = True
file_type: FileType = "parquet"

log_timings: bool = False
Expand All @@ -46,8 +46,8 @@ class Run(BaseSettings):

class Plot(BaseSettings):
show: bool = False
n_queries: int = 8
limit_with_io: int = 15
n_queries: int = 7
limit_with_io: int = 20
limit_without_io: int = 15

model_config = SettingsConfigDict(
Expand Down

0 comments on commit a791f33

Please sign in to comment.