-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathcli_event_handler.py
145 lines (120 loc) · 6.09 KB
/
cli_event_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import re
from typing import Optional
try:
from rich import print as pprint
from rich.console import Console
from rich.syntax import Syntax
from rich.text import Text
RICH_OUTPUT = True
except ImportError:
RICH_OUTPUT = False
pprint = print # type: ignore
from dbally.audit.event_handlers.base import EventHandler
from dbally.audit.events import Event, FallbackEvent, LLMEvent, RequestEnd, RequestStart, SimilarityEvent
_RICH_FORMATING_KEYWORD_SET = {"green", "orange", "grey", "bold", "cyan"}
_RICH_FORMATING_PATTERN = rf"\[.*({'|'.join(_RICH_FORMATING_KEYWORD_SET)}).*\]"
class CLIEventHandler(EventHandler):
"""
This handler displays all interactions between LLM and user happening during `Collection.ask`\
execution inside the terminal.
### Usage
```python
import dbally
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
dbally.event_handlers = [CLIEventHandler()]
my_collection = dbally.create_collection("my_collection", llm)
```
After using `CLIEventHandler`, during every `Collection.ask` execution you will see output similar to the one below:

"""
def __init__(self) -> None:
super().__init__()
self._console = Console(record=True) if RICH_OUTPUT else None
def _print_syntax(self, content: str, lexer: Optional[str] = None) -> None:
if self._console:
if lexer:
console_content = Syntax(content, lexer, word_wrap=True)
else:
console_content = Text.from_markup(content)
self._console.print(console_content)
else:
content_without_formatting = re.sub(_RICH_FORMATING_PATTERN, "", content)
print(content_without_formatting)
async def request_start(self, user_request: RequestStart) -> None:
"""
Displays information about event starting to the terminal.
Args:
user_request: Object containing name of collection and asked query
"""
self._print_syntax(f"[orange3 bold]Request starts... \n[orange3 bold]MESSAGE: [grey53]{user_request.question}")
self._print_syntax("[grey53]\n=======================================")
self._print_syntax("[grey53]=======================================\n")
# pylint: disable=unused-argument
async def event_start(self, event: Event, request_context: None) -> None:
"""
Displays information that event has started, then all messages inside the prompt
Args:
event: db-ally event to be logged with all the details.
request_context: Optional context passed from request_start method
"""
if isinstance(event, LLMEvent):
self._print_syntax(
f"[cyan bold]LLM event starts... \n[cyan bold]LLM EVENT PROMPT TYPE: [grey53]{event.type}"
)
if isinstance(event.prompt, tuple):
for msg in event.prompt:
self._print_syntax(f"\n[orange3]{msg['role']}")
self._print_syntax(msg["content"], "text")
else:
self._print_syntax(f"{event.prompt}", "text")
elif isinstance(event, SimilarityEvent):
self._print_syntax(
f"[cyan bold]Similarity event starts... \n"
f"[cyan bold]INPUT: [grey53]{event.input_value}\n"
f"[cyan bold]STORE: [grey53]{event.store}\n"
f"[cyan bold]FETCHER: [grey53]{event.fetcher}\n"
)
elif isinstance(event, FallbackEvent):
self._print_syntax(
f"[grey53]\n=======================================\n"
"[grey53]=======================================\n"
f"[orange bold]Fallback event starts \n"
f"[orange bold]Triggering collection: [grey53]{event.triggering_collection_name}\n"
f"[orange bold]Triggering view name: [grey53]{event.triggering_view_name}\n"
f"[orange bold]Error description: [grey53]{event.error_description}\n"
f"[orange bold]Fallback collection name: [grey53]{event.fallback_collection_name}\n"
"[grey53]=======================================\n"
"[grey53]=======================================\n"
)
# pylint: disable=unused-argument
async def event_end(self, event: Optional[Event], request_context: None, event_context: None) -> None:
"""
Displays the response from the LLM.
Args:
event: db-ally event to be logged with all the details.
request_context: Optional context passed from request_start method
event_context: Optional context passed from event_start method
"""
if isinstance(event, LLMEvent):
self._print_syntax(f"\n[green bold]RESPONSE: {event.response}")
self._print_syntax("[grey53]\n=======================================")
self._print_syntax("[grey53]=======================================\n")
elif isinstance(event, SimilarityEvent):
self._print_syntax(f"[green bold]OUTPUT: {event.output_value}")
self._print_syntax("[grey53]\n=======================================")
self._print_syntax("[grey53]=======================================\n")
# pylint: disable=unused-argument
async def request_end(self, output: RequestEnd, request_context: Optional[dict] = None) -> None:
"""
Displays the output of the request, namely the `results` and the `context`
Args:
output: The output of the request.
request_context: Optional context passed from request_start method
"""
if output.result:
self._print_syntax("[green bold]REQUEST OUTPUT:")
self._print_syntax(f"Number of rows: {len(output.result.results)}")
if "sql" in output.result.context:
self._print_syntax(f"{output.result.context['sql']}", "psql")
else:
self._print_syntax("[red bold]No results found")