Skip to content

Commit

Permalink
Refactor: Redis usage (#70)
Browse files Browse the repository at this point in the history
* refactor: using redis connection pool to replace create connection each times and reduce duplicate code

* impr: graceful shutdown

* style: lint code
  • Loading branch information
tobiichi3227 authored May 26, 2024
1 parent 6d5147d commit 9a8e4f0
Show file tree
Hide file tree
Showing 15 changed files with 180 additions and 140 deletions.
5 changes: 4 additions & 1 deletion src/handlers/acct.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ async def post(self):
self.error('Eacces')
return

err, _ = await UserService.inst.update_acct(self.acct.acct_id, self.acct.acct_type, self.acct.acct_class, name, photo, cover)
err, _ = await UserService.inst.update_acct(
self.acct.acct_id, self.acct.acct_type, self.acct.acct_class, name, photo, cover
)
if err:
self.error(err)
return
Expand All @@ -109,6 +111,7 @@ async def post(self):

self.error('Eunk')


class SignHandler(RequestHandler):
@reqenv
async def get(self):
Expand Down
28 changes: 24 additions & 4 deletions src/handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import asyncio
import datetime
import json

import asyncpg
import tornado.gen
import tornado.template
import tornado.web
import tornado.websocket
from redis import asyncio as aioredis

from services.user import UserService


class RequestHandler(tornado.web.RequestHandler):
def __init__(self, *args, **kwargs):
self.db = kwargs.pop('db')
self.rs = kwargs.pop('rs')
self.db: asyncpg.Pool = kwargs.pop('db')
self.rs: aioredis.Redis = kwargs.pop('rs')
self.tpldr = tornado.template.Loader('static/templ')

super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -56,12 +59,29 @@ def default(self, obj):

class WebSocketHandler(tornado.websocket.WebSocketHandler):
def __init__(self, *args, **kwargs):
self.db = kwargs.pop('db')
self.rs = kwargs.pop('rs')
self.db: asyncpg.Pool = kwargs.pop('db')
self.rs: aioredis.Redis = kwargs.pop('rs')

super().__init__(*args, **kwargs)


class WebSocketSubHandler(tornado.websocket.WebSocketHandler):
def __init__(self, *args, **kwargs):
pool = kwargs.pop('pool')
self.rs: aioredis.Redis = aioredis.Redis(connection_pool=pool)
self.p = self.rs.pubsub()
self.task: asyncio.Task = None

super().__init__(*args, **kwargs)

def check_origin(self, origin: str) -> bool:
return True

def on_close(self) -> None:
self.task.cancel()
asyncio.create_task(self.rs.aclose())


def reqenv(func):
# @tornado.gen.coroutine
async def wrap(self, *args, **kwargs):
Expand Down
38 changes: 16 additions & 22 deletions src/handlers/bulletin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from redis import asyncio as aioredis

from handlers.base import RequestHandler, WebSocketHandler, reqenv
from handlers.base import RequestHandler, WebSocketSubHandler, reqenv
from services.bulletin import BulletinService
from services.judge import JudgeServerClusterService

Expand All @@ -23,31 +23,25 @@ async def get(self, bulletin_id=None):
await self.render('bulletin', bulletin=bulletin)


class BulletinSub(WebSocketHandler):
async def open(self):
self.ars = aioredis.Redis(host='localhost', port=6379, db=1)
await self.ars.incr('online_counter', 1)
await self.ars.sadd('online_counter_set', self.request.remote_ip)
self.p = self.ars.pubsub()
await self.p.subscribe('bulletinsub')
class BulletinSub(WebSocketSubHandler):
async def listen_newbulletin(self):
async for msg in self.p.listen():
if msg['type'] != 'message':
continue

async def test():
async for msg in self.p.listen():
if msg['type'] != 'message':
continue
await self.on_message(str(int(msg['data'])))

await self.on_message(str(int(msg['data'])))
async def open(self):
await self.rs.incr('online_counter', 1)
await self.rs.sadd('online_counter_set', self.request.remote_ip)
await self.p.subscribe('bulletinsub')

self.task = asyncio.tasks.Task(test())
self.task = asyncio.tasks.Task(self.listen_newbulletin())

async def on_message(self, msg):
self.write_message(msg)
await self.write_message(msg)

def on_close(self) -> None:
asyncio.create_task(self.ars.decr('online_counter', 1))
asyncio.create_task(self.ars.srem('online_counter_set', self.request.remote_ip))
self.task.cancel()

def check_origin(self, origin):
# TODO: secure
return True
super().on_close()
asyncio.create_task(self.rs.decr('online_counter', 1))
asyncio.create_task(self.rs.srem('online_counter_set', self.request.remote_ip))
86 changes: 31 additions & 55 deletions src/handlers/chal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import tornado.web

from handlers.base import RequestHandler, WebSocketHandler, reqenv
from handlers.base import RequestHandler, WebSocketSubHandler, reqenv
from services.chal import ChalConst, ChalService
from services.pro import ProService
from services.user import UserService
Expand Down Expand Up @@ -103,52 +103,42 @@ async def get(self, chal_id):
from redis import asyncio as aioredis


class ChalListNewChalHandler(WebSocketHandler):
async def open(self):
self.ars = aioredis.Redis(host='localhost', port=6379, db=1)
self.p = self.ars.pubsub()
await self.p.subscribe('challist_sub')
class ChalListNewChalHandler(WebSocketSubHandler):
async def listen_challistnewchal(self):
async for msg in self.p.listen():
if msg['type'] != 'message':
continue

async def test():
async for msg in self.p.listen():
if msg['type'] != 'message':
continue
await self.on_message(str(int(msg['data'])))

await self.on_message(str(int(msg['data'])))
async def open(self):
await self.p.subscribe('challist_sub')

self.task = asyncio.tasks.Task(test())
self.task = asyncio.tasks.Task(self.listen_challistnewchal())

async def on_message(self, msg):
self.write_message(msg)
await self.write_message(msg)

def on_close(self) -> None:
self.task.cancel()

def check_origin(self, _):
return True
class ChalListNewStateHandler(WebSocketSubHandler):
async def listen_challiststate(self):
async for msg in self.p.listen():
if msg['type'] != 'message':
continue

chal_id = int(msg['data'])
if self.first_chal_id <= chal_id <= self.last_chal_id:
_, new_state = await ChalService.inst.get_single_chal_state_in_list(chal_id, self.acct)
await self.write_message(json.dumps(new_state))

class ChalListNewStateHandler(WebSocketHandler):
async def open(self):
self.first_chal_id = -1
self.last_chal_id = -1
self.acct = None

self.ars = aioredis.Redis(host='localhost', port=6379, db=1)
self.p = self.ars.pubsub()
await self.p.subscribe('challiststatesub')

async def listen_challiststate():
async for msg in self.p.listen():
if msg['type'] != 'message':
continue

chal_id = int(msg['data'])
if self.first_chal_id <= chal_id <= self.last_chal_id:
_, new_state = await ChalService.inst.get_single_chal_state_in_list(chal_id, self.acct)
await self.write_message(json.dumps(new_state))

self.task = asyncio.tasks.Task(listen_challiststate())
self.task = asyncio.tasks.Task(self.listen_challiststate())

async def on_message(self, msg):
if self.acct is None:
Expand All @@ -162,37 +152,23 @@ async def on_message(self, msg):

self.acct = acct

def on_close(self) -> None:
self.task.cancel()

def check_origin(self, _):
return True
class ChalNewStateHandler(WebSocketSubHandler):

async def listen_chalstate(self):
async for msg in self.p.listen():
if msg['type'] != 'message':
continue

if int(msg['data']) == self.chal_id:
_, chal_states = await ChalService.inst.get_chal_state(self.chal_id)
await self.write_message(json.dumps(chal_states))

class ChalNewStateHandler(WebSocketHandler):
async def open(self):
self.chal_id = -1
self.ars = aioredis.Redis(host='localhost', port=6379, db=1)
self.p = self.ars.pubsub()
await self.p.subscribe('chalstatesub')

async def listen_chalstate():
async for msg in self.p.listen():
if msg['type'] != 'message':
continue

if int(msg['data']) == self.chal_id:
_, chal_states = await ChalService.inst.get_chal_state(self.chal_id)
await self.write_message(json.dumps(chal_states))

self.task = asyncio.tasks.Task(listen_chalstate())
self.task = asyncio.tasks.Task(self.listen_chalstate())

async def on_message(self, msg):
if self.chal_id == -1 and msg.isdigit():
self.chal_id = int(msg)

def on_close(self) -> None:
self.task.cancel()

def check_origin(self, _):
return True
37 changes: 16 additions & 21 deletions src/handlers/manage/judge.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import asyncio
import base64

import config
from msgpack import packb, unpackb
from redis import asyncio as aioredis

from handlers.base import RequestHandler, WebSocketHandler, reqenv, require_permission
import config
from handlers.base import (
RequestHandler,
WebSocketSubHandler,
reqenv,
require_permission,
)
from services.judge import JudgeServerClusterService
from services.log import LogService
from services.user import UserConst
Expand Down Expand Up @@ -70,27 +74,18 @@ async def post(self):
self.finish('S')


class JudgeChalCntSub(WebSocketHandler):
async def open(self):
self.ars = aioredis.Redis(host='localhost', port=6379, db=1)
self.p = self.ars.pubsub()
await self.p.subscribe('judgechalcnt_sub')
class JudgeChalCntSub(WebSocketSubHandler):
async def listen_newchal(self):
async for msg in self.p.listen():
if msg['type'] != 'message':
continue

async def loop():
async for msg in self.p.listen():
if msg['type'] != 'message':
continue
await self.on_message(msg['data'].decode('utf-8'))

await self.on_message(msg['data'].decode('utf-8'))
async def open(self):
await self.p.subscribe('judgechalcnt_sub')

self.task = asyncio.tasks.Task(loop())
self.task = asyncio.tasks.Task(self.listen_newchal())

async def on_message(self, msg):
await self.write_message(msg)

def on_close(self) -> None:
self.task.cancel()

def check_origin(self, origin):
# TODO: secure
return True
6 changes: 4 additions & 2 deletions src/handlers/manage/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from handlers.manage.question import ManageQuestionHandler


def get_manage_url(db, rs):
def get_manage_url(db, rs, pool):
args = {
'db': db,
'rs': rs,
}

sub_args = {'pool': pool}

return [
('/manage/dash', ManageDashHandler, args),
('/manage/acct', ManageAcctHandler, args),
Expand All @@ -34,6 +36,6 @@ def get_manage_url(db, rs):
('/manage/question/(.+)', ManageQuestionHandler, args),
('/manage/group', ManageGroupHandler, args),
('/manage/judge', ManageJudgeHandler, args),
('/manage/judgecntws', JudgeChalCntSub, args),
('/manage/judgecntws', JudgeChalCntSub, sub_args),
('/manage/pack', ManagePackHandler, args),
]
Loading

0 comments on commit 9a8e4f0

Please sign in to comment.