Skip to content

Commit

Permalink
Allow CLI override PostgreSQL connection settings (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
janbjorge authored Apr 11, 2024
1 parent 8102387 commit 9f7e397
Showing 1 changed file with 114 additions and 46 deletions.
160 changes: 114 additions & 46 deletions src/pgcachewatch/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import os
import sys

import asyncpg
Expand Down Expand Up @@ -38,6 +39,63 @@ def cliparser() -> argparse.Namespace:
help="Commit changes to database.",
)

common_arguments.add_argument(
"--pg-dsn",
help=(
"Connection string in the libpq URI format, including host, port, user, "
"database, password, passfile, and SSL options. Must be properly quoted; "
"IPv6 addresses must be in brackets. "
"Example: postgres://user:pass@host:port/database. Defaults to PGDSN "
"environment variable if set."
),
default=os.environ.get("PGDSN"),
)

common_arguments.add_argument(
"--pg-host",
help=(
"Database host address, which can be an IP or domain name. "
"Defaults to PGHOST environment variable if set."
),
default=os.environ.get("PGHOST"),
)

common_arguments.add_argument(
"--pg-port",
help=(
"Port number for the server host Defaults to PGPORT environment variable "
"or 5432 if not set."
),
default=os.environ.get("PGPORT", "5432"),
)

common_arguments.add_argument(
"--pg-user",
help=(
"Database role for authentication. Defaults to PGUSER environment "
"variable if set."
),
default=os.environ.get("PGUSER"),
)

common_arguments.add_argument(
"--pg-database",
help=(
"Name of the database to connect to. Defaults to PGDATABASE environment "
"variable if set."
),
default=os.environ.get("PGDATABASE"),
)

common_arguments.add_argument(
"--pg-password",
help=(
"Password for authentication. Defaults to PGPASSWORD "
"environment variable if set"
),
default=os.environ.get("PGPASSWORD"),
)

parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
prog="pgcachewatch",
Expand Down Expand Up @@ -67,53 +125,63 @@ async def main() -> None:
pg_fn_name = f"{parsed.function_name}_{parsed.channel_name}"
pg_tg_name = f"{parsed.trigger_name}_{parsed.channel_name}"

match parsed.command:
case "install":
install = "\n".join(
[
queries.create_notify_function(
channel_name=parsed.channel_name,
function_name=pg_fn_name,
)
]
+ [
queries.create_after_change_trigger(
trigger_name=pg_tg_name,
table_name=table,
function_name=pg_fn_name,
)
for table in parsed.tables
]
)

print(install, flush=True)

if parsed.commit:
await (await asyncpg.connect()).execute(install)
else:
print(
"::: Use '--commit' to write changes to db. :::",
file=sys.stderr,
async with asyncpg.create_pool(
parsed.pg_dsn,
database=parsed.pg_database,
password=parsed.pg_password,
port=parsed.pg_port,
user=parsed.pg_user,
host=parsed.pg_host,
min_size=0,
max_size=1,
) as pool:
match parsed.command:
case "install":
install = "\n".join(
[
queries.create_notify_function(
channel_name=parsed.channel_name,
function_name=pg_fn_name,
)
]
+ [
queries.create_after_change_trigger(
trigger_name=pg_tg_name,
table_name=table,
function_name=pg_fn_name,
)
for table in parsed.tables
]
)

case "uninstall":
trigger_names = await (await asyncpg.connect()).fetch(
queries.fetch_trigger_names(pg_tg_name),
)
combined = "\n".join(
(
"\n".join(
queries.drop_trigger(t["trigger_name"], t["table"])
for t in trigger_names
),
queries.drop_function(pg_fn_name),
print(install, flush=True)

if parsed.commit:
await pool.execute(install)
else:
print(
"::: Use '--commit' to write changes to db. :::",
file=sys.stderr,
)

case "uninstall":
trigger_names = await pool.fetch(
queries.fetch_trigger_names(pg_tg_name),
)
)
print(combined, flush=True)
if parsed.commit:
await (await asyncpg.connect()).execute(combined)
else:
print(
"::: Use '--commit' to write changes to db. :::",
file=sys.stderr,
combined = "\n".join(
(
"\n".join(
queries.drop_trigger(t["trigger_name"], t["table"])
for t in trigger_names
),
queries.drop_function(pg_fn_name),
)
)
print(combined, flush=True)
if parsed.commit:
await pool.execute(combined)
else:
print(
"::: Use '--commit' to write changes to db. :::",
file=sys.stderr,
)

0 comments on commit 9f7e397

Please sign in to comment.