Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(benchmarks): aggregation benchmarks #84

Merged
merged 29 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6d66b00
Adding aggregation handling for SqlAlchemyBaseView extending quicksta…
PatrykWyzgowski Jul 12, 2024
be12e6e
Applying initial review feedback. Adding both filters and aggregation…
PatrykWyzgowski Jul 15, 2024
33d5b2e
Renaming subquery attribute and method argument to filtered_query
PatrykWyzgowski Jul 16, 2024
2e77fbc
Simplified question to the model.
PatrykWyzgowski Jul 17, 2024
41e88ed
Fixing unnecessary-pass.
PatrykWyzgowski Jul 17, 2024
dfe3e13
Continuation of review feedback application.
PatrykWyzgowski Jul 17, 2024
c09e68e
Adjusting filter prompt not to mix IQL with 'UNSUPPORTED QUERY'. Furt…
PatrykWyzgowski Jul 17, 2024
c6bbf90
Applied changes suggested in a comment related to Aggregations not ge…
PatrykWyzgowski Jul 18, 2024
4765dde
Applying pre-commit hooks.
PatrykWyzgowski Jul 18, 2024
2918ba5
Mocking AggregationFormat in tests.
PatrykWyzgowski Jul 18, 2024
0b7e50a
Merge branch 'main' into pw/add-single-aggregation
PatrykWyzgowski Jul 19, 2024
5511729
Mocking methods of the view related to aggregations to make them comp…
PatrykWyzgowski Jul 19, 2024
ae26c8b
Pre-commit fixes.
PatrykWyzgowski Jul 19, 2024
63c3adc
merge main
micpst Aug 12, 2024
a2169f2
revert to prev approach
micpst Aug 16, 2024
013cb69
fix tests
micpst Aug 16, 2024
f0a2f6e
add more tests
micpst Aug 16, 2024
aeb6295
trying to fix tests (localy working)
micpst Aug 16, 2024
d21f4e1
fix tests for python 3.8
micpst Aug 17, 2024
e473cc4
update views
micpst Aug 19, 2024
9e37c82
fix bench
micpst Aug 30, 2024
db368ca
merge main
micpst Sep 2, 2024
d3cbc37
fixes
micpst Sep 2, 2024
21b7502
fix metrics and ag view
micpst Sep 2, 2024
e7646ad
improve metrics
micpst Sep 2, 2024
3beff6d
small fix
micpst Sep 2, 2024
e69fc07
fix sql ex metric
micpst Sep 2, 2024
f351b29
Merge branch 'main' into mp/aggregation-benchmarks
micpst Sep 23, 2024
0193e4b
Merge branch 'main' into mp/aggregation-benchmarks
micpst Sep 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions benchmarks/sql/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
from bench.evaluator import Evaluator
from bench.loaders import CollectionDataLoader, IQLViewDataLoader, SQLViewDataLoader
from bench.metrics import (
AggregationAccuracy,
ExecutionAccuracy,
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLAggregationCorrectness,
IQLAggregationParseability,
IQLFiltersAccuracy,
IQLFiltersCorrectness,
IQLFiltersParseability,
Expand Down Expand Up @@ -57,9 +60,12 @@ class EvaluationType(Enum):

EVALUATION_METRICS = {
EvaluationType.IQL.value: MetricSet(
AggregationAccuracy,
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLAggregationParseability,
IQLAggregationCorrectness,
IQLFiltersAccuracy,
IQLFiltersPrecision,
IQLFiltersRecall,
Expand All @@ -72,9 +78,12 @@ class EvaluationType(Enum):
ExecutionAccuracy,
),
EvaluationType.E2E.value: MetricSet(
AggregationAccuracy,
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLAggregationParseability,
IQLAggregationCorrectness,
IQLFiltersAccuracy,
IQLFiltersPrecision,
IQLFiltersRecall,
Expand Down
6 changes: 6 additions & 0 deletions benchmarks/sql/bench/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from .base import Metric, MetricSet
from .iql import (
AggregationAccuracy,
FilteringAccuracy,
FilteringPrecision,
FilteringRecall,
IQLAggregationCorrectness,
IQLAggregationParseability,
IQLFiltersAccuracy,
IQLFiltersCorrectness,
IQLFiltersParseability,
Expand All @@ -15,14 +18,17 @@
__all__ = [
"Metric",
"MetricSet",
"AggregationAccuracy",
"FilteringAccuracy",
"FilteringPrecision",
"FilteringRecall",
"IQLAggregationParseability",
"IQLFiltersAccuracy",
"IQLFiltersPrecision",
"IQLFiltersRecall",
"IQLFiltersParseability",
"IQLFiltersCorrectness",
"IQLAggregationCorrectness",
"SQLExactMatch",
"ViewSelectionAccuracy",
"ViewSelectionPrecision",
Expand Down
128 changes: 106 additions & 22 deletions benchmarks/sql/bench/metrics/iql.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,49 @@
from abc import ABC
from typing import Any, Dict, List

from ..pipelines import EvaluationResult
from .base import Metric


class FilteringAccuracy(Metric):
class AssessingAccuracy(Metric, ABC):
"""
Filtering accuracy is proportion of correct decisions (to filter or not) out of all decisions made.
Assessing accuracy is proportion of correct decisions out of all decisions made.
"""

prefix: str
iql: str

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the filtering accuracy.
Computes the assessing accuracy.

Args:
results: List of evaluation results.

Returns:
Filtering accuracy.
Assessing accuracy.
"""
results = [result for result in results if result.reference.iql and result.prediction.iql]
results = [
result
for result in results
if result.reference.iql
and result.prediction.iql
and result.reference.view_name
and result.prediction.view_name
and getattr(result.reference.iql, self.iql).generated
and getattr(result.prediction.iql, self.iql).generated
]
return {
"DM/FLT/ACC": (
f"DM/{self.prefix}/ACC": (
sum(
isinstance(result.prediction.iql.filters.source, type(result.reference.iql.filters.source))
and result.prediction.iql.filters.unsupported == result.reference.iql.filters.unsupported
(
getattr(result.reference.iql, self.iql).source is not None
or getattr(result.reference.iql, self.iql).unsupported
)
== (
getattr(result.prediction.iql, self.iql).source is not None
or getattr(result.prediction.iql, self.iql).unsupported
)
for result in results
)
/ len(results)
Expand All @@ -34,6 +53,24 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
}


class FilteringAccuracy(AssessingAccuracy):
"""
Filtering accuracy is proportion of correct decisions (to filter or not) out of all decisions made.
"""

prefix: str = "FLT"
iql: str = "filters"


class AggregationAccuracy(AssessingAccuracy):
"""
Aggregation accuracy is proportion of correct decisions (to aggregate or not) out of all decisions made.
"""

prefix: str = "AGG"
iql: str = "aggregation"


class FilteringPrecision(Metric):
"""
Filtering precision is proportion of correct decisions to filter out of all decisions to filter.
Expand Down Expand Up @@ -222,11 +259,14 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
}


class IQLFiltersParseability(Metric):
class IQLParseability(Metric, ABC):
"""
IQL filters parseability is proportion of syntactically correct (parseable) IQLs out of all generated IQLs.
"""

prefix: str
iql: str

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the IQL filters parseability.
Expand All @@ -241,46 +281,90 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
result
for result in results
if (result.reference.iql and result.prediction.iql)
and (result.reference.iql.filters and result.prediction.iql.filters)
and (result.reference.iql.filters.source and result.prediction.iql.filters.source)
and (getattr(result.reference.iql, self.iql) and getattr(result.prediction.iql, self.iql))
and (getattr(result.reference.iql, self.iql).source and getattr(result.prediction.iql, self.iql).source)
]
return {
"IQL/FLT/PARSEABILITY": (
sum(result.prediction.iql.filters.valid for result in results) / len(results) if results else None
f"IQL/{self.prefix}/PARSEABILITY": (
sum(getattr(result.prediction.iql, self.iql).valid for result in results) / len(results)
if results
else None
)
}


class IQLFiltersCorrectness(Metric):
class IQLFiltersParseability(IQLParseability):
"""
IQL filters correctness is proportion of IQLs that produce correct results out of all parseable IQLs.
IQL filters parseability is proportion of syntactically correct (parseable) IQLs out of all generated IQLs.
"""

prefix: str = "FLT"
iql: str = "filters"


class IQLAggregationParseability(IQLParseability):
"""
IQL aggregation parseability is proportion of syntactically correct (parseable) IQLs out of all generated IQLs.
"""

prefix: str = "AGG"
iql: str = "aggregation"


class IQLCorrectness(Metric, ABC):
"""
IQL correctness is proportion of IQLs that produce correct results out of all parseable IQLs.
"""

prefix: str
iql: str

def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
"""
Computes the IQL filters correctness.
Computes the IQL correctness.

Args:
results: List of evaluation results.

Returns:
IQL filters correctness.
IQL correctness.
"""
results = [
result
for result in results
if (result.reference.iql and result.prediction.iql)
and (
result.reference.iql.filters.source
and result.prediction.iql.filters.source
and result.prediction.iql.filters.valid
getattr(result.reference.iql, self.iql).source
and getattr(result.prediction.iql, self.iql).source
and getattr(result.prediction.iql, self.iql).valid
)
]
return {
"IQL/FLT/CORRECTNESS": (
sum(result.prediction.iql.filters.source == result.reference.iql.filters.source for result in results)
f"IQL/{self.prefix}/CORRECTNESS": (
sum(
getattr(result.prediction.iql, self.iql).source == getattr(result.reference.iql, self.iql).source
for result in results
)
/ len(results)
if results
else None
)
}


class IQLFiltersCorrectness(IQLCorrectness):
"""
IQL filters correctness is proportion of IQLs that produce correct results out of all parseable IQLs.
"""

prefix: str = "FLT"
iql: str = "filters"


class IQLAggregationCorrectness(IQLCorrectness):
"""
IQL aggregation correctness is proportion of IQLs that produce correct results out of all parseable IQLs.
"""

prefix: str = "AGG"
iql: str = "aggregation"
9 changes: 6 additions & 3 deletions benchmarks/sql/bench/metrics/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
Returns:
The exact match ratio.
"""
results = [result for result in results if result.reference.sql and result.prediction.sql]
return {
"SQL/EM": (
sum(result.prediction.sql == result.reference.sql for result in results) / len(results)
Expand Down Expand Up @@ -95,6 +96,7 @@ def compute(self, results: List[EvaluationResult]) -> Dict[str, Any]:
Returns:
Execution accuracy score and valid efficiency score.
"""
results = [result for result in results if result.reference.sql and result.prediction.sql]
accurate_results = [result for result in results if self._execution_accuracy(result)]
return {
"EX": len(accurate_results) / len(results) if results else None,
Expand All @@ -121,9 +123,6 @@ def _execution_accuracy(self, result: EvaluationResult) -> bool:
Returns:
True if the execution results are identical, False otherwise.
"""
if result.prediction.sql is None:
return False

try:
ref_results = self._execute_query(result.reference.sql, result.db_id)
pred_results = self._execute_query(result.prediction.sql, result.db_id)
Expand All @@ -138,6 +137,10 @@ def _execution_accuracy(self, result: EvaluationResult) -> bool:
if reference.shape[0] != prediction.shape[0]:
return False

# If both dataframes have only one column, compare the values directly
if reference.shape[1] == prediction.shape[1] == 1:
return reference.iloc[:, 0].equals(prediction.iloc[:, 0])

# Returned view may have the same columns, or more columns than the ground truth
if not reference.columns.isin(prediction.columns).all():
return False
Expand Down
Loading
Loading