Skip to content

Commit f671b1d

Browse files
committed
Persist sessions through restarts with sqlite
1 parent d262cc8 commit f671b1d

File tree

5 files changed

+59
-25
lines changed

5 files changed

+59
-25
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
.venv/
22
.offload/
33
.history/
4+
oracle.db
45

56
__pycache__
67
.DS_Store

oracle/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
DB_PATH = 'oracle.db'
2+
13
def get_device():
24
import torch
35

oracle/gradio/__main__.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,37 @@
11
import os
22
import hashlib
3-
from oracle.gradio.gui import demo
3+
import sqlite3
44

5+
from oracle import DB_PATH
6+
from oracle.gradio.gui import demo
57

6-
credentials = {}
78

89
def open_registration(username, password):
10+
connection = sqlite3.connect(DB_PATH)
911
password_hash = hashlib.sha256(password.encode()).hexdigest()
10-
saved_hash = credentials.get(username, password_hash)
11-
credentials[username] = saved_hash
12+
query = connection.execute(
13+
"SELECT password FROM credentials WHERE user=?",
14+
(username,),
15+
)
16+
saved_hash = query.fetchone()
17+
if not saved_hash:
18+
connection.execute(
19+
"INSERT OR REPLACE INTO credentials VALUES (?, ?)",
20+
(username, password_hash),
21+
)
22+
connection.commit()
23+
saved_hash = password_hash
24+
else:
25+
saved_hash = saved_hash[0]
26+
connection.close()
1227
return password_hash == saved_hash
1328

1429
if __name__ == '__main__':
30+
connection = sqlite3.connect(DB_PATH)
31+
try: connection.execute("CREATE TABLE credentials (user TEXT PRIMARY KEY, password TEXT)")
32+
except: pass
33+
connection.close()
34+
1535
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
1636
demo.launch(
1737
server_port=8080,

oracle/gradio/gui.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@
99
# Model
1010

1111
session_state = gr.State(ChatSession)
12-
raw_chat_log = persist(gr.State(list))
12+
raw_chat_log = persist("chat", gr.State(list))
1313

1414
# View
1515

1616
with gr.Accordion('☰', open=False, elem_id='settings'):
1717
with gr.Tab('Context'):
18-
context_input = persist(gr.Dropdown(label='Source'))
19-
motive_input = persist(gr.Textbox(label='Motivation'))
20-
keyword_checkbox = persist(gr.Checkbox(True, label='Ask the model for keywords?'))
21-
debug_checkbox = persist(gr.Checkbox(label='Show debug info?'))
18+
context_input = persist("context", gr.Dropdown(label='Source'))
19+
motive_input = persist("motive", gr.Textbox(label='Motivation'))
20+
keyword_checkbox = persist("keyword", gr.Checkbox(True, label='Ask the model for keywords?'))
21+
debug_checkbox = persist("debug", gr.Checkbox(label='Show debug info?'))
2222

2323
with gr.Tab('Model'):
24-
model_input = persist(gr.Dropdown(label='Chat Model'))
25-
style_input = persist(gr.Dropdown(
24+
model_input = persist("model", gr.Dropdown(label='Chat Model'))
25+
style_input = persist("style", gr.Dropdown(
2626
label='Response Style',
2727
choices=STYLES,
2828
allow_custom_value=True,

oracle/gradio/utils.py

+25-14
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import inspect
2-
from uuid import uuid4
2+
import sqlite3
3+
import pickle
34

45
import gradio as gr
56
from gradio.context import Context
67
from gradio.events import Changeable
78

9+
from oracle import DB_PATH
10+
811

912
def locked(**kwargs):
1013
return gr.update(**kwargs, interactive=False)
@@ -26,23 +29,31 @@ def wrapper(fn):
2629
else:
2730
return wrapper
2831

29-
user_sessions = {}
30-
31-
def get_session_id(request):
32-
session_id = user_sessions.get(request.username, None)
33-
if not session_id:
34-
session_id = uuid4()
35-
if request.username:
36-
user_sessions[request.username] = session_id
37-
return session_id
32+
def persist(name, component):
33+
connection = sqlite3.connect(DB_PATH)
34+
try: connection.execute(f"CREATE TABLE {name} (user TEXT PRIMARY KEY, value TEXT)")
35+
except: pass
36+
connection.close()
3837

39-
def persist(component):
40-
sessions = {}
4138
@on(Context.root_block.load)
4239
def resume_session(value: component, request: gr.Request) -> component:
43-
return sessions.get(get_session_id(request), value)
40+
connection = sqlite3.connect(DB_PATH)
41+
query = connection.execute(
42+
f"SELECT value FROM {name} WHERE user=?",
43+
(request.username,),
44+
)
45+
saved_value = query.fetchone()
46+
connection.close()
47+
return pickle.loads(saved_value[0]) if saved_value else value
48+
4449
def update_session(value: component, request: gr.Request):
45-
sessions[get_session_id(request)] = value
50+
connection = sqlite3.connect(DB_PATH)
51+
query = connection.execute(
52+
f"INSERT OR REPLACE INTO {name} VALUES (?, ?)",
53+
(request.username, pickle.dumps(value, 5)),
54+
)
55+
connection.commit()
56+
connection.close()
4657

4758
if hasattr(component, 'change'):
4859
on(component.change, update_session)

0 commit comments

Comments
 (0)