Skip to content

Commit

Permalink
use session interface for sqlalchemy
Browse files Browse the repository at this point in the history
  • Loading branch information
brassy-endomorph committed Sep 17, 2024
1 parent 92bf51b commit bb524a8
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 63 deletions.
15 changes: 6 additions & 9 deletions hushline/make_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,25 @@

from hushline import create_app
from hushline.db import db
from hushline.model import User
from hushline.models import Username


def toggle_admin(username: str) -> None:
user = User.query.filter_by(primary_username=username).one_or_none()
if not user:
uname = db.session.scalars(db.select(Username).filter_by(_username=username)).one_or_none()
if not uname:
print("User not found.")
return

# Toggle admin status
user.is_admin = not user.is_admin
uname.user.is_admin = not uname.user.is_admin
db.session.commit()

print(f"User {username} admin status toggled to {user.is_admin}.")
print(f"User {username} admin status toggled to {uname.user.is_admin}.")


if __name__ == "__main__":
if len(sys.argv) != 2: # noqa: PLR2004
print("Usage: python make_admin.py <username>")
sys.exit(1)

username = sys.argv[1]

with create_app().app_context():
toggle_admin(username)
toggle_admin(sys.argv[1])
31 changes: 19 additions & 12 deletions hushline/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from flask_wtf import FlaskForm
from sqlalchemy import select
from sqlalchemy.sql import exists
from werkzeug.wrappers.response import Response
from wtforms import Field, Form, PasswordField, StringField, TextAreaField
from wtforms.validators import DataRequired, Length, Optional, ValidationError
Expand Down Expand Up @@ -131,7 +130,7 @@ def inbox() -> Response | str:
@app.route("/to/<username>", methods=["GET"])
def profile(username: str) -> Response | str:
form = MessageForm()
uname = Username.query.filter_by(_username=username).one_or_none()
uname = db.session.scalars(db.select(Username).filter_by(_username=username)).one_or_none()
if not uname:
flash("🫥 User not found.")
return redirect(url_for("index"))
Expand Down Expand Up @@ -179,7 +178,7 @@ def profile(username: str) -> Response | str:
@app.route("/to/<username>", methods=["POST"])
def submit_message(username: str) -> Response | str:
form = MessageForm()
uname = Username.query.filter_by(_username=username).one_or_none()
uname = db.session.scalars(db.select(Username).filter_by(_username=username)).one_or_none()
if not uname:
flash("🫥 User not found.")
return redirect(url_for("index"))
Expand Down Expand Up @@ -280,17 +279,23 @@ def delete_message(message_id: int) -> Response:
flash("🔑 Please log in to continue.")
return redirect(url_for("login"))

user = User.query.get(session.get("user_id"))
user = db.session.get(User, session.get("user_id"))
if not user:
flash("🫥 User not found. Please log in again.")
return redirect(url_for("login"))

row_count = Message.query.filter(
Message.id == message_id,
Message.username_id.in_(
select(Username.user_id).select_from(Username).filter(Username.user_id == user.id)
),
).delete()
row_count = (
db.delete(Message)
.where(
Message.id == message_id,
Message.username_id.in_(
select(Username.user_id)
.select_from(Username)
.filter(Username.user_id == user.id)
),
)
.delete()
)
match row_count:
case 1:
db.session.commit()
Expand Down Expand Up @@ -328,7 +333,9 @@ def register() -> Response | str | tuple[Response | str, int]:

invite_code_input = form.invite_code.data if require_invite_code else None
if invite_code_input:
invite_code = InviteCode.query.filter_by(code=invite_code_input).one_or_none()
invite_code = db.session.scalars(
db.select(InviteCode).filter_by(code=invite_code_input)
).one_or_none()
if not invite_code or invite_code.expiration_date.replace(
tzinfo=UTC
) < datetime.now(UTC):
Expand All @@ -342,7 +349,7 @@ def register() -> Response | str | tuple[Response | str, int]:
400,
)

if db.session.query(exists(Username).where(Username._username == username)).scalar():
if db.session.query(db.exists(Username).where(Username._username == username)).scalar():
flash("💔 Username already taken.", "error")
return (
render_template(
Expand Down
30 changes: 17 additions & 13 deletions hushline/settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
session,
url_for,
)
from psycopg.errors import UniqueViolation
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import exists
from werkzeug.wrappers.response import Response
from wtforms import Field

Expand Down Expand Up @@ -153,7 +153,7 @@ def handle_change_username_form(

# TODO a better pattern would be to try to commit, catch the exception, and match
# on the name of the unique index that errored
if db.session.query(exists(Username).where(Username._username == new_username)).scalar():
if db.session.query(db.exists(Username).where(Username._username == new_username)).scalar():
flash("💔 This username is already taken.")
else:
username.username = new_username
Expand All @@ -169,13 +169,12 @@ def handle_change_username_form(
def handle_new_alias_form(user: User, new_alias_form: NewAliasForm, redirect_url: str) -> Response:
current_app.logger.debug("Creating alias for {user.primary_username.username}")
# TODO check that users are allowed to add aliases here (is premium, not too many)
# TODO check that alias is not yet taken
uname = Username(_username=new_alias_form.username.data, user_id=user.id, is_primary=False)
db.session.add(uname)
try:
db.session.commit()
except IntegrityError as e:
if 'duplicate key value violates unique constraint "usernames_username_key"' in str(e):
if isinstance(e.orig, UniqueViolation) and '"usernames_username_key"' in str(e.orig):
flash("💔 This username is already taken.")
else:
flash("⛔️ Internal server error. Alias not created.")
Expand Down Expand Up @@ -240,11 +239,11 @@ async def index() -> str | Response:
)
flash("Uh oh. There was an error handling your data. Please notify the admin.")

aliases = (
Username.query.filter_by(is_primary=False, user_id=user.id)
aliases = db.session.scalars(
db.select(Username)
.filter_by(is_primary=False, user_id=user.id)
.order_by(db.func.coalesce(Username._display_name, Username._username))
.all()
)
).all()
# Additional admin-specific data initialization
user_count = two_fa_count = pgp_key_count = two_fa_percentage = pgp_key_percentage = None
all_users = []
Expand All @@ -264,7 +263,11 @@ async def index() -> str | Response:
user_count = len(all_users)
two_fa_percentage = (two_fa_count / user_count * 100) if user_count else 0
pgp_key_percentage = (pgp_key_count / user_count * 100) if user_count else 0
all_users = list(User.query.join(Username).order_by(Username._username).all())
all_users = list(
db.session.scalars(
db.select(User).join(Username).order_by(Username._username)
).all()
)

# Prepopulate form fields
email_forwarding_form.forwarding_enabled.data = user.email is not None
Expand Down Expand Up @@ -620,9 +623,10 @@ def delete_account() -> Response | str:
@authentication_required
@bp.route("/alias/<int:username_id>", methods=["GET", "POST"])
async def alias(username_id: int) -> Response | str:
user = User.query.get(session["user_id"])
alias = Username.query.filter_by(
id=username_id, user_id=user.id, is_primary=False
alias = db.session.scalars(
db.select(Username).filter_by(
id=username_id, user_id=session["user_id"], is_primary=False
)
).one_or_none()
if not alias:
flash("Alias not found.")
Expand Down Expand Up @@ -658,7 +662,7 @@ async def alias(username_id: int) -> Response | str:

return render_template(
"settings/alias.html",
user=user,
user=alias.user,
alias=alias,
display_name_form=display_name_form,
directory_visibility_form=directory_visibility_form,
Expand Down
24 changes: 11 additions & 13 deletions tests/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,36 @@
import pyotp
from flask.testing import FlaskClient

from hushline.db import db
from hushline.model import AuthenticationLog, User, Username


def register_user(client: FlaskClient, username: str, password: str) -> User:
# Prepare the environment to not require invite codes
os.environ["REGISTRATION_CODES_REQUIRED"] = "False"

# User registration data
user_data = {"username": username, "password": password}

# Post request to register a new user
response = client.post("/register", data=user_data, follow_redirects=True)

# Validate response
assert response.status_code == 200
assert b"Registration successful!" in response.data

# Verify user is added to the database
user = User.query.join(Username).filter(Username._username == username).one_or_none()
assert user is not None
user = db.session.scalars(
db.select(User).join(Username).filter(Username._username == username)
).one()
assert user.primary_username.username == username

return user


def register_user_2fa(client: FlaskClient, username: str, password: str) -> tuple[User, str]:
# Register a new user
user_data = {"username": username, "password": password}
response = client.post("/register", data=user_data, follow_redirects=True)
assert response.status_code == 200

# Verify user is added to the database
user = User.query.join(Username).filter(Username._username == username).one_or_none()
assert user is not None
user = db.session.scalars(
db.select(User).join(Username).filter(Username._username == username)
).one()
assert user.primary_username.username == username

# And 2FA is disabled
Expand Down Expand Up @@ -75,7 +71,7 @@ def register_user_2fa(client: FlaskClient, username: str, password: str) -> tupl
assert "Enter your 2FA Code" in login_response.text

# Modify the timestamps on the AuthenticationLog entries to allow for 2FA verification
for log in AuthenticationLog.query.all():
for log in db.session.scalars(db.select(AuthenticationLog)).all():
log.timestamp = datetime.now() - timedelta(minutes=5)

return (user, totp_secret)
Expand All @@ -95,7 +91,9 @@ def login_user(client: FlaskClient, username: str, password: str) -> User | None
f'href="/inbox?username={username}"'.encode() in response.data
), f"Inbox link should be present for the user {username}"

if username := Username.query.filter_by(_username=username).one_or_none():
if username := db.session.scalars(
db.select(Username).filter_by(_username=username)
).one_or_none():
return username.user
return None

Expand Down
15 changes: 12 additions & 3 deletions tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def test_profile_submit_message(client: FlaskClient) -> None:
assert response.status_code == 200
assert b"Message submitted successfully." in response.data

message = Message.query.filter_by(username_id=user.primary_username.id).one()
message = db.session.scalars(
db.select(Message).filter_by(username_id=user.primary_username.id)
).one()
assert message.content == msg_content

response = client.get(url_for("inbox", unamename=username), follow_redirects=True)
Expand Down Expand Up @@ -75,7 +77,9 @@ def test_profile_submit_message_with_contact_method(client: FlaskClient) -> None
assert response.status_code == 200
assert b"Message submitted successfully." in response.data

message = Message.query.filter_by(username_id=user.primary_username.id).one_or_none()
message = db.session.scalars(
db.select(Message).filter_by(username_id=user.primary_username.id)
).one()
expected_content = f"Contact Method: {contact_method}\n\n{message_content}"
assert message.content == expected_content

Expand Down Expand Up @@ -172,4 +176,9 @@ def test_profile_submit_message_with_invalid_captcha(client: FlaskClient) -> Non
assert message_content.encode() in response.data

# Verify that the message is not saved in the database
assert not Message.query.filter_by(username_id=user.primary_username.id).one_or_none()
assert (
db.session.scalars(
db.select(Message).filter_by(username_id=user.primary_username.id)
).one_or_none()
is None
)
4 changes: 2 additions & 2 deletions tests/test_registration_and_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_user_registration_with_invite_code_disabled(client: FlaskClient) -> Non
assert response.status_code == 200
assert "Registration successful!" in response.text

uname = Username.query.filter_by(_username=username).one()
uname = db.session.scalars(db.select(Username).filter_by(_username=username)).one()
assert uname.username == username


Expand All @@ -42,7 +42,7 @@ def test_user_registration_with_invite_code_enabled(client: FlaskClient) -> None
assert response.status_code == 200
assert "Registration successful!" in response.text

uname = Username.query.filter_by(_username=username).one()
uname = db.session.scalars(db.select(Username).filter_by(_username=username)).one()
assert uname.username == "newuser"


Expand Down
Loading

0 comments on commit bb524a8

Please sign in to comment.