Freeform SQL execution with similarity indexes #97
GitHub Actions / JUnit Test Report
failed
May 8, 2024 in 0s
92 tests run, 91 passed, 0 skipped, 1 failed.
Annotations
Check failure on line 49 in tests/unit/views/text2sql/test_view.py
github-actions / JUnit Test Report
test_view.test_text2sql_view
dbally.data_models.prompts.common_validation_utils.PromptTemplateError: Template format is not correct. It should be system, and then user/assistant alternating.
Raw output
self = <dbally.views.freeform.text2sql._view.Text2SQLFreeformView object at 0x7fc259d26580>
query = 'Show me customers from New York'
llm_client = <Mock id='140472707538032'>
event_tracker = <dbally.audit.event_tracker.EventTracker object at 0x7fc259d26700>
n_retries = 3, dry_run = False
async def ask(
self, query: str, llm_client: LLMClient, event_tracker: EventTracker, n_retries: int = 3, dry_run: bool = False
) -> ViewExecutionResult:
"""
Executes the query and returns the result. It generates the SQL query from the natural language query and
executes it against the database. It retries the process in case of errors.
Args:
query: The natural language query to execute.
llm_client: The LLM client used to execute the query.
event_tracker: The event tracker used to audit the query execution.
n_retries: The number of retries to execute the query in case of errors.
dry_run: If True, the query will not be used to fetch data from the datasource.
Returns:
The result of the query.
Raises:
Text2SQLError: If the text2sql query generation fails after n_retries.
"""
conversation = text2sql_prompt
sql, rows = None, None
exceptions = []
for _ in range(n_retries):
# We want to catch all exceptions to retry the process.
# pylint: disable=broad-except
try:
> sql, parameters, conversation = await self._generate_sql(query, conversation, llm_client, event_tracker)
src/dbally/views/freeform/text2sql/_view.py:131:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/dbally/views/freeform/text2sql/_view.py:165: in _generate_sql
data = json.loads(response)
/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/json/__init__.py:357: in loads
return _default_decoder.decode(s)
/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/json/decoder.py:337: in decode
obj, end = self.raw_decode(s, idx=_w(s, 0).end())
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <json.decoder.JSONDecoder object at 0x7fc271a9f220>
s = "SELECT * FROM customers WHERE city = 'New York'", idx = 0
def raw_decode(self, s, idx=0):
"""Decode a JSON document from ``s`` (a ``str`` beginning with
a JSON document) and return a 2-tuple of the Python
representation and the index in ``s`` where the document ended.
This can be used to decode a JSON document from a string that may
have extraneous data at the end.
"""
try:
obj, end = self.scan_once(s, idx)
except StopIteration as err:
> raise JSONDecodeError("Expecting value", s, err.value) from None
E json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)
/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/json/decoder.py:355: JSONDecodeError
During handling of the above exception, another exception occurred:
sample_db = Engine(sqlite:///:memory:)
async def test_text2sql_view(sample_db: Engine):
mock_llm = Mock()
mock_llm.text_generation = AsyncMock(return_value="SELECT * FROM customers WHERE city = 'New York'")
config = Text2SQLConfig(
tables={
"customers": Text2SQLTableConfig(
ddl="CREATE TABLE customers (id INTEGER PRIMARY KEY, name TEXT, city TEXT)",
description="Customers table",
)
}
)
collection = dbally.create_collection(name="test_collection", llm_client=mock_llm)
collection.add(Text2SQLFreeformView, lambda: Text2SQLFreeformView(sample_db, config))
> response = await collection.ask("Show me customers from New York")
tests/unit/views/text2sql/test_view.py:49:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/dbally/collection.py:205: in ask
view_result = await view.ask(
src/dbally/views/freeform/text2sql/_view.py:139: in ask
conversation = conversation.add_user_message(f"Response is invalid! Error: {e}")
src/dbally/data_models/prompts/prompt_template.py:71: in add_user_message
return self.__class__((*self.chat, {"role": "user", "content": content}))
src/dbally/data_models/prompts/prompt_template.py:54: in __init__
self.chat: ChatFormat = _check_chat_order(chat)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
chat = ({'content': 'You are a very smart database programmer. You have access to the following {dialect} tables:\n{tables}\n... 'role': 'user'}, {'content': 'Response is invalid! Error: Expecting value: line 1 column 1 (char 0)', 'role': 'user'})
def _check_chat_order(chat: ChatFormat) -> ChatFormat:
"""
Pydantic validator. Checks if the chat template is constructed correctly (system, user, assistant alternating).
Args:
chat: Chat template
Raises:
PromptTemplateError: if chat template is not constructed correctly.
Returns:
Chat template
"""
expected_order = ["user", "assistant"]
for i, message in enumerate(chat):
role = message["role"]
if role == "system":
if i != 0:
raise PromptTemplateError("Only first message should come from system")
continue
index = i % len(expected_order)
if role != expected_order[index - 1]:
> raise PromptTemplateError(
"Template format is not correct. It should be system, and then user/assistant alternating."
)
E dbally.data_models.prompts.common_validation_utils.PromptTemplateError: Template format is not correct. It should be system, and then user/assistant alternating.
src/dbally/data_models/prompts/prompt_template.py:30: PromptTemplateError
Loading