Skip to content

Commit 11a7b21

Browse files
authored
feat(views): optional filtering for structured views (#78)
1 parent eac0515 commit 11a7b21

10 files changed

+261
-45
lines changed

src/dbally/iql/_exceptions.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,19 @@ def __init__(self, source: str) -> None:
2020
super().__init__(message, source)
2121

2222

23-
class IQLEmptyExpressionError(IQLError):
24-
"""Raised when IQL expression is empty."""
23+
class IQLNoStatementError(IQLError):
24+
"""Raised when IQL does not have any statement."""
2525

2626
def __init__(self, source: str) -> None:
27-
message = "Empty IQL expression"
27+
message = "Empty IQL"
2828
super().__init__(message, source)
2929

3030

31-
class IQLMultipleExpressionsError(IQLError):
32-
"""Raised when IQL contains multiple expressions."""
31+
class IQLMultipleStatementsError(IQLError):
32+
"""Raised when IQL contains multiple statements."""
3333

3434
def __init__(self, nodes: List[ast.stmt], source: str) -> None:
35-
message = "Multiple expressions or statements in IQL are not supported"
35+
message = "Multiple statements in IQL are not supported"
3636
super().__init__(message, source)
3737
self.nodes = nodes
3838

src/dbally/iql/_processor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from dbally.iql._exceptions import (
77
IQLArgumentParsingError,
88
IQLArgumentValidationError,
9-
IQLEmptyExpressionError,
109
IQLFunctionNotExists,
1110
IQLIncorrectNumberArgumentsError,
12-
IQLMultipleExpressionsError,
11+
IQLMultipleStatementsError,
1312
IQLNoExpressionError,
13+
IQLNoStatementError,
1414
IQLSyntaxError,
1515
IQLUnsupportedSyntaxError,
1616
)
@@ -50,10 +50,10 @@ async def process(self) -> syntax.Node:
5050
raise IQLSyntaxError(self.source) from exc
5151

5252
if not ast_tree.body:
53-
raise IQLEmptyExpressionError(self.source)
53+
raise IQLNoStatementError(self.source)
5454

5555
if len(ast_tree.body) > 1:
56-
raise IQLMultipleExpressionsError(ast_tree.body, self.source)
56+
raise IQLMultipleStatementsError(ast_tree.body, self.source)
5757

5858
if not isinstance(ast_tree.body[0], ast.Expr):
5959
raise IQLNoExpressionError(ast_tree.body[0], self.source)

src/dbally/iql_generator/iql_generator.py

+104-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22

33
from dbally.audit.event_tracker import EventTracker
44
from dbally.iql import IQLError, IQLQuery
5-
from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat
5+
from dbally.iql_generator.prompt import (
6+
FILTERING_DECISION_TEMPLATE,
7+
IQL_GENERATION_TEMPLATE,
8+
FilteringDecisionPromptFormat,
9+
IQLGenerationPromptFormat,
10+
)
611
from dbally.llms.base import LLM
712
from dbally.llms.clients.base import LLMOptions
813
from dbally.llms.clients.exceptions import LLMError
@@ -25,17 +30,110 @@ class IQLGenerator:
2530
It uses LLM to generate text-based responses, passing in the prompt template, formatted filters, and user question.
2631
"""
2732

28-
def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None) -> None:
33+
def __init__(
34+
self,
35+
llm: LLM,
36+
*,
37+
decision_prompt: Optional[PromptTemplate[FilteringDecisionPromptFormat]] = None,
38+
generation_prompt: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None,
39+
) -> None:
2940
"""
3041
Constructs a new IQLGenerator instance.
3142
3243
Args:
33-
llm: LLM used to generate IQL
44+
llm: LLM used to generate IQL.
45+
decision_prompt: Prompt template for filtering decision making.
46+
generation_prompt: Prompt template for IQL generation.
3447
"""
3548
self._llm = llm
36-
self._prompt_template = prompt_template or IQL_GENERATION_TEMPLATE
49+
self._decision_prompt = decision_prompt or FILTERING_DECISION_TEMPLATE
50+
self._generation_prompt = generation_prompt or IQL_GENERATION_TEMPLATE
51+
52+
async def generate(
53+
self,
54+
question: str,
55+
filters: List[ExposedFunction],
56+
event_tracker: EventTracker,
57+
examples: Optional[List[FewShotExample]] = None,
58+
llm_options: Optional[LLMOptions] = None,
59+
n_retries: int = 3,
60+
) -> Optional[IQLQuery]:
61+
"""
62+
Generates IQL in text form using LLM.
63+
64+
Args:
65+
question: User question.
66+
filters: List of filters exposed by the view.
67+
event_tracker: Event store used to audit the generation process.
68+
examples: List of examples to be injected into the conversation.
69+
llm_options: Options to use for the LLM client.
70+
n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection.
71+
72+
Returns:
73+
Generated IQL query or None if the decision is not to continue.
74+
75+
Raises:
76+
LLMError: If LLM text generation fails after all retries.
77+
IQLError: If IQL parsing fails after all retries.
78+
UnsupportedQueryError: If the question is not supported by the view.
79+
"""
80+
decision = await self._decide_on_generation(
81+
question=question,
82+
event_tracker=event_tracker,
83+
llm_options=llm_options,
84+
n_retries=n_retries,
85+
)
86+
if not decision:
87+
return None
88+
89+
return await self._generate_iql(
90+
question=question,
91+
filters=filters,
92+
event_tracker=event_tracker,
93+
examples=examples,
94+
llm_options=llm_options,
95+
n_retries=n_retries,
96+
)
97+
98+
async def _decide_on_generation(
99+
self,
100+
question: str,
101+
event_tracker: EventTracker,
102+
llm_options: Optional[LLMOptions] = None,
103+
n_retries: int = 3,
104+
) -> bool:
105+
"""
106+
Decides whether the question requires filtering or not.
107+
108+
Args:
109+
question: User question.
110+
event_tracker: Event store used to audit the generation process.
111+
llm_options: Options to use for the LLM client.
112+
n_retries: Number of retries to LLM API in case of errors.
113+
114+
Returns:
115+
Decision whether to generate IQL or not.
116+
117+
Raises:
118+
LLMError: If LLM text generation fails after all retries.
119+
"""
120+
prompt_format = FilteringDecisionPromptFormat(question=question)
121+
formatted_prompt = self._decision_prompt.format_prompt(prompt_format)
122+
123+
for retry in range(n_retries + 1):
124+
try:
125+
response = await self._llm.generate_text(
126+
prompt=formatted_prompt,
127+
event_tracker=event_tracker,
128+
options=llm_options,
129+
)
130+
# TODO: Move response parsing to llm generate_text method
131+
return formatted_prompt.response_parser(response)
132+
except LLMError as exc:
133+
if retry == n_retries:
134+
raise exc
37135

38-
async def generate_iql(
136+
async def _generate_iql(
39137
self,
40138
question: str,
41139
filters: List[ExposedFunction],
@@ -68,7 +166,7 @@ async def generate_iql(
68166
filters=filters,
69167
examples=examples,
70168
)
71-
formatted_prompt = self._prompt_template.format_prompt(prompt_format)
169+
formatted_prompt = self._generation_prompt.format_prompt(prompt_format)
72170

73171
for retry in range(n_retries + 1):
74172
try:

src/dbally/iql_generator/prompt.py

+65
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,41 @@ def _validate_iql_response(llm_response: str) -> str:
3434
return llm_response
3535

3636

37+
def _decision_iql_response_parser(response: str) -> bool:
38+
"""
39+
Parses the response from the decision prompt.
40+
41+
Args:
42+
response: Response from the LLM.
43+
44+
Returns:
45+
True if the response is positive, False otherwise.
46+
"""
47+
response = response.lower()
48+
if "decision:" not in response:
49+
return False
50+
51+
_, decision = response.split("decision:", 1)
52+
return "true" in decision
53+
54+
55+
class FilteringDecisionPromptFormat(PromptFormat):
56+
"""
57+
IQL prompt format, providing a question and filters to be used in the conversation.
58+
"""
59+
60+
def __init__(self, *, question: str, examples: List[FewShotExample] = None) -> None:
61+
"""
62+
Constructs a new IQLGenerationPromptFormat instance.
63+
64+
Args:
65+
question: Question to be asked.
66+
examples: List of examples to be injected into the conversation.
67+
"""
68+
super().__init__(examples)
69+
self.question = question
70+
71+
3772
class IQLGenerationPromptFormat(PromptFormat):
3873
"""
3974
IQL prompt format, providing a question and filters to be used in the conversation.
@@ -85,3 +120,33 @@ def __init__(
85120
],
86121
response_parser=_validate_iql_response,
87122
)
123+
124+
125+
FILTERING_DECISION_TEMPLATE = PromptTemplate[FilteringDecisionPromptFormat](
126+
[
127+
{
128+
"role": "system",
129+
"content": (
130+
"Given a question, determine whether the answer requires initial data filtering in order to compute it.\n"
131+
"Initial data filtering is a process in which the result set is reduced to only include the rows "
132+
"that meet certain criteria specified in the question.\n\n"
133+
"---\n\n"
134+
"Follow the following format.\n\n"
135+
"Question: ${{question}}\n"
136+
"Hint: ${{hint}}"
137+
"Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n"
138+
"Decision: indicates whether the answer to the question requires initial data filtering. "
139+
"(Respond with True or False)\n\n"
140+
),
141+
},
142+
{
143+
"role": "user",
144+
"content": (
145+
"Question: {question}\n"
146+
"Hint: Look for words indicating data specific features.\n"
147+
"Reasoning: Let's think step by step in order to "
148+
),
149+
},
150+
],
151+
response_parser=_decision_iql_response_parser,
152+
)

src/dbally/views/pandas_base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
from functools import reduce
3+
from typing import Optional
34

45
import pandas as pd
56

@@ -25,7 +26,7 @@ def __init__(self, df: pd.DataFrame) -> None:
2526
self.df = df
2627

2728
# The mask to be applied to the dataframe to filter the data
28-
self._filter_mask: pd.Series = None
29+
self._filter_mask: Optional[pd.Series] = None
2930

3031
async def apply_filters(self, filters: IQLQuery) -> None:
3132
"""

src/dbally/views/sqlalchemy_base.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import asyncio
3+
from typing import Optional
34

45
import sqlalchemy
56

@@ -13,10 +14,11 @@ class SqlAlchemyBaseView(MethodsBaseView):
1314
Base class for views that use SQLAlchemy to generate SQL queries.
1415
"""
1516

16-
def __init__(self, sqlalchemy_engine: sqlalchemy.engine.Engine) -> None:
17+
def __init__(self, sqlalchemy_engine: sqlalchemy.Engine) -> None:
1718
super().__init__()
18-
self._select = self.get_select()
1919
self._sqlalchemy_engine = sqlalchemy_engine
20+
self._select = self.get_select()
21+
self._where_clause: Optional[sqlalchemy.ColumnElement] = None
2022

2123
@abc.abstractmethod
2224
def get_select(self) -> sqlalchemy.Select:
@@ -34,7 +36,7 @@ async def apply_filters(self, filters: IQLQuery) -> None:
3436
Args:
3537
filters: IQLQuery object representing the filters to apply
3638
"""
37-
self._select = self._select.where(await self._build_filter_node(filters.root))
39+
self._where_clause = await self._build_filter_node(filters.root)
3840

3941
async def _build_filter_node(self, node: syntax.Node) -> sqlalchemy.ColumnElement:
4042
"""
@@ -75,8 +77,11 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:
7577
Results of the query where `results` will be a list of dictionaries representing retrieved rows or an empty\
7678
list if `dry_run` is set to `True`. Inside the `context` field the generated sql will be stored.
7779
"""
78-
7980
results = []
81+
82+
if self._where_clause is not None:
83+
self._select = self._select.where(self._where_clause)
84+
8085
sql = str(self._select.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True}))
8186

8287
if not dry_run:

src/dbally/views/structured.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ async def ask(
6969
examples = self.list_few_shots()
7070

7171
try:
72-
iql = await iql_generator.generate_iql(
72+
iql = await iql_generator.generate(
7373
question=query,
7474
filters=filters,
7575
examples=examples,
@@ -90,10 +90,11 @@ async def ask(
9090
aggregation=None,
9191
) from exc
9292

93-
await self.apply_filters(iql)
93+
if iql:
94+
await self.apply_filters(iql)
9495

9596
result = self.execute(dry_run=dry_run)
96-
result.context["iql"] = f"{iql}"
97+
result.context["iql"] = str(iql) if iql else None
9798

9899
return result
99100

tests/unit/iql/test_iql_parser.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from dbally.iql import IQLArgumentParsingError, IQLQuery, IQLUnsupportedSyntaxError, syntax
77
from dbally.iql._exceptions import (
88
IQLArgumentValidationError,
9-
IQLEmptyExpressionError,
109
IQLFunctionNotExists,
1110
IQLIncorrectNumberArgumentsError,
12-
IQLMultipleExpressionsError,
11+
IQLMultipleStatementsError,
1312
IQLNoExpressionError,
13+
IQLNoStatementError,
1414
IQLSyntaxError,
1515
)
1616
from dbally.iql._processor import IQLProcessor
@@ -95,7 +95,7 @@ async def test_iql_parser_syntax_error():
9595

9696

9797
async def test_iql_parser_multiple_expression_error():
98-
with pytest.raises(IQLMultipleExpressionsError) as exc_info:
98+
with pytest.raises(IQLMultipleStatementsError) as exc_info:
9999
await IQLQuery.parse(
100100
"filter_by_age\nfilter_by_age",
101101
allowed_functions=[
@@ -109,11 +109,11 @@ async def test_iql_parser_multiple_expression_error():
109109
],
110110
)
111111

112-
assert exc_info.match(re.escape("Multiple expressions or statements in IQL are not supported"))
112+
assert exc_info.match(re.escape("Multiple statements in IQL are not supported"))
113113

114114

115115
async def test_iql_parser_empty_expression_error():
116-
with pytest.raises(IQLEmptyExpressionError) as exc_info:
116+
with pytest.raises(IQLNoStatementError) as exc_info:
117117
await IQLQuery.parse(
118118
"",
119119
allowed_functions=[
@@ -127,7 +127,7 @@ async def test_iql_parser_empty_expression_error():
127127
],
128128
)
129129

130-
assert exc_info.match(re.escape("Empty IQL expression"))
130+
assert exc_info.match(re.escape("Empty IQL"))
131131

132132

133133
async def test_iql_parser_no_expression_error():

0 commit comments

Comments
 (0)