Skip to content

Commit

Permalink
updates tests
Browse files Browse the repository at this point in the history
  • Loading branch information
saxix committed Nov 27, 2024
1 parent c061d76 commit c8220a0
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 28 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ requires-python = ">=3.10"
dependencies = [
"celery>=5.4.0",
"django-admin-extra-buttons>=1.5.8",
"django-concurrency>=2.6"
"django-concurrency>=2.6",
"django==4.2.*",
]

[project.optional-dependencies]
Expand Down
3 changes: 2 additions & 1 deletion src/django_celery_boost/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,10 @@ def task_info(self) -> "dict[str, Any]":
# "id": self.async_result.id,
"started_at": started_at,
"completed_at": date_done,
"last_update": date_done,
"status": task_status,
"error": error,
"result": query_result_id,
"query_result_id": query_result_id,
}
return ret

Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def pytest_configure(config):
settings.CELERY_TASK_ALWAYS_EAGER = False
settings.CELERY_BROKER_URL = os.environ.get("CELERY_BROKER_URL")
settings.DEMOAPP_PATH = DEMOAPP_PATH
settings.MESSAGE_STORAGE = "demo.messages.PlainCookieStorage"

from celery.fixups.django import DjangoWorkerFixup

Expand Down
183 changes: 183 additions & 0 deletions tests/demoapp/demo/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Code copied from https://github.com/saxix/django-adminactions - tests/demo/storage.py
# type: ignore
import json
from typing import Any, Optional

from django.conf import settings
from django.contrib.messages.storage.base import BaseStorage, Message
from django.core import signing
from django.http import SimpleCookie
from django.utils.crypto import constant_time_compare, salted_hmac
from django.utils.safestring import SafeData, mark_safe


class MessageEncoder(json.JSONEncoder):
"""
Compactly serialize instances of the ``Message`` class as JSON.
"""

message_key = "__json_message"

def default(self, obj: Any) -> Any:
if isinstance(obj, Message):
# Using 0/1 here instead of False/True to produce more compact json
is_safedata = 1 if isinstance(obj.message, SafeData) else 0
message = [self.message_key, is_safedata, obj.level, obj.message]
if obj.extra_tags:
message.append(obj.extra_tags)
return message
return super().default(obj)


class MessageDecoder(json.JSONDecoder):
"""
Decode JSON that includes serialized ``Message`` instances.
"""

def process_messages(self, obj: Any) -> Any:
if isinstance(obj, list) and obj:
if obj[0] == MessageEncoder.message_key:
if obj[1]:
obj[3] = mark_safe(obj[3])
return Message(*obj[2:])
return [self.process_messages(item) for item in obj]
if isinstance(obj, dict):
return {key: self.process_messages(value) for key, value in obj.items()}
return obj

def decode(self, s: Any, **kwargs: Any) -> Any:
decoded = super().decode(s, **kwargs)
return self.process_messages(decoded)


class PlainCookieStorage(BaseStorage):
"""
Store messages in a cookie.
"""

cookie_name = "messages"
# uwsgi's default configuration enforces a maximum size of 4kb for all the
# HTTP headers. In order to leave some room for other cookies and headers,
# restrict the session cookie to 1/2 of 4kb. See #18781.
max_cookie_size = 2048
not_finished = "__messagesnotfinished__"
key_salt = "django.contrib.messages"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.signer = signing.get_cookie_signer(salt=self.key_salt)

def _get(self, *args, **kwargs) -> Any:
"""
Retrieve a list of messages from the messages cookie. If the
not_finished sentinel value is found at the end of the message list,
remove it and return a result indicating that not all messages were
retrieved by this storage.
"""
data = self.request.COOKIES.get(self.cookie_name)
messages = self._decode(data)
all_retrieved = not (messages and messages[-1] == self.not_finished)
if messages and not all_retrieved:
# remove the sentinel value
messages.pop()
return messages, all_retrieved

def _update_cookie(self, encoded_data, response) -> Any:
"""
Either set the cookie with the encoded data if there is any data to
store, or delete the cookie.
"""
if encoded_data:
response.set_cookie(
self.cookie_name, encoded_data, domain=settings.SESSION_COOKIE_DOMAIN, secure=None, httponly=None
)
else:
response.delete_cookie(self.cookie_name, domain=settings.SESSION_COOKIE_DOMAIN)

def _store(self, messages, response, remove_oldest=True, *args, **kwargs) -> Any:
"""
Store the messages to a cookie and return a list of any messages which
could not be stored.
If the encoded data is larger than ``max_cookie_size``, remove
messages until the data fits (these are the messages which are
returned), and add the not_finished sentinel value to indicate as much.
"""
unstored_messages = []
encoded_data = self._encode(messages)
if self.max_cookie_size:
# data is going to be stored eventually by SimpleCookie, which
# adds its own overhead, which we must account for.
cookie = SimpleCookie() # create outside the loop

def stored_length(val):
return len(cookie.value_encode(val)[1])

while encoded_data and stored_length(encoded_data) > self.max_cookie_size:
if remove_oldest:
unstored_messages.append(messages.pop(0))
else:
unstored_messages.insert(0, messages.pop())
encoded_data = self._encode(messages + [self.not_finished], encode_empty=unstored_messages)
self._update_cookie(encoded_data, response)
return unstored_messages

def _legacy_hash(self, value) -> Any:
"""
# RemovedInDjango40Warning: pre-Django 3.1 hashes will be invalid.
Create an HMAC/SHA1 hash based on the value and the project setting's
SECRET_KEY, modified to make it unique for the present purpose.
"""
# The class wide key salt is not reused here since older Django
# versions had it fixed and making it dynamic would break old hashes if
# self.key_salt is changed.
key_salt = "django.contrib.messages"
return salted_hmac(key_salt, value).hexdigest()

def _encode(self, messages: str, encode_empty: bool = False) -> Any:
"""
Return an encoded version of the messages list which can be stored as
plain text.
Since the data will be retrieved from the client-side, the encoded data
also contains a hash to ensure that the data was not tampered with.
"""
if messages or encode_empty:
encoder = MessageEncoder(separators=(",", ":"))
value = encoder.encode(messages)
return self.signer.sign(value)

def _decode(self, data: str) -> Optional[Message]:
"""
Safely decode an encoded text stream back into a list of messages.
If the encoded text stream contained an invalid hash or was in an
invalid format, return None.
"""
if not data:
return None
try:
decoded = self.signer.unsign(data)
except signing.BadSignature:
# RemovedInDjango40Warning: when the deprecation ends, replace
# with:
# decoded = None.
decoded = self._legacy_decode(data)
if decoded:
try:
return json.loads(decoded, cls=MessageDecoder)
except json.JSONDecodeError:
pass
# Mark the data as used (so it gets removed) since something was wrong
# with the data.
self.used = True
return None

def _legacy_decode(self, data: Optional[Message]) -> Optional[Message]:
# RemovedInDjango40Warning: pre-Django 3.1 hashes will be invalid.
bits = data.split("$", 1)
if len(bits) == 2:
hash_, value = bits
if constant_time_compare(hash_, self._legacy_hash(value)):
return value
return None
56 changes: 33 additions & 23 deletions tests/test_admin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from unittest import mock

import pytest
from demo.factories import user_grant_permission
from demo.models import Job
from django.urls import reverse

pytestmark = [pytest.mark.admin]
Expand Down Expand Up @@ -73,47 +76,54 @@ def test_celery_queue(request, django_app, std_user, job):
res = django_app.get(url, user=std_user, expect_errors=True)
assert res.status_code == 403

with user_grant_permission(std_user, ["demo.queue_job"]):
with user_grant_permission(std_user, ["demo.queue_job", "demo.change_job"]):
res = django_app.get(url, user=std_user)
assert res.status_code == 200
res = res.forms[1].submit()
assert res.status_code == 302

res = django_app.get(url, user=std_user)
res = res.forms[1].submit()
assert res.status_code == 302
res = res.forms[1].submit().follow()
msgs = res.context["messages"]
assert [m.message for m in msgs] == ["Queued"]

res = django_app.get(url, user=std_user).follow()
msgs = res.context["messages"]
assert [m.message for m in msgs] == ["Task has already been queued."]


def test_celery_terminate(request, django_app, std_user, job):
url = reverse("admin:demo_job_celery_terminate", args=[job.pk])
res = django_app.get(url, user=std_user, expect_errors=True)
assert res.status_code == 403

with user_grant_permission(std_user, ["demo.terminate_job"]):
res = django_app.get(url, user=std_user)
assert res.status_code == 200
res = res.forms[1].submit()
assert res.status_code == 302
with user_grant_permission(std_user, ["demo.terminate_job", "demo.change_job"]):
res = django_app.get(url, user=std_user).follow()
msgs = res.context["messages"]
assert [m.message for m in msgs] == ["Task not queued."]

res = django_app.get(url, user=std_user)
res = res.forms[1].submit()
assert res.status_code == 302
with mock.patch.object(Job, "is_queued") as m:
m.return_value = True
res = django_app.get(url, user=std_user)
res = res.forms[1].submit().follow()
msgs = res.context["messages"]
assert [m.message for m in msgs] == ["Terminated"]


def test_celery_revoke(request, django_app, std_user, job):
url = reverse("admin:demo_job_celery_revoke", args=[job.pk])
res = django_app.get(url, user=std_user, expect_errors=True)
assert res.status_code == 403

with user_grant_permission(std_user, ["demo.revoke_job"]):
res = django_app.get(url, user=std_user)
assert res.status_code == 200
res = res.forms[1].submit()
assert res.status_code == 302

res = django_app.get(url, user=std_user)
res = res.forms[1].submit()
assert res.status_code == 302
with mock.patch.object(job, "is_queued", lambda: True):
with user_grant_permission(std_user, ["demo.revoke_job", "demo.change_job"]):
res = django_app.get(url, user=std_user).follow()
msgs = res.context["messages"]
assert [m.message for m in msgs] == ["Task not queued."]

with mock.patch.object(Job, "is_queued") as m:
m.return_value = True
res = django_app.get(url, user=std_user)
res = res.forms[1].submit().follow()
msgs = res.context["messages"]
assert [m.message for m in msgs] == ["Revoked"]


def test_check_status(request, django_app, std_user, job, queued):
Expand Down
8 changes: 5 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,17 @@ def test_model_task_info(db):
job1: Job = JobFactory()
assert job1.version == 1

assert job1.task_info == {"status": Job.NOT_SCHEDULED}
assert job1.task_info == {"status": Job.NOT_SCHEDULED, "completed_at": ""}
assert job1.queue()
job1.refresh_from_db()
assert job1.version == 1
assert job1.task_info == {
"error": "",
"completed_at": None,
"last_update": None,
"query_result_id": None,
"result": None,
"started_at": 0,
"started_at": "-",
"status": Job.PENDING,
}
with mock.patch(
Expand All @@ -127,7 +128,8 @@ def test_model_task_info(db):
"last_update": None,
"query_result_id": None,
"result": None,
"started_at": 0,
"started_at": "-",
"completed_at": None,
"status": Job.REVOKED,
}

Expand Down

0 comments on commit c8220a0

Please sign in to comment.