1
+ from __future__ import annotations
2
+
1
3
import asyncio
2
4
import json
3
5
import time
4
6
import uuid
5
7
from typing import Any , Callable , Dict , MutableMapping , Optional , cast
6
8
7
- import aioredis
8
9
import pytest
9
10
from aiohttp import web
10
11
from aiohttp .test_utils import TestClient
11
12
from pytest_mock import MockFixture
13
+ from redis import asyncio as aioredis
12
14
13
15
from aiohttp_session import Handler , Session , get_session , session_middleware
14
16
from aiohttp_session .redis_storage import RedisStorage
18
20
19
21
def create_app (
20
22
handler : Handler ,
21
- redis : aioredis .Redis ,
23
+ redis : aioredis .Redis [ bytes ] ,
22
24
max_age : Optional [int ] = None ,
23
25
key_factory : Callable [[], str ] = lambda : uuid .uuid4 ().hex ,
24
26
) -> web .Application :
@@ -31,7 +33,7 @@ def create_app(
31
33
32
34
33
35
async def make_cookie (
34
- client : TestClient , redis : aioredis .Redis , data : Dict [Any , Any ]
36
+ client : TestClient , redis : aioredis .Redis [ bytes ] , data : Dict [Any , Any ]
35
37
) -> None :
36
38
session_data = {"session" : data , "created" : int (time .time ())}
37
39
value = json .dumps (session_data )
@@ -40,23 +42,21 @@ async def make_cookie(
40
42
client .session .cookie_jar .update_cookies ({"AIOHTTP_SESSION" : key })
41
43
42
44
43
- async def make_cookie_with_bad_value (client : TestClient , redis : aioredis .Redis ) -> None :
45
+ async def make_cookie_with_bad_value (client : TestClient , redis : aioredis .Redis [ bytes ] ) -> None :
44
46
key = uuid .uuid4 ().hex
45
47
await redis .set ("AIOHTTP_SESSION_" + key , "" )
46
48
client .session .cookie_jar .update_cookies ({"AIOHTTP_SESSION" : key })
47
49
48
50
49
- async def load_cookie (client : TestClient , redis : aioredis .Redis ) -> Any :
51
+ async def load_cookie (client : TestClient , redis : aioredis .Redis [ bytes ] ) -> Any :
50
52
cookies = client .session .cookie_jar .filter_cookies (client .make_url ("/" ))
51
53
key = cookies ["AIOHTTP_SESSION" ]
52
- encoded = await redis .get ("AIOHTTP_SESSION_" + key .value )
53
- s = encoded .decode ("utf-8" )
54
- value = json .loads (s )
55
- return value
54
+ value_bytes = await redis .get ("AIOHTTP_SESSION_" + key .value )
55
+ return None if value_bytes is None else json .loads (value_bytes )
56
56
57
57
58
58
async def test_create_new_session (
59
- aiohttp_client : AiohttpClient , redis : aioredis .Redis
59
+ aiohttp_client : AiohttpClient , redis : aioredis .Redis [ bytes ]
60
60
) -> None :
61
61
async def handler (request : web .Request ) -> web .StreamResponse :
62
62
session = await get_session (request )
@@ -72,7 +72,7 @@ async def handler(request: web.Request) -> web.StreamResponse:
72
72
73
73
74
74
async def test_load_existing_session (
75
- aiohttp_client : AiohttpClient , redis : aioredis .Redis
75
+ aiohttp_client : AiohttpClient , redis : aioredis .Redis [ bytes ]
76
76
) -> None :
77
77
async def handler (request : web .Request ) -> web .StreamResponse :
78
78
session = await get_session (request )
@@ -89,7 +89,7 @@ async def handler(request: web.Request) -> web.StreamResponse:
89
89
90
90
91
91
async def test_load_bad_session (
92
- aiohttp_client : AiohttpClient , redis : aioredis .Redis
92
+ aiohttp_client : AiohttpClient , redis : aioredis .Redis [ bytes ]
93
93
) -> None :
94
94
async def handler (request : web .Request ) -> web .StreamResponse :
95
95
session = await get_session (request )
@@ -106,7 +106,7 @@ async def handler(request: web.Request) -> web.StreamResponse:
106
106
107
107
108
108
async def test_change_session (
109
- aiohttp_client : AiohttpClient , redis : aioredis .Redis
109
+ aiohttp_client : AiohttpClient , redis : aioredis .Redis [ bytes ]
110
110
) -> None :
111
111
async def handler (request : web .Request ) -> web .StreamResponse :
112
112
session = await get_session (request )
@@ -133,7 +133,7 @@ async def handler(request: web.Request) -> web.StreamResponse:
133
133
134
134
135
135
async def test_clear_cookie_on_session_invalidation (
136
- aiohttp_client : AiohttpClient , redis : aioredis .Redis
136
+ aiohttp_client : AiohttpClient , redis : aioredis .Redis [ bytes ]
137
137
) -> None :
138
138
async def handler (request : web .Request ) -> web .StreamResponse :
139
139
session = await get_session (request )
@@ -154,7 +154,7 @@ async def handler(request: web.Request) -> web.StreamResponse:
154
154
155
155
156
156
async def test_create_cookie_in_handler (
157
- aiohttp_client : AiohttpClient , redis : aioredis .Redis
157
+ aiohttp_client : AiohttpClient , redis : aioredis .Redis [ bytes ]
158
158
) -> None :
159
159
async def handler (request : web .Request ) -> web .StreamResponse :
160
160
session = await get_session (request )
@@ -181,7 +181,7 @@ async def handler(request: web.Request) -> web.StreamResponse:
181
181
182
182
183
183
async def test_set_ttl_on_session_saving (
184
- aiohttp_client : AiohttpClient , redis : aioredis .Redis
184
+ aiohttp_client : AiohttpClient , redis : aioredis .Redis [ bytes ]
185
185
) -> None :
186
186
async def handler (request : web .Request ) -> web .StreamResponse :
187
187
session = await get_session (request )
@@ -201,7 +201,7 @@ async def handler(request: web.Request) -> web.StreamResponse:
201
201
202
202
203
203
async def test_set_ttl_manually_set (
204
- aiohttp_client : AiohttpClient , redis : aioredis .Redis
204
+ aiohttp_client : AiohttpClient , redis : aioredis .Redis [ bytes ]
205
205
) -> None :
206
206
async def handler (request : web .Request ) -> web .StreamResponse :
207
207
session = await get_session (request )
@@ -222,7 +222,7 @@ async def handler(request: web.Request) -> web.StreamResponse:
222
222
223
223
224
224
async def test_create_new_session_if_key_doesnt_exists_in_redis (
225
- aiohttp_client : AiohttpClient , redis : aioredis .Redis
225
+ aiohttp_client : AiohttpClient , redis : aioredis .Redis [ bytes ]
226
226
) -> None :
227
227
async def handler (request : web .Request ) -> web .StreamResponse :
228
228
session = await get_session (request )
@@ -236,7 +236,7 @@ async def handler(request: web.Request) -> web.StreamResponse:
236
236
237
237
238
238
async def test_create_storage_with_custom_key_factory (
239
- aiohttp_client : AiohttpClient , redis : aioredis .Redis
239
+ aiohttp_client : AiohttpClient , redis : aioredis .Redis [ bytes ]
240
240
) -> None :
241
241
async def handler (request : web .Request ) -> web .StreamResponse :
242
242
session = await get_session (request )
@@ -259,7 +259,7 @@ def key_factory() -> str:
259
259
260
260
261
261
async def test_redis_session_fixation (
262
- aiohttp_client : AiohttpClient , redis : aioredis .Redis
262
+ aiohttp_client : AiohttpClient , redis : aioredis .Redis [ bytes ]
263
263
) -> None :
264
264
async def login (request : web .Request ) -> web .StreamResponse :
265
265
session = await get_session (request )
@@ -288,7 +288,7 @@ async def test_redis_from_create_pool(redis_url: str) -> None:
288
288
async def handler (request : web .Request ) -> web .StreamResponse :
289
289
pass
290
290
291
- redis = aioredis .from_url (redis_url ) # type: ignore[no-untyped-call]
291
+ redis = aioredis .from_url (redis_url )
292
292
create_app (handler = handler , redis = redis )
293
293
await redis .close ()
294
294
@@ -314,17 +314,13 @@ async def test_old_aioredis_version(mocker: MockFixture) -> None:
314
314
async def handler (request : web .Request ) -> web .StreamResponse :
315
315
pass
316
316
317
- class Dummy :
318
- def __init__ (self , * args : object , ** kwargs : object ) -> None :
319
- self .version = (0 , 3 )
320
-
321
- mocker .patch ("aiohttp_session.redis_storage.StrictVersion" , Dummy )
317
+ mocker .patch ("aiohttp_session.redis_storage.REDIS_VERSION" , (0 , 3 , "dev0" ))
322
318
with pytest .raises (RuntimeError ):
323
319
create_app (handler = handler , redis = None ) # type: ignore[arg-type]
324
320
325
321
326
322
async def test_load_session_dont_load_expired_session (
327
- aiohttp_client : AiohttpClient , redis : aioredis .Redis
323
+ aiohttp_client : AiohttpClient , redis : aioredis .Redis [ bytes ]
328
324
) -> None :
329
325
async def handler (request : web .Request ) -> web .StreamResponse :
330
326
session = await get_session (request )
0 commit comments