diff --git a/hushline/settings/__init__.py b/hushline/settings/__init__.py index e1657a2b..cc4f87a7 100644 --- a/hushline/settings/__init__.py +++ b/hushline/settings/__init__.py @@ -43,6 +43,7 @@ from .forms import ( ChangePasswordForm, ChangeUsernameForm, + DeleteBrandLogoForm, DirectoryVisibilityForm, DisplayNameForm, EmailForwardingForm, @@ -344,6 +345,7 @@ async def index() -> str | Response: update_brand_primary_color_form=UpdateBrandPrimaryColorForm(), update_brand_app_name_form=UpdateBrandAppNameForm(), update_brand_logo_form=UpdateBrandLogoForm(), + delete_brand_logo_form=DeleteBrandLogoForm(), email_forwarding_form=email_forwarding_form, change_password_form=change_password_form, change_username_form=change_username_form, @@ -690,39 +692,35 @@ def update_brand_app_name() -> Response | str: @bp.route("/update-brand-logo", methods=["POST"]) @admin_authentication_required def update_brand_logo() -> Response | str: - form = UpdateBrandLogoForm() - if form.validate_on_submit(): - public_store.put(OrganizationSetting.BRAND_LOGO_VALUE, form.logo.data) + update_form = UpdateBrandLogoForm() + delete_form = DeleteBrandLogoForm() + if update_form.validate_on_submit() and update_form.logo.data: + public_store.put(OrganizationSetting.BRAND_LOGO_VALUE, update_form.logo.data) OrganizationSetting.upsert( key=OrganizationSetting.BRAND_LOGO, value=OrganizationSetting.BRAND_LOGO_VALUE, ) flash("👍 Brand logo updated successfully.") - return redirect(url_for(".index")) - - flash("⛔ Invalid form data. Please try again.") - return redirect(url_for(".index")) - - @bp.route("/delete-brand-logo", methods=["POST"]) - @admin_authentication_required - def delete_brand_logo() -> Response | str: - row_count = db.session.execute( - db.delete(OrganizationSetting).where( - OrganizationSetting.key == OrganizationSetting.BRAND_LOGO - ) - ).rowcount - if row_count > 1: - current_app.logger.error( - "Would have deleted multiple rows for OrganizationSetting key=" - + OrganizationSetting.BRAND_LOGO - ) - db.session.rollback() - abort(503) - db.session.commit() + elif delete_form.validate_on_submit(): + row_count = db.session.execute( + db.delete(OrganizationSetting).where( + OrganizationSetting.key == OrganizationSetting.BRAND_LOGO + ) + ).rowcount + if row_count > 1: + current_app.logger.error( + "Would have deleted multiple rows for OrganizationSetting key=" + + OrganizationSetting.BRAND_LOGO + ) + db.session.rollback() + abort(503) + db.session.commit() - public_store.delete(OrganizationSetting.BRAND_LOGO_VALUE) + public_store.delete(OrganizationSetting.BRAND_LOGO_VALUE) - flash("👍 Brand logo deleted.") + flash("👍 Brand logo deleted.") + else: + flash("⛔ Invalid form data. Please try again.") return redirect(url_for(".index")) @bp.route("/delete-account", methods=["POST"]) diff --git a/hushline/settings/forms.py b/hushline/settings/forms.py index 5356c33b..74386ea5 100644 --- a/hushline/settings/forms.py +++ b/hushline/settings/forms.py @@ -3,23 +3,42 @@ from flask import current_app from flask_wtf import FlaskForm -from flask_wtf.file import FileAllowed, FileField, FileRequired, FileSize +from flask_wtf.file import FileAllowed, FileField, FileSize +from markupsafe import Markup from wtforms import ( BooleanField, + Field, FormField, IntegerField, PasswordField, SelectField, StringField, + SubmitField, TextAreaField, ) from wtforms.validators import DataRequired, Email, Length from wtforms.validators import Optional as OptionalField +from wtforms.widgets.core import html_params from ..forms import CanonicalHTML, ComplexPassword, HexColor from ..model import SMTPEncryption +class Button: + html_params = staticmethod(html_params) + + def __call__(self, field: Field, **kwargs: Any) -> Markup: + kwargs.setdefault("id", field.id) + kwargs.setdefault("type", "submit") + kwargs.setdefault("value", field.label.text) + if "value" not in kwargs: + kwargs["value"] = field._value() + if "required" not in kwargs and "required" in getattr(field, "flags", []): + kwargs["required"] = True + params = self.html_params(name=field.name, **kwargs) + return Markup(f"") + + class ChangePasswordForm(FlaskForm): old_password = PasswordField("Old Password", validators=[DataRequired()]) new_password = PasswordField( @@ -192,8 +211,14 @@ class UpdateBrandLogoForm(FlaskForm): logo = FileField( "Logo (.png only)", validators=[ - FileRequired(), + # NOTE: not present because the same form w/ 2 submit buttonts is used for deletions + # FileRequired() FileAllowed(["png"], "Only PNG files are allowed"), FileSize(256 * 1000), # 256 KB ], ) + submit = SubmitField("Update Logo", name="update_logo", widget=Button()) + + +class DeleteBrandLogoForm(FlaskForm): + submit = SubmitField("Delete Logo", name="submit_logo", widget=Button()) diff --git a/hushline/storage.py b/hushline/storage.py index ac35e993..73be03f2 100644 --- a/hushline/storage.py +++ b/hushline/storage.py @@ -58,8 +58,10 @@ class FsDriver(StorageDriver): - BLOB_STORAGE_FS_ROOT """ - def __init__(self, app: Flask, config_prefix: Optional[str] = None) -> None: - super().__init__(config_prefix) + def __init__( + self, app: Flask, config_prefix: Optional[str] = None, is_public: bool = False + ) -> None: + super().__init__(config_prefix, is_public) root = Path(app.config[self._config_name("FS_ROOT")]) if root.absolute() != root: # needed for security checks later @@ -99,8 +101,10 @@ class S3Driver(StorageDriver): - BLOB_STORAGE_S3_CDN_ENDPOINT """ - def __init__(self, app: Flask, config_prefix: Optional[str] = None) -> None: - super().__init__(config_prefix) + def __init__( + self, app: Flask, config_prefix: Optional[str] = None, is_public: bool = False + ) -> None: + super().__init__(config_prefix, is_public) self.__bucket = app.config[self._config_name("S3_BUCKET")] self.__cdn_endpoint = app.config[self._config_name("S3_CDN_ENDPOINT")] @@ -125,21 +129,28 @@ def put(self, path: str, readable: IOBase) -> None: Key=path, Body=readable, ContentType=self.mime_type(path), - ACL="public" if self._is_public else "private", + ACL="public-read" if self._is_public else "private", ) def delete(self, path: str) -> None: self._client.delete_object(Bucket=self.__bucket, Key=path) def serve(self, path: str) -> Response: - url = self._client.generate_presigned_url( - ClientMethod="get_object", - Params={ - "Bucket": self.__bucket, - "Key": path, - }, - ExpiresIn=3600, - ) + if self._is_public: + url = ( + self.__cdn_endpoint + + ("" if self.__cdn_endpoint.endswith("/") or path.startswith("/") else "/") + + path + ) + else: + url = self._client.generate_presigned_url( + ClientMethod="get_object", + Params={ + "Bucket": self.__bucket, + "Key": path, + }, + ExpiresIn=3600, + ) return redirect(url) @@ -157,9 +168,9 @@ def init_app(self, app: Flask) -> None: driver: Optional[StorageDriver] match app.config.get(self._config_name("DRIVER")) or None: case "s3": - driver = S3Driver(app, self._config_prefix) + driver = S3Driver(app, self._config_prefix, self._is_public) case "file-system": - driver = FsDriver(app, self._config_prefix) + driver = FsDriver(app, self._config_prefix, self._is_public) case "none": driver = None case None: @@ -186,4 +197,4 @@ def serve(self, path: str) -> Response: return self._driver.serve(path) -public_store = BlobStorage("PUBLIC") +public_store = BlobStorage("PUBLIC", is_public=True) diff --git a/hushline/templates/settings/branding.html b/hushline/templates/settings/branding.html index 2db3e5d3..562539a3 100644 --- a/hushline/templates/settings/branding.html +++ b/hushline/templates/settings/branding.html @@ -50,15 +50,8 @@

Logo

> {{ update_brand_logo_form.hidden_tag() }} {{ update_brand_logo_form.logo(accept=".png") }} - - - -
- + {{ update_brand_logo_form.submit(class="btn") }} + {{ delete_brand_logo_form.submit(class="btn-danger") }}
{% endif %} diff --git a/tests/test_settings.py b/tests/test_settings.py index a2d8eb85..9bf1e352 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -625,7 +625,10 @@ def test_update_brand_logo(client: FlaskClient, admin: User) -> None: resp = client.post( url_for("settings.update_brand_logo"), - data={"logo": (BytesIO(png), "wat.png")}, + data={ + "logo": (BytesIO(png), "wat.png"), + "update_logo": "", + }, follow_redirects=True, content_type="multipart/form-data", ) @@ -650,3 +653,16 @@ def test_update_brand_logo(client: FlaskClient, admin: User) -> None: resp = client.get(logo_url, follow_redirects=True) assert resp.status_code == 200 assert resp.data == png + + resp = client.post( + url_for("settings.update_brand_logo"), + data={"delete_logo": ""}, + follow_redirects=True, + ) + assert resp.status_code == 200 + assert "Brand logo deleted" in resp.text + + # check the file is not accessible + resp = client.get(logo_url, follow_redirects=True) + # yes this check is ridiculous. why? because we redirect not-founds instead of actually 404-ing + assert "That page doesn" in resp.text