Skip to content

Commit

Permalink
WIP contexts table feature, refs #617
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Nov 13, 2024
1 parent dff53a9 commit b2fce50
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 25 deletions.
14 changes: 12 additions & 2 deletions docs/logging.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ def cleanup_sql(sql):
return first_line + '(\n ' + ',\n '.join(columns) + '\n);'
cog.out("```sql\n")
for table in ("conversations", "responses", "responses_fts", "attachments", "prompt_attachments"):
for table in (
"conversations", "responses", "responses_fts", "attachments",
"prompt_attachments", "contexts"
):
schema = db[table].schema
cog.out(format(cleanup_sql(schema)))
cog.out("\n")
Expand All @@ -148,11 +151,13 @@ CREATE TABLE [conversations] (
[name] TEXT,
[model] TEXT
);
CREATE TABLE [responses] (
CREATE TABLE "responses" (
[id] TEXT PRIMARY KEY,
[model] TEXT,
[prompt] TEXT,
[system] TEXT,
[prompt_id] INTEGER REFERENCES [contexts]([id]),
[system_id] INTEGER REFERENCES [contexts]([id]),
[prompt_json] TEXT,
[options_json] TEXT,
[response] TEXT,
Expand Down Expand Up @@ -180,6 +185,11 @@ CREATE TABLE [prompt_attachments] (
PRIMARY KEY ([response_id],
[attachment_id])
);
CREATE TABLE [contexts] (
[id] INTEGER PRIMARY KEY,
[hash] TEXT,
[context] TEXT
);
```
<!-- [[[end]]] -->
`responses_fts` configures [SQLite full-text search](https://www.sqlite.org/fts5.html) against the `prompt` and `response` columns in the `responses` table.
67 changes: 52 additions & 15 deletions llm/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import click
from click_default_group import DefaultGroup
from collections import defaultdict
from dataclasses import asdict
import io
import json
Expand Down Expand Up @@ -30,7 +31,7 @@

from .migrations import migrate
from .plugins import pm
from .utils import mimetype_from_path, mimetype_from_string
from .utils import mimetype_from_path, mimetype_from_string, reverse_replacements
import base64
import httpx
import pathlib
Expand Down Expand Up @@ -569,9 +570,12 @@ def load_conversation(conversation_id: Optional[str]) -> Optional[Conversation]:
)
# Inflate that conversation
conversation = Conversation.from_row(row)
for response in db["responses"].rows_where(
"conversation_id = ?", [conversation_id]
):
sql = LOGS_SQL.format(
extra_where=" where responses.conversation_id = ?",
limit="",
columns=LOGS_COLUMNS,
)
for response in db.query(sql, [conversation_id]):
conversation.responses.append(Response.from_row(db, response))
return conversation

Expand Down Expand Up @@ -707,13 +711,17 @@ def logs_turn_off():
responses.model,
responses.prompt,
responses.system,
responses.prompt_id,
responses.system_id,
responses.prompt_json,
responses.options_json,
responses.response,
responses.response_json,
responses.conversation_id,
responses.duration_ms,
responses.datetime_utc,
pc.context AS prompt_context,
sc.context AS system_context,
conversations.name as conversation_name,
conversations.model as conversation_model"""

Expand All @@ -722,18 +730,40 @@ def logs_turn_off():
{columns}
from
responses
left join contexts pc on responses.prompt_id = pc.id
left join contexts sc on responses.system_id = sc.id
left join conversations on responses.conversation_id = conversations.id{extra_where}
order by responses.id desc{limit}
"""
LOGS_SQL_SEARCH = """
with matching_responses as (
-- Direct matches in responses_fts
select
responses.id,
responses_fts.rank as rank
from responses
join responses_fts on responses_fts.rowid = responses.rowid
where responses_fts match :query
union
-- Matches in linked contexts (prompt_id only)
select
responses.id,
contexts_fts.rank as rank
from responses
join contexts on responses.prompt_id = contexts.id
join contexts_fts on contexts_fts.rowid = contexts.rowid
where contexts_fts match :query
)
select
{columns}
{columns}
from
responses
matching_responses
join responses on responses.id = matching_responses.id
left join conversations on responses.conversation_id = conversations.id
join responses_fts on responses_fts.rowid = responses.rowid
where responses_fts match :query{extra_where}
order by responses_fts.rank desc{limit}
left join contexts pc on responses.prompt_id = pc.id
left join contexts sc on responses.system_id = sc.id
{extra_where}
order by matching_responses.rank desc{limit}
"""

ATTACHMENTS_SQL = """
Expand Down Expand Up @@ -861,12 +891,9 @@ def logs_list(
sql_format["extra_where"] = where_ + " and ".join(where_bits)

final_sql = sql.format(**sql_format)
rows = list(
db.query(
final_sql,
{"model": model_id, "query": query, "conversation_id": conversation_id},
)
)
params = {"model": model_id, "query": query, "conversation_id": conversation_id}

rows = list(db.query(final_sql, params))
# Reverse the order - we do this because we 'order by id desc limit 3' to get the
# 3 most recent results, but we still want to display them in chronological order
# ... except for searches where we don't do this
Expand All @@ -881,6 +908,13 @@ def logs_list(
attachments_by_id.setdefault(attachment["response_id"], []).append(attachment)

for row in rows:
replacements = defaultdict(str)
if row["prompt_context"]:
replacements[row["prompt_id"]] = row["prompt_context"]
row["prompt"] = row.pop("prompt_context")
if row["system_context"]:
replacements[row["system_id"]] = row["system_context"]
row["system"] = row.pop("system_context")
if truncate:
row["prompt"] = _truncate_string(row["prompt"])
row["response"] = _truncate_string(row["response"])
Expand All @@ -893,6 +927,9 @@ def logs_list(
else:
row[key] = json.loads(row[key])

if replacements and row["prompt_json"]:
row["prompt_json"] = reverse_replacements(row["prompt_json"], replacements)

if json_output:
# Output as JSON if requested
for row in rows:
Expand Down
47 changes: 47 additions & 0 deletions llm/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,50 @@ def m012_attachments_tables(db):
),
pk=("response_id", "attachment_id"),
)


@migration
def m013_contexts_table(db):
with db.conn:
db["contexts"].create(
{
"id": int,
"hash": str,
"context": str,
},
pk="id",
)
db["contexts"].create_index(["hash"], unique=True)
# Add optional foreign keys prompt_id / system_id to responses
db["responses"].add_column("prompt_id", int)
db["responses"].add_column("system_id", int)
db["responses"].add_foreign_key("prompt_id", "contexts", "id")
db["responses"].add_foreign_key("system_id", "contexts", "id")
db["responses"].transform(
column_order=(
"id",
"model",
"prompt",
"system",
"prompt_id",
"system_id",
"prompt_json",
"options_json",
"response",
"response_json",
"conversation_id",
"duration_ms",
"datetime_utc",
),
)


@migration
def m014_contexts_fts(db):
db["contexts"].enable_fts(["context"], create_triggers=True)


@migration
def m15_reindex_responses(db):
# The .transform() may have recreated rowids so they no longer match
db["responses"].populate_fts(["prompt", "response"])
52 changes: 46 additions & 6 deletions llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import re
import time
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union
from .utils import mimetype_from_path, mimetype_from_string
from .utils import mimetype_from_path, mimetype_from_string, apply_replacements
from abc import ABC, abstractmethod
import json
from pydantic import BaseModel
from ulid import ULID

CONVERSATION_NAME_LENGTH = 32
PROMPT_THRESHOLD = 100


@dataclass
Expand Down Expand Up @@ -218,12 +219,51 @@ def log_to_db(self, db):
ignore=True,
)
response_id = str(ULID()).lower()

# if prompt/system are long we stash them in contexts
prompt = self.prompt.prompt
system = self.prompt.system
prompt_id = None
system_id = None
replacements = {}

for context, column in (
(prompt, "prompt"),
(system, "system"),
):
if context is not None and len(context) > PROMPT_THRESHOLD:
hash = hashlib.sha256(context.encode("utf-8")).hexdigest()
rows = list(db.query("select id from contexts where hash = ?", [hash]))
if rows:
context_id = rows[0]["id"]
else:
context_id = (
db["contexts"]
.insert(
{
"hash": hash,
"context": context,
},
ignore=True,
)
.last_pk
)
replacements[context_id] = context
if column == "prompt":
prompt_id = context_id
prompt = None
else:
system_id = context_id
system = None

response = {
"id": response_id,
"model": self.model.model_id,
"prompt": self.prompt.prompt,
"system": self.prompt.system,
"prompt_json": self._prompt_json,
"prompt": prompt,
"system": system,
"prompt_id": prompt_id,
"system_id": system_id,
"prompt_json": apply_replacements(self._prompt_json, replacements),
"options_json": {
key: value
for key, value in dict(self.prompt.options).items()
Expand Down Expand Up @@ -290,10 +330,10 @@ def from_row(cls, db, row):
response = cls(
model=model,
prompt=Prompt(
prompt=row["prompt"],
prompt=row["prompt_context"] or row["prompt"],
model=model,
attachments=[],
system=row["system"],
system=row["system_context"] or row["system"],
options=model.Options(**json.loads(row["options_json"])),
),
stream=False,
Expand Down
46 changes: 46 additions & 0 deletions llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,49 @@ def logging_client() -> httpx.Client:
transport=_LogTransport(httpx.HTTPTransport()),
event_hooks={"request": [_no_accept_encoding], "response": [_log_response]},
)


def apply_replacements(obj, replacements):
if isinstance(obj, dict):
return {k: apply_replacements(v, replacements) for k, v in obj.items()}
elif isinstance(obj, list):
return [apply_replacements(item, replacements) for item in obj]
elif isinstance(obj, str):
replaced_parts = []
last_index = 0
found = False

for value, key in replacements.items():
index = obj.find(key)
while index != -1:
found = True
if index > last_index:
replaced_parts.append(obj[last_index:index])
replaced_parts.append(value)
last_index = index + len(key)
index = obj.find(key, last_index)

if found:
if last_index < len(obj):
replaced_parts.append(obj[last_index:])
return {"$r": replaced_parts}
else:
return obj
else:
return obj


def reverse_replacements(obj, replacements):
if isinstance(obj, dict):
if "$r" in obj:
# Reconstruct the original string from the list
return "".join(
(replacements[part] if isinstance(part, int) else part)
for part in obj["$r"]
)
else:
return {k: reverse_replacements(v, replacements) for k, v in obj.items()}
elif isinstance(obj, list):
return [reverse_replacements(item, replacements) for item in obj]
else:
return obj
Loading

0 comments on commit b2fce50

Please sign in to comment.