diff --git a/Dockerfile b/Dockerfile index 703b5d5..11e9455 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,20 +1,9 @@ -FROM python:3.7.2 +FROM python:3.8 WORKDIR /opt COPY requirements.txt requirements.txt RUN pip install -r requirements.txt -# python:3.7.2 only has sqlite 3.26, temporarily compiling 3.27 instead - -RUN curl -O "https://www.sqlite.org/2019/sqlite-amalgamation-3270000.zip" -RUN unzip -p sqlite-amalgamation-3270000.zip sqlite-amalgamation-3270000/sqlite3.c | gcc \ - -DSQLITE_THREADSAFE=1 -lpthread \ - -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS4 -DSQLITE_ENABLE_FTS5 \ - -DSQLITE_ENABLE_JSON1 -DSQLITE_ENABLE_RTREE \ - -ldl -shared -fPIC -o libsqlite3.so.0 -xc - - -RUN mv libsqlite3.so.0 /usr/lib/x86_64-linux-gnu/libsqlite3.so.0 - COPY markovich markovich CMD ["python3", "-m", "markovich", "/opt/config.json"] diff --git a/docker-compose.yml b/docker-compose.yml index 28991b6..2caa7cc 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,18 +1,15 @@ version: "3" services: - markovich: + markovich-irc: build: . volumes: - - /home/user/git/markovich/config.json:/opt/config.json + - /home/user/git/markovich/config-irc.json:/opt/config.json - /home/user/git/markovich/db:/opt/db - ports: - - 6697:6697 - - 6667:6667 - #networks: - # - znc-network restart: unless-stopped -# docker network create znc-network -networks: - znc-network: - external: true \ No newline at end of file + markovich-discord: + build: . + volumes: + - /home/user/git/markovich/config-discord.json:/opt/config.json + - /home/user/git/markovich/db:/opt/db + restart: unless-stopped diff --git a/markovich/__main__.py b/markovich/__main__.py index 8205022..4b081cd 100644 --- a/markovich/__main__.py +++ b/markovich/__main__.py @@ -1,31 +1,26 @@ import sys, os import json -from typing import List, Callable -from asyncio import Future, get_event_loop +import logging from .markovich_discord import run_markovich_discord from .markovich_irc import run_markovich_irc from .markovich_cli import run_markovich_cli def run_with_config(config): - cleanup_functions = [] #type: List[Callable[[], Future]] - aio_loop = get_event_loop() - - try: - if 'irc' in config: - cleanup_fn = run_markovich_irc(config['irc'], eventloop=aio_loop) - #cleanup_functions.append(cleanup_fn) - - if 'discord' in config: - cleanup_fn = run_markovich_discord(config['discord'], eventloop=aio_loop) - cleanup_functions.append(cleanup_fn) + if 'irc' in config and 'discord' in config: + logging.error("Cannot run in both Discord and IRC mode at the same time.") + sys.exit(1) + + if 'irc' in config: + run_markovich_irc(config['irc']) + elif 'discord' in config: + run_markovich_discord(config['discord']) + else: + logging.error("No known configurations in the specified config file.") + sys.exit(1) - aio_loop.run_forever() - except KeyboardInterrupt: - print("Shutting down") - finally: - for cleanup in cleanup_functions: - aio_loop.run_until_complete(cleanup()) +def run_without_config(): + run_markovich_cli() if len(sys.argv) > 1: config_path = sys.argv[1] @@ -35,5 +30,5 @@ def run_with_config(config): run_with_config(config) else: - print("Called with no configuration file, launching in test mode") - run_markovich_cli() + logging.warning("Called with no configuration file, launching in test mode") + run_without_config() diff --git a/markovich/backends/backend_manager.py b/markovich/backends/backend_manager.py index dff4a8d..3837042 100644 --- a/markovich/backends/backend_manager.py +++ b/markovich/backends/backend_manager.py @@ -1,25 +1,25 @@ -from .markov_backend_sqlite import MarkovBackendSQLite, sqlite_db_directory +from .markov_backend_sqlite import open_markov_backend_sqlite, sqlite_db_directory from .markov_backend import MarkovBackend -from typing import Dict +from contextlib import asynccontextmanager +from typing import Dict, AsyncIterator class MarkovManager: + open_backends: Dict[str, MarkovBackend] + def __init__(self, **kwargs): - self.open_backends = {} # type: Dict[str, MarkovBackend] + self.open_backends = {} # Recieves a list of initial mappings, the rest will be created dynamically for k,v in kwargs: if type(v) is MarkovBackendSQLite: open_backends[k] = v - def get_markov(self, context_id: str) -> MarkovBackend: - backend = self.open_backends.get(context_id) - - if backend is None: - db_path = sqlite_db_directory / context_id - db_path = db_path.with_suffix('.db') - print("Opening database \"{}\" ({})".format(context_id, db_path)) + @asynccontextmanager + async def get_markov(self, context_id: str) -> AsyncIterator[MarkovBackend]: + # In the case of SQLite, connections are opened and closed each time, no pooling - backend = MarkovBackendSQLite(db_path) - self.open_backends[context_id] = backend + db_path = sqlite_db_directory / context_id + db_path = db_path.with_suffix('.db') - return backend \ No newline at end of file + async with open_markov_backend_sqlite(db_path) as backend: + yield backend diff --git a/markovich/backends/markov_backend.py b/markovich/backends/markov_backend.py index c9deeb0..f6a786b 100644 --- a/markovich/backends/markov_backend.py +++ b/markovich/backends/markov_backend.py @@ -1,5 +1,14 @@ -from typing import Optional, Pattern +from typing import Optional, Pattern, List class MarkovBackend: - def record_and_generate(self, input_string:str, split_pattern: Pattern, word_limit: int) -> Optional[str]: - raise NotImplementedError \ No newline at end of file + async def record_and_generate(self, input_string:str, split_pattern: Pattern, word_limit: int) -> Optional[str]: + raise NotImplementedError + + async def record_words(self, chopped_string:List[str]): + pass + + async def generate_sentence(self, starting_word:str, word_limit: int): + pass + + # def bulk_learn(self, sentences): + # pass diff --git a/markovich/backends/markov_backend_sqlite.py b/markovich/backends/markov_backend_sqlite.py index e6e068e..bc3979e 100644 --- a/markovich/backends/markov_backend_sqlite.py +++ b/markovich/backends/markov_backend_sqlite.py @@ -1,24 +1,33 @@ -import sqlite3 +import aiosqlite import random import json +from contextlib import asynccontextmanager from pathlib import Path -from typing import Optional, Pattern, List +from typing import Optional, Pattern, List, AsyncIterator from .markov_backend import MarkovBackend sqlite_db_directory = Path("./db") -class MarkovBackendSQLite(MarkovBackend): - def __init__(self, database_path:Path): - MarkovBackendSQLite.check_sqlite_version() - - file_check = database_path.is_file() or (database_path.parent.is_dir() and not database_path.exists()) - assert file_check, "Cannot open or create {}".format(database_path) - - self.conn = sqlite3.connect(database_path.__fspath__()) - self.init_db() - +@asynccontextmanager +async def open_markov_backend_sqlite(database_path: Path) -> AsyncIterator["MarkovBackendSQLite"]: + MarkovBackendSQLite.check_sqlite_version() + + file_check = database_path.is_file() or (database_path.parent.is_dir() and not database_path.exists()) + assert file_check, "Cannot open or create {}".format(database_path) + + async with aiosqlite.connect(database_path.__fspath__()) as conn: + await MarkovBackendSQLite.init_db(conn) + # Alternative to using `ABS(random()) / CAST(0x7FFFFFFFFFFFFFFF AS real)` - self.conn.create_function('random_real', 0, random.random) + await conn.create_function('random_real', 0, random.random) + + yield MarkovBackendSQLite(conn) + +class MarkovBackendSQLite(MarkovBackend): + conn: aiosqlite.Connection + + def __init__(self, conn: aiosqlite.Connection): + self.conn = conn @staticmethod def check_sqlite_version(): @@ -26,88 +35,122 @@ def check_sqlite_version(): # Windowing clauses require SQLite 3.25 # Windowing clauses inside correlated subqueries cause segfaults in SQLite 3.25 and 3.26.0 # Requires the json1 extension, but currently no way to check for that - current_version = sqlite3.sqlite_version_info + current_version = aiosqlite.sqlite_version_info minimum_version = (3,27,0) - version_check = True # For an exact match - - for (current, minimum) in zip(current_version, minimum_version): - if current > minimum: - version_check = True - break - elif current < minimum: - version_check = False - break - - assert version_check, "SQLite {}.{}.{} or greater required (Running on SQLite {}.{}.{})".format(*minimum_version, *current_version) - - def init_db(self) -> None: - self.conn.execute(""" - CREATE TABLE IF NOT EXISTS chain ( - link1 text NOT NULL, - link2 text NOT NULL, -- Space is used as an end-of-sentence sentinel - n integer NOT NULL, - PRIMARY KEY (link1, link2), -- Primary Key requires not null anyway - CHECK (link1 <> ' ') + assert current_version >= minimum_version, "SQLite {}.{}.{} or greater required (Running on SQLite {}.{}.{})".format(*minimum_version, *current_version) + + @staticmethod + async def init_db(conn: aiosqlite.Connection) -> None: + await conn.execute(""" + PRAGMA foreign_keys = true; + """) + + await conn.execute(""" + CREATE TABLE IF NOT EXISTS words_v2( + word_id INTEGER PRIMARY KEY, -- Implicit ROWID + word TEXT UNIQUE NOT NULL -- Nulls are never unique ); """) - - self.conn.execute("CREATE INDEX IF NOT EXISTS chain_link1_idx ON chain (link1);") - def record_and_generate(self, input_string:str, split_pattern: Pattern, word_limit: int) -> Optional[str]: + await conn.execute(""" + CREATE TABLE IF NOT EXISTS chain_v2( + word_id1 INTEGER NOT NULL, + word_id2 INTEGER NOT NULL, + n INTEGER NOT NULL, + PRIMARY KEY(word_id1,word_id2), + FOREIGN KEY(word_id1) REFERENCES words_v2(word_id), + FOREIGN KEY(word_id2) REFERENCES words_v2(word_id) + ); + """) + + await conn.execute(""" + -- End/start of sentence sentinel, the "null word" + -- SQLite will never insert a ROWID of zero on its own, so this is safe + INSERT INTO words_v2(word_id, word) VALUES(0, '') ON CONFLICT DO NOTHING; + """) + + await conn.execute("CREATE INDEX IF NOT EXISTS chain_word_id1_idx ON chain_v2 (word_id1);") + await conn.commit() + + async def record_and_generate(self, input_string:str, split_pattern: Pattern, word_limit: int) -> Optional[str]: chopped_string = split_pattern.split(input_string) - self.record_words(chopped_string) + await self.record_words(chopped_string) try: starting_word = random.choice(chopped_string) - return self.generate_from_pair(starting_word, word_limit) + return await self.generate_sentence(starting_word, word_limit) except IndexError: return None - def record_words(self, chopped_string:List[str]) -> None: - c = self.conn.cursor() - + async def record_words(self, chopped_string:List[str]) -> None: # Passing a json text array instead of creating a query with an arbitrary amount of parmeters each time # SQLite, unlike pgSQL, doesn't have a function to split strings into tables json_encoded = json.dumps(chopped_string) - c.execute(""" - WITH - words(words) AS (SELECT value FROM json_each(?)), - word_chain(link1, link2, n) AS (SELECT words, lead(words, 1) OVER (), 1 AS n FROM words) - INSERT INTO chain(link1, link2, n) - -- link2 would normally be null, but PK constraint disallows that - SELECT link1, COALESCE(link2, ' ') AS link2, SUM(n) AS n FROM word_chain GROUP BY link1, link2 - ON CONFLICT (link1, link2) DO UPDATE SET n = chain.n + EXCLUDED.n - """, (json_encoded,)) - - def generate_from_pair(self, starting_word:str, word_limit: int) -> Optional[str]: + async with self.conn.cursor() as c: + # Insert words + await c.execute(""" + INSERT INTO words_v2(word) + SELECT DISTINCT value FROM json_each(?) + WHERE TRUE -- Avoids parsing ambiguity + ON CONFLICT(word) DO NOTHING; + """, (json_encoded,)) + + # Update chain + await c.execute(""" + WITH + word_list(word) AS (SELECT value FROM json_each(?)), + word_id_list(word_id) AS (SELECT words_v2.word_id FROM word_list INNER JOIN words_v2 ON (word_list.word = words_v2.word)), + -- Use ROWID 0 as a sentinel value for the end of chain. + word_id_chain(link1, link2, n) AS (SELECT word_id, lead(word_id, 1, 0) OVER (), 1 AS n FROM word_id_list) + INSERT INTO chain_v2(word_id1, word_id2, n) + SELECT link1, link2, SUM(n) AS n FROM word_id_chain GROUP BY link1, link2 + ON CONFLICT (word_id1, word_id2) DO UPDATE SET n = chain_v2.n + EXCLUDED.n + """, (json_encoded,)) + + # FIXME(sm15): aiosqlite has trouble with concurrent transactions on shared connections. + # Global commits are acceptable here, but connection pooling would be better. + # https://github.com/omnilib/aiosqlite/issues/19 + await self.conn.commit() + + async def generate_sentence(self, starting_word:str, word_limit: int) -> Optional[str]: if word_limit < 0: return None - - c = self.conn.cursor() - - c.execute(""" - WITH RECURSIVE markov(last_word, current_word, random_const) AS ( - VALUES(NULL, ?, random_real()) - UNION ALL - SELECT markov.current_word, ( - SELECT link2 FROM ( - SELECT link1, link2, n, - SUM(n) OVER (PARTITION BY link1 ROWS UNBOUNDED PRECEDING) AS rank, - SUM(n) OVER (PARTITION BY link1) * markov.random_const AS roll - FROM chain - WHERE link1 = markov.current_word - ) t WHERE roll <= rank LIMIT 1 - ) AS next_word, - random_real() AS random_const - - FROM markov - WHERE current_word <> ' ' - ) - -- Initial pair (NULL, starting_word) needs to be removed - SELECT last_word FROM markov WHERE last_word IS NOT NULL LIMIT ?; - """, (starting_word, word_limit)) - - word_tuples = c.fetchall() - words = [word for (word,) in word_tuples] - return ' '.join(words) + + async with self.conn.cursor() as c: + await c.execute(""" + WITH RECURSIVE markov(prev_id, curr_id, random_const) AS ( + -- Correlated subquery in a VALUES expression + -- I'm going to hell for this :) + VALUES(0, (SELECT word_id FROM words_v2 WHERE word = ?), random_real()) + UNION ALL + SELECT markov.curr_id, ( + SELECT word_id2 FROM ( + SELECT word_id1, word_id2, n, + SUM(n) OVER (PARTITION BY word_id1 ROWS UNBOUNDED PRECEDING) AS rank, + SUM(n) OVER (PARTITION BY word_id1) * markov.random_const AS roll + FROM chain_v2 + WHERE word_id1 = markov.curr_id + ) t WHERE roll <= rank LIMIT 1 + ) AS next_id, + random_real() AS random_const + + FROM markov + WHERE curr_id <> 0 + ) + + SELECT words_v2.word FROM markov + INNER JOIN words_v2 ON prev_id = word_id + -- Initial pair (0, start_id) needs to be removed + WHERE prev_id <> 0 LIMIT ?; + """, (starting_word, word_limit)) + + # FIXME(sm15): Use conn pools instead of shared conns for transaction support? + await self.conn.commit() + + word_tuples = await c.fetchall() + words = [word for (word,) in word_tuples] + return ' '.join(words) + + def bulk_learn(self, sentences): + raise NotImplementedError() diff --git a/markovich/markovich_cli.py b/markovich/markovich_cli.py index 988f2e2..957e6ca 100644 --- a/markovich/markovich_cli.py +++ b/markovich/markovich_cli.py @@ -1,12 +1,17 @@ from .backends import MarkovManager import re +import asyncio + +split_pattern = re.compile(r'[,\s]+') def run_markovich_cli(): backends = MarkovManager() - backend = backends.get_markov("test_db2") - split_pattern = re.compile(r'[,\s]+') while True: input_string = input("<-- ") - output_string = backend.record_and_generate(input_string, split_pattern, 50) - print("--> {}".format(output_string)) \ No newline at end of file + asyncio.run(on_message(input_string, backends)) + +async def on_message(input_string: str, backends: MarkovManager): + async with backends.get_markov("test_db2") as backend: + output_string = await backend.record_and_generate(input_string, split_pattern, 50) + print("--> {}".format(output_string)) diff --git a/markovich/markovich_discord.py b/markovich/markovich_discord.py index a2ce8e8..89d2b5d 100644 --- a/markovich/markovich_discord.py +++ b/markovich/markovich_discord.py @@ -1,34 +1,34 @@ import re import discord +import logging from .backends import MarkovManager -from typing import Dict, Callable +from typing import Dict, Callable, cast from asyncio import Future split_pattern = re.compile(r'[,\s]+') -def run_markovich_discord(discord_config: Dict, eventloop = None) -> Callable[[], Future]: - client = discord.Client(loop = eventloop) - backends = MarkovManager() +class MarkovichDiscord(discord.Client): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.backends = MarkovManager() - @client.event - async def on_ready(): - print("Logged in as: {}".format(client.user)) + async def on_ready(self): + print("Logged in as: {}".format(self.user)) - @client.event - async def on_message(message: discord.Message): - if type(message.channel) is discord.DMChannel or message.author == client.user: return + async def on_message(self, message: discord.Message): + if message.author == self.user or message.is_system(): return + if not isinstance(message.channel, discord.TextChannel): return - print("{}@{} ==> {}".format(message.author, message.channel, message.content)) - - markov_chain = backends.get_markov(f"discord_{message.channel.guild.id}") - reply_length = 50 if client.user.mentioned_in(message) else 0 - - reply = markov_chain.record_and_generate(message.content, split_pattern, reply_length) + async with self.backends.get_markov(f"discord_{message.channel.guild.id}") as markov_chain: + reply_length = 50 if self.user.mentioned_in(message) else 0 + reply = await markov_chain.record_and_generate(message.content, split_pattern, reply_length) if reply: - print("{} <== {}".format(message.channel, reply)) - await message.channel.send(reply) - - aio_loop = client.loop - aio_loop.run_until_complete(client.start(discord_config['token'])) - return client.logout # Return cleanup function + # Any users not in that list will be tagged but not pinged + allowed_mentions = discord.AllowedMentions(everyone=False, roles=False, users=list({message.author, self.user, *message.mentions})) + await message.channel.send(reply, allowed_mentions=allowed_mentions) + + +def run_markovich_discord(discord_config: Dict): + client = MarkovichDiscord() + client.run(discord_config['token']) diff --git a/markovich/markovich_irc.py b/markovich/markovich_irc.py index 794dcd1..87a912e 100644 --- a/markovich/markovich_irc.py +++ b/markovich/markovich_irc.py @@ -18,16 +18,19 @@ async def on_message(self, target:str, source:str, message:str): if source == self.nickname: return is_mentionned = self.nickname.lower() in message.lower() - reply_length = 50 if is_mentionned else 0 - - markov_chain = self.backends.get_markov(f"{self.server_tag}_{target}") - reply = markov_chain.record_and_generate(message, split_pattern, reply_length) + + async with self.backends.get_markov(f"{self.server_tag}_{target}") as markov_chain: + reply_length = 50 if is_mentionned else 0 + reply = await markov_chain.record_and_generate(message, split_pattern, reply_length) if reply: await self.message(target, reply) -def run_markovich_irc(irc_configs: List[Dict], eventloop = None): +def run_markovich_irc(irc_configs: List[Dict]): + pool = pydle.ClientPool() + for irc_config in irc_configs: - client = MarkovichIRC(irc_config['username'], eventloop=eventloop) - client.eventloop.run_until_complete(client.connect(**irc_config['server'])) - + client = MarkovichIRC(irc_config['username']) + pool.connect(client, **irc_config['server']) + + pool.handle_forever() diff --git a/migrate_v1_to_v2.sql b/migrate_v1_to_v2.sql new file mode 100644 index 0000000..4ce8c64 --- /dev/null +++ b/migrate_v1_to_v2.sql @@ -0,0 +1,43 @@ +-- sqlite3 -echo test_db2.db ".read migrate_v1_to_v2.sql" +PRAGMA foreign_keys = true; +BEGIN TRANSACTION; + +CREATE TABLE IF NOT EXISTS words_v2( + word_id INTEGER PRIMARY KEY, -- Implicit ROWID + word TEXT UNIQUE NOT NULL -- Nulls are never unique +); + +CREATE TABLE IF NOT EXISTS chain_v2( + word_id1 INTEGER NOT NULL, + word_id2 INTEGER NOT NULL, + n INTEGER NOT NULL, + PRIMARY KEY(word_id1,word_id2), + FOREIGN KEY(word_id1) REFERENCES words_v2(word_id), + FOREIGN KEY(word_id2) REFERENCES words_v2(word_id) +); + +-- End/start of sentence sentinel, the "null word" +-- SQLite will never insert a ROWID of zero on its own, so this is safe +INSERT INTO words_v2(word_id, word) VALUES(0, '') ON CONFLICT DO NOTHING; + +INSERT INTO words_v2(word) + SELECT DISTINCT link1 FROM chain WHERE TRUE + UNION + SELECT DISTINCT link2 FROM chain WHERE TRUE +ON CONFLICT(word) DO NOTHING; + +INSERT INTO chain_v2(word_id1, word_id2, n) + SELECT + (SELECT word_id FROM main.words_v2 WHERE chain.link1 = words_v2.word OR (words_v2.word_id = 0 AND chain.link1 IS ' ')) AS wid1, + (SELECT word_id FROM main.words_v2 WHERE chain.link2 = words_v2.word OR (words_v2.word_id = 0 AND chain.link2 IS ' ')) AS wid2, + n + FROM chain + WHERE TRUE +ON CONFLICT (word_id1, word_id2) DO UPDATE SET n = chain_v2.n + EXCLUDED.n; + +CREATE INDEX IF NOT EXISTS chain_word_id1_idx ON chain_v2 (word_id1); + +--DROP TABLE chain; + +COMMIT; +VACUUM; diff --git a/requirements.txt b/requirements.txt index 3e529db..3a8b321 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ -git+https://github.com/Rapptz/discord.py@rewrite#egg=discord.py -pydle \ No newline at end of file +discord.py +pydle +aiosqlite \ No newline at end of file