Skip to content

Commit

Permalink
Update EasyAuditMiddleware to support async context
Browse files Browse the repository at this point in the history
Replaced standard threading with 'asgiref.local' in EasyAuditMiddleware. Also, made EasyAuditMiddleware extend Django's MiddlewareMixin to automatically handle sync and async execution modes.
Github issue: #291

Update EasyAuditMiddleware for async compatibility

The EasyAuditMiddleware class has been updated to be compatible with both synchronous and asynchronous processes. The class initialization now checks if the 'get_response' function is a coroutine and sets the class' async capability accordingly. An '__acall__' method has been introduced to handle asynchronous calls.

Refactor sync_to_async calls in test_main.py

Replaced sync_to_async keyword usage throughout the test_main.py file with direct asyncio calls, specifically in the ASGIRequestEvent tests. This change both simplifies the code and reduces reliance on the sync_to_async function.

Add test for async capability in middleware

A new test has been added to verify the async capability of the EasyAuditMiddleware. This ensures that the Django logger does not emit a debug message for asynchronous handler adaptation for the middleware. This is aligned with the Django documentation recommendations on async views.
  • Loading branch information
Kamil Niski committed Nov 25, 2024
1 parent 70e192b commit a495dd1
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 24 deletions.
41 changes: 24 additions & 17 deletions easyaudit/middleware/easyaudit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# makes easy-audit thread-safe
import contextlib
from threading import local
from typing import Callable

from asgiref.local import Local
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
from django.http.request import HttpRequest
from django.http.response import HttpResponse


class MockRequest:
Expand All @@ -10,7 +15,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


_thread_locals = local()
_thread_locals = Local()


def get_current_request():
Expand Down Expand Up @@ -38,30 +43,32 @@ def clear_request():


class EasyAuditMiddleware:
"""Makes request available to this app signals."""
async_capable = True
sync_capable = True

def __init__(self, get_response=None):
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:
self.get_response = get_response
if iscoroutinefunction(self.get_response):
markcoroutinefunction(self)

def __call__(self, request):
_thread_locals.request = (
request # seems redundant w/process_request, but keeping in for now.
)
if hasattr(self, "process_request"):
response = self.process_request(request)
response = response or self.get_response(request)
if hasattr(self, "process_response"):
response = self.process_response(request, response)
return response
def __call__(self, request: HttpRequest) -> HttpResponse:
if iscoroutinefunction(self):
return self.__acall__(request)

def process_request(self, request):
_thread_locals.request = request
response = self.get_response(request)

def process_response(self, request, response):
with contextlib.suppress(AttributeError):
del _thread_locals.request

return response

def process_exception(self, request, exception):
async def __acall__(self, request: HttpRequest) -> HttpResponse:
_thread_locals.request = request

response = await self.get_response(request)

with contextlib.suppress(AttributeError):
del _thread_locals.request

return response
32 changes: 25 additions & 7 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging

import pytest
from asgiref.sync import sync_to_async
Expand Down Expand Up @@ -297,16 +298,16 @@ def test_middleware_logged_in_user_in_request(self, user, client):
class TestASGIRequestEvent:
async def test_login(self, async_user, async_client, username, password):
await sync_to_async(async_client.login)(username=username, password=password)
assert await sync_to_async(RequestEvent.objects.count)() == 0
assert await RequestEvent.objects.acount() == 0

resp = await async_client.get(reverse("test_app:index"))
assert resp.status_code == 200

qs = await sync_to_async(RequestEvent.objects.filter)(user=async_user)
assert await sync_to_async(qs.exists)()
qs = RequestEvent.objects.filter(user=async_user)
assert await qs.aexists()

async def test_remote_addr_default(self, async_client):
assert await sync_to_async(RequestEvent.objects.count)() == 0
assert await RequestEvent.objects.acount() == 0

resp = await async_client.request(
method="GET",
Expand All @@ -318,11 +319,11 @@ async def test_remote_addr_default(self, async_client):
)
assert resp.status_code == 200

event = await sync_to_async(RequestEvent.objects.get)(url=reverse("test_app:index"))
event = await RequestEvent.objects.aget(url=reverse("test_app:index"))
assert event.remote_ip == "127.0.0.1"

async def test_remote_addr_another(self, async_client):
assert await sync_to_async(RequestEvent.objects.count)() == 0
assert await RequestEvent.objects.acount() == 0

resp = await async_client.request(
method="GET",
Expand All @@ -335,9 +336,26 @@ async def test_remote_addr_another(self, async_client):
)
assert resp.status_code == 200

event = await sync_to_async(RequestEvent.objects.get)(url=reverse("test_app:index"))
event = await RequestEvent.objects.aget(url=reverse("test_app:index"))
assert event.remote_ip == "10.0.0.1"

async def test_middleware_is_async_capable(self, async_client, caplog, settings):
"""Test for async capability of EasyAuditMiddleware.
If the EasyAuditMiddleware is async capable Django `django.request` logger
will not emit debug message 'Asynchronous handler adapted for middleware …'
See: https://docs.djangoproject.com/en/5.0/topics/async/#async-views
"""
unwanted_log_message = (
"Asynchronous handler adapted for middleware "
"easyaudit.middleware.easyaudit.EasyAuditMiddleware"
)
settings.DEBUG = True
with caplog.at_level(logging.DEBUG, "django.request"):
await async_client.get(reverse("test_app:index"))
assert unwanted_log_message not in caplog.text


@pytest.mark.django_db
class TestWSGIRequestEvent:
Expand Down

0 comments on commit a495dd1

Please sign in to comment.