Skip to content

Commit a99eaa6

Browse files
authored
Merge pull request #4 from amirreza8002/middleware3
implement make_middleware-decorator and related utils
2 parents 11dfb0b + 8b48db2 commit a99eaa6

File tree

6 files changed

+385
-0
lines changed

6 files changed

+385
-0
lines changed

django_async_extensions/utils/__init__.py

Whitespace-only changes.
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from functools import wraps
2+
3+
from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async
4+
5+
6+
def decorator_from_middleware_with_args(middleware_class):
7+
"""
8+
Like decorator_from_middleware, but return a function
9+
that accepts the arguments to be passed to the middleware_class.
10+
Use like::
11+
12+
cache_page = decorator_from_middleware_with_args(CacheMiddleware)
13+
# ...
14+
15+
@cache_page(3600)
16+
def my_view(request):
17+
# ...
18+
"""
19+
return make_middleware_decorator(middleware_class)
20+
21+
22+
def decorator_from_middleware(middleware_class):
23+
"""
24+
Given a middleware class (not an instance), return a view decorator. This
25+
lets you use middleware functionality on a per-view basis. The middleware
26+
is created with no params passed.
27+
"""
28+
return make_middleware_decorator(middleware_class)()
29+
30+
31+
def make_middleware_decorator(middleware_class):
32+
def _make_decorator(*m_args, **m_kwargs):
33+
def _decorator(view_func):
34+
middleware = middleware_class(view_func, *m_args, **m_kwargs)
35+
36+
async def _pre_process_request(request, *args, **kwargs):
37+
if hasattr(middleware, "process_request"):
38+
result = await middleware.process_request(request)
39+
if result is not None:
40+
return result
41+
if hasattr(middleware, "process_view"):
42+
if iscoroutinefunction(middleware.process_view):
43+
result = await middleware.process_view(
44+
request, view_func, args, kwargs
45+
)
46+
else:
47+
result = await sync_to_async(middleware.process_view)(
48+
request, view_func, args, kwargs
49+
)
50+
if result is not None:
51+
return result
52+
return None
53+
54+
async def _process_exception(request, exception):
55+
if hasattr(middleware, "process_exception"):
56+
if iscoroutinefunction(middleware.process_exception):
57+
result = await middleware.process_exception(request, exception)
58+
else:
59+
result = await sync_to_async(middleware.process_exception)(
60+
request, exception
61+
)
62+
if result is not None:
63+
return result
64+
raise
65+
66+
async def _post_process_request(request, response):
67+
if hasattr(response, "render") and callable(response.render):
68+
if hasattr(middleware, "process_template_response"):
69+
if iscoroutinefunction(middleware.process_template_response):
70+
response = await middleware.process_template_response(
71+
request, response
72+
)
73+
else:
74+
response = await sync_to_async(
75+
middleware.process_template_response
76+
)(request, response)
77+
# Defer running of process_response until after the template
78+
# has been rendered:
79+
if hasattr(middleware, "process_response"):
80+
81+
async def callback(response):
82+
return await middleware.process_response(request, response)
83+
84+
response.add_post_render_callback(async_to_sync(callback))
85+
else:
86+
if hasattr(middleware, "process_response"):
87+
return await middleware.process_response(request, response)
88+
return response
89+
90+
if iscoroutinefunction(view_func):
91+
92+
async def _view_wrapper(request, *args, **kwargs):
93+
result = await _pre_process_request(request, *args, **kwargs)
94+
if result is not None:
95+
return result
96+
97+
try:
98+
response = await view_func(request, *args, **kwargs)
99+
except Exception as e:
100+
result = await _process_exception(request, e)
101+
if result is not None:
102+
return result
103+
104+
return await _post_process_request(request, response)
105+
106+
else:
107+
108+
def _view_wrapper(request, *args, **kwargs):
109+
result = async_to_sync(_pre_process_request)(
110+
request, *args, **kwargs
111+
)
112+
if result is not None:
113+
return result
114+
115+
try:
116+
response = view_func(request, *args, **kwargs)
117+
except Exception as e:
118+
result = async_to_sync(_process_exception)(request, e)
119+
if result is not None:
120+
return result
121+
122+
return async_to_sync(_post_process_request)(request, response)
123+
124+
return wraps(view_func)(_view_wrapper)
125+
126+
return _decorator
127+
128+
return _make_decorator

docs/middleware/base.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ with the following specification:
1212
```
1313
where `get_response` is an **async function**, sync functions are not supported and **will raise** an error.
1414

15+
**Note:** you can use middlewares drove from this base class with normal django middlewares, you can even write sync views
16+
`get_response` is usually provided by django, so you don't have to worry about it being async.
17+
1518
----------------------------
1619

1720
other methods are as follows:

docs/middleware/decorate_views.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
`django_async_extensions.utils.decorators.decorator_from_middleware` and
2+
`django_async_extensions.utils.decorators.decorator_from_middleware_with_args`
3+
are provided to decorate a view with an async middleware directly.
4+
5+
they work almost exactly like django's [decorator_from_middleware](https://docs.djangoproject.com/en/5.1/ref/utils/#django.utils.decorators.decorator_from_middleware)
6+
and [decorator_from_middleware_with_args](https://docs.djangoproject.com/en/5.1/ref/utils/#django.utils.decorators.decorator_from_middleware_with_args)
7+
but it expects an async middleware as described in [AsyncMiddlewareMixin](base.md)
8+
9+
**Important:** if you are using a middleware that inherits from [AsyncMiddlewareMixin](base.md) you can only decorate async views
10+
if you need to decorate a sync view change middleware's `__init__()` method to accept async `get_response` argument.
11+
12+
with an async view
13+
```python
14+
from django.http.response import HttpResponse
15+
16+
from django_async_extensions.middleware.base import AsyncMiddlewareMixin
17+
from django_async_extensions.utils.decorators import decorator_from_middleware
18+
19+
class MyAsyncMiddleware(AsyncMiddlewareMixin):
20+
async def process_request(self, request):
21+
return HttpResponse()
22+
23+
24+
deco = decorator_from_middleware(MyAsyncMiddleware)
25+
26+
27+
@deco
28+
async def my_view(request):
29+
return HttpResponse()
30+
```
31+
32+
33+
if you need to use a sync view design your middleware like this
34+
```python
35+
from django_async_extensions.middleware.base import AsyncMiddlewareMixin
36+
37+
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
38+
39+
40+
class MyMiddleware(AsyncMiddlewareMixin):
41+
sync_capable = True
42+
43+
def __init__(self, get_response):
44+
if get_response is None:
45+
raise ValueError("get_response must be provided.")
46+
self.get_response = get_response
47+
48+
self.async_mode = iscoroutinefunction(self.get_response) or iscoroutinefunction(
49+
getattr(self.get_response, "__call__", None)
50+
)
51+
if self.async_mode:
52+
# Mark the class as async-capable.
53+
markcoroutinefunction(self)
54+
55+
super().__init__()
56+
```

tests/test_async_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# named like this to not conflict with something from django :/
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
from asgiref.sync import sync_to_async
2+
3+
import pytest
4+
5+
from django.http import HttpResponse
6+
from django.template import engines
7+
from django.template.response import TemplateResponse
8+
from django.test import RequestFactory
9+
10+
from django_async_extensions.middleware.base import AsyncMiddlewareMixin
11+
from django_async_extensions.utils.decorators import decorator_from_middleware
12+
13+
14+
class ProcessViewMiddleware(AsyncMiddlewareMixin):
15+
def __init__(self, get_response):
16+
self.get_response = get_response
17+
18+
async def process_view(self, request, view_func, view_args, view_kwargs):
19+
pass
20+
21+
22+
process_view_dec = decorator_from_middleware(ProcessViewMiddleware)
23+
24+
25+
@process_view_dec
26+
async def async_process_view(request):
27+
return HttpResponse()
28+
29+
30+
@process_view_dec
31+
def process_view(request):
32+
return HttpResponse()
33+
34+
35+
class ClassProcessView:
36+
def __call__(self, request):
37+
return HttpResponse()
38+
39+
40+
class_process_view = process_view_dec(ClassProcessView())
41+
42+
43+
class AsyncClassProcessView:
44+
async def __call__(self, request):
45+
return HttpResponse()
46+
47+
48+
async_class_process_view = process_view_dec(AsyncClassProcessView())
49+
50+
51+
class FullMiddleware(AsyncMiddlewareMixin):
52+
def __init__(self, get_response):
53+
self.get_response = get_response
54+
55+
async def process_request(self, request):
56+
request.process_request_reached = True
57+
58+
async def process_view(self, request, view_func, view_args, view_kwargs):
59+
request.process_view_reached = True
60+
61+
async def process_template_response(self, request, response):
62+
request.process_template_response_reached = True
63+
return response
64+
65+
async def process_response(self, request, response):
66+
# This should never receive unrendered content.
67+
request.process_response_content = response.content
68+
request.process_response_reached = True
69+
return response
70+
71+
72+
full_dec = decorator_from_middleware(FullMiddleware)
73+
74+
75+
class TestDecoratorFromMiddleware:
76+
"""
77+
Tests for view decorators created using
78+
``django.utils.decorators.decorator_from_middleware``.
79+
"""
80+
81+
rf = RequestFactory()
82+
83+
def test_process_view_middleware(self):
84+
"""
85+
Test a middleware that implements process_view.
86+
"""
87+
process_view(self.rf.get("/"))
88+
89+
async def test_process_view_middleware_async(self, async_rf):
90+
await async_process_view(async_rf.get("/"))
91+
92+
async def test_sync_process_view_raises_in_async_context(self):
93+
msg = (
94+
"You cannot use AsyncToSync in the same thread as an async event loop"
95+
" - just await the async function directly."
96+
)
97+
with pytest.raises(RuntimeError, match=msg):
98+
process_view(self.rf.get("/"))
99+
100+
def test_callable_process_view_middleware(self):
101+
"""
102+
Test a middleware that implements process_view, operating on a callable class.
103+
"""
104+
class_process_view(self.rf.get("/"))
105+
106+
async def test_callable_process_view_middleware_async(self, async_rf):
107+
await async_process_view(async_rf.get("/"))
108+
109+
def test_full_dec_normal(self):
110+
"""
111+
All methods of middleware are called for normal HttpResponses
112+
"""
113+
114+
@full_dec
115+
def normal_view(request):
116+
template = engines["django"].from_string("Hello world")
117+
return HttpResponse(template.render())
118+
119+
request = self.rf.get("/")
120+
normal_view(request)
121+
assert getattr(request, "process_request_reached", False)
122+
assert getattr(request, "process_view_reached", False)
123+
# process_template_response must not be called for HttpResponse
124+
assert getattr(request, "process_template_response_reached", False) is False
125+
assert getattr(request, "process_response_reached", False)
126+
127+
async def test_full_dec_normal_async(self, async_rf):
128+
"""
129+
All methods of middleware are called for normal HttpResponses
130+
"""
131+
132+
@full_dec
133+
async def normal_view(request):
134+
template = engines["django"].from_string("Hello world")
135+
return HttpResponse(template.render())
136+
137+
request = async_rf.get("/")
138+
await normal_view(request)
139+
assert getattr(request, "process_request_reached", False)
140+
assert getattr(request, "process_view_reached", False)
141+
# process_template_response must not be called for HttpResponse
142+
assert getattr(request, "process_template_response_reached", False) is False
143+
assert getattr(request, "process_response_reached", False)
144+
145+
def test_full_dec_templateresponse(self):
146+
"""
147+
All methods of middleware are called for TemplateResponses in
148+
the right sequence.
149+
"""
150+
151+
@full_dec
152+
def template_response_view(request):
153+
template = engines["django"].from_string("Hello world")
154+
return TemplateResponse(request, template)
155+
156+
request = self.rf.get("/")
157+
response = template_response_view(request)
158+
assert getattr(request, "process_request_reached", False)
159+
assert getattr(request, "process_view_reached", False)
160+
assert getattr(request, "process_template_response_reached", False)
161+
# response must not be rendered yet.
162+
assert response._is_rendered is False
163+
# process_response must not be called until after response is rendered,
164+
# otherwise some decorators like csrf_protect and gzip_page will not
165+
# work correctly. See #16004
166+
assert getattr(request, "process_response_reached", False) is False
167+
response.render()
168+
assert getattr(request, "process_response_reached", False)
169+
# process_response saw the rendered content
170+
assert request.process_response_content == b"Hello world"
171+
172+
async def test_full_dec_templateresponse_async(self, async_rf):
173+
"""
174+
All methods of middleware are called for TemplateResponses in
175+
the right sequence.
176+
"""
177+
178+
@full_dec
179+
async def template_response_view(request):
180+
template = engines["django"].from_string("Hello world")
181+
return TemplateResponse(request, template)
182+
183+
request = async_rf.get("/")
184+
response = await template_response_view(request)
185+
assert getattr(request, "process_request_reached", False)
186+
assert getattr(request, "process_view_reached", False)
187+
assert getattr(request, "process_template_response_reached", False)
188+
# response must not be rendered yet.
189+
assert response._is_rendered is False
190+
# process_response must not be called until after response is rendered,
191+
# otherwise some decorators like csrf_protect and gzip_page will not
192+
# work correctly. See #16004
193+
assert getattr(request, "process_response_reached", False) is False
194+
await sync_to_async(response.render)()
195+
assert getattr(request, "process_response_reached", False)
196+
# process_response saw the rendered content
197+
assert request.process_response_content == b"Hello world"

0 commit comments

Comments
 (0)