2
2
3
3
from dbally .audit .event_tracker import EventTracker
4
4
from dbally .iql import IQLError , IQLQuery
5
- from dbally .iql_generator .prompt import IQL_GENERATION_TEMPLATE , IQLGenerationPromptFormat
5
+ from dbally .iql_generator .prompt import (
6
+ FILTERING_DECISION_TEMPLATE ,
7
+ IQL_GENERATION_TEMPLATE ,
8
+ FilteringDecisionPromptFormat ,
9
+ IQLGenerationPromptFormat ,
10
+ )
6
11
from dbally .llms .base import LLM
7
12
from dbally .llms .clients .base import LLMOptions
8
13
from dbally .llms .clients .exceptions import LLMError
@@ -25,17 +30,110 @@ class IQLGenerator:
25
30
It uses LLM to generate text-based responses, passing in the prompt template, formatted filters, and user question.
26
31
"""
27
32
28
- def __init__ (self , llm : LLM , prompt_template : Optional [PromptTemplate [IQLGenerationPromptFormat ]] = None ) -> None :
33
+ def __init__ (
34
+ self ,
35
+ llm : LLM ,
36
+ * ,
37
+ decision_prompt : Optional [PromptTemplate [FilteringDecisionPromptFormat ]] = None ,
38
+ generation_prompt : Optional [PromptTemplate [IQLGenerationPromptFormat ]] = None ,
39
+ ) -> None :
29
40
"""
30
41
Constructs a new IQLGenerator instance.
31
42
32
43
Args:
33
- llm: LLM used to generate IQL
44
+ llm: LLM used to generate IQL.
45
+ decision_prompt: Prompt template for filtering decision making.
46
+ generation_prompt: Prompt template for IQL generation.
34
47
"""
35
48
self ._llm = llm
36
- self ._prompt_template = prompt_template or IQL_GENERATION_TEMPLATE
49
+ self ._decision_prompt = decision_prompt or FILTERING_DECISION_TEMPLATE
50
+ self ._generation_prompt = generation_prompt or IQL_GENERATION_TEMPLATE
51
+
52
+ async def generate (
53
+ self ,
54
+ question : str ,
55
+ filters : List [ExposedFunction ],
56
+ event_tracker : EventTracker ,
57
+ examples : Optional [List [FewShotExample ]] = None ,
58
+ llm_options : Optional [LLMOptions ] = None ,
59
+ n_retries : int = 3 ,
60
+ ) -> Optional [IQLQuery ]:
61
+ """
62
+ Generates IQL in text form using LLM.
63
+
64
+ Args:
65
+ question: User question.
66
+ filters: List of filters exposed by the view.
67
+ event_tracker: Event store used to audit the generation process.
68
+ examples: List of examples to be injected into the conversation.
69
+ llm_options: Options to use for the LLM client.
70
+ n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection.
71
+
72
+ Returns:
73
+ Generated IQL query or None if the decision is not to continue.
74
+
75
+ Raises:
76
+ LLMError: If LLM text generation fails after all retries.
77
+ IQLError: If IQL parsing fails after all retries.
78
+ UnsupportedQueryError: If the question is not supported by the view.
79
+ """
80
+ decision = await self ._decide_on_generation (
81
+ question = question ,
82
+ event_tracker = event_tracker ,
83
+ llm_options = llm_options ,
84
+ n_retries = n_retries ,
85
+ )
86
+ if not decision :
87
+ return None
88
+
89
+ return await self ._generate_iql (
90
+ question = question ,
91
+ filters = filters ,
92
+ event_tracker = event_tracker ,
93
+ examples = examples ,
94
+ llm_options = llm_options ,
95
+ n_retries = n_retries ,
96
+ )
97
+
98
+ async def _decide_on_generation (
99
+ self ,
100
+ question : str ,
101
+ event_tracker : EventTracker ,
102
+ llm_options : Optional [LLMOptions ] = None ,
103
+ n_retries : int = 3 ,
104
+ ) -> bool :
105
+ """
106
+ Decides whether the question requires filtering or not.
107
+
108
+ Args:
109
+ question: User question.
110
+ event_tracker: Event store used to audit the generation process.
111
+ llm_options: Options to use for the LLM client.
112
+ n_retries: Number of retries to LLM API in case of errors.
113
+
114
+ Returns:
115
+ Decision whether to generate IQL or not.
116
+
117
+ Raises:
118
+ LLMError: If LLM text generation fails after all retries.
119
+ """
120
+ prompt_format = FilteringDecisionPromptFormat (question = question )
121
+ formatted_prompt = self ._decision_prompt .format_prompt (prompt_format )
122
+
123
+ for retry in range (n_retries + 1 ):
124
+ try :
125
+ response = await self ._llm .generate_text (
126
+ prompt = formatted_prompt ,
127
+ event_tracker = event_tracker ,
128
+ options = llm_options ,
129
+ )
130
+ # TODO: Move response parsing to llm generate_text method
131
+ return formatted_prompt .response_parser (response )
132
+ except LLMError as exc :
133
+ if retry == n_retries :
134
+ raise exc
37
135
38
- async def generate_iql (
136
+ async def _generate_iql (
39
137
self ,
40
138
question : str ,
41
139
filters : List [ExposedFunction ],
@@ -68,7 +166,7 @@ async def generate_iql(
68
166
filters = filters ,
69
167
examples = examples ,
70
168
)
71
- formatted_prompt = self ._prompt_template .format_prompt (prompt_format )
169
+ formatted_prompt = self ._generation_prompt .format_prompt (prompt_format )
72
170
73
171
for retry in range (n_retries + 1 ):
74
172
try :
0 commit comments