Skip to content

Commit

Permalink
Adding test case for previous commit (5ffa26d)
Browse files Browse the repository at this point in the history
  • Loading branch information
parkervg committed Feb 22, 2024
1 parent 5ffa26d commit fc4aae7
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
29 changes: 28 additions & 1 deletion tests/test_single_table_blendsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
get_length,
select_first_sorted,
get_table_size,
select_first_option,
)


Expand All @@ -18,7 +19,13 @@ def db() -> SQLiteDBConnector:

@pytest.fixture
def ingredients() -> set:
return {starts_with, get_length, select_first_sorted, get_table_size}
return {
starts_with,
get_length,
select_first_sorted,
get_table_size,
select_first_option,
}


def test_simple_exec(db, ingredients):
Expand Down Expand Up @@ -457,5 +464,25 @@ def test_exists_isolated_qa_call(db, ingredients):
assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item()


def test_query_options_arg(db, ingredients):
# commit 5ffa26d
blendsql = """
{{
select_first_option(
'I hope this test works',
(SELECT * FROM transactions),
options=(SELECT DISTINCT merchant FROM transactions WHERE merchant = 'Paypal')
)
}}
"""
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
)
assert len(smoothie.df) == 1
assert smoothie.df.values.flat[0] == "Paypal"


if __name__ == "__main__":
pytest.main()
14 changes: 12 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
from typing import Iterable, Any, List, Union
from blendsql.ingredients import MapIngredient, QAIngredient, JoinIngredient
from blendsql.db.utils import single_quote_escape


class starts_with(MapIngredient):
Expand All @@ -24,7 +25,7 @@ def run(self, question: str, values: List[str], **kwargs) -> Iterable[int]:


class select_first_sorted(QAIngredient):
def run(self, question: str, options: List[str], **kwargs) -> Iterable[Any]:
def run(self, question: str, options: set, **kwargs) -> Iterable[Any]:
"""Simple test function, equivalent to the following in SQL:
`ORDER BY {colname} LIMIT 1`
"""
Expand All @@ -42,12 +43,21 @@ def run(

class get_table_size(QAIngredient):
def run(
self, question: str, context: pd.DataFrame, options: str = None, **kwargs
self, question: str, context: pd.DataFrame, options: set = None, **kwargs
) -> Union[str, int, float]:
"""Returns the length of the context subtable passed to it."""
return len(context)


class select_first_option(QAIngredient):
def run(
self, question: str, context: pd.DataFrame, options: set = None, **kwargs
) -> Union[str, int, float]:
"""Returns the first item in the (ordered) options set"""
assert options is not None
return f"'{single_quote_escape(sorted(list(options))[0])}'"


class do_join(JoinIngredient):
"""A very silly, overcomplicated way to do a traditional SQL join.
But useful for testing.
Expand Down

0 comments on commit fc4aae7

Please sign in to comment.