-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.py
175 lines (143 loc) · 5.6 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# coding: utf-8
import asyncio
import socket
import ssl
import os
from safe_block import Block, DecryptError
from xybase import StreamBase
from aisle import SyncLogger
from config_parse import PyxyConfig
class Server(StreamBase):
"""服务器对象"""
def __init__(self, config: PyxyConfig, name: str = None):
self.key_string = config.general["key"]
super().__init__(self.key_string)
self.config = config.server
self.logger.name = str(os.getpid())
if name:
self.logger = self.logger.get_child(suffix=name)
# 获取安全环境
self.safe_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
self.safe_context.load_cert_chain(
certfile=self.config["crt_file"], keyfile=self.config["key_file"]
)
# You can load your own cert and key files here.
async def start(self):
"""异步入口函数
addr: 连接监听地址
port: 连接监听端口
backlog: 监听队列长度,超过这个数量的并发连接将被拒绝
"""
self.logger: SyncLogger = self.logger.get_child(
f"{self.config['ipv4_address']}:{self.config['port']}"
)
server = await asyncio.start_server(
self.handler,
self.config["ipv4_address"],
self.config["port"],
# limit=4096, # 创建的流的缓冲大小
ssl=self.safe_context,
backlog=self.config["backlog"],
reuse_port=True,
)
self.logger.warning(
f"Server starting at {self.config['ipv4_address']}:{self.config['port']}"
)
async with server:
await server.serve_forever()
@StreamBase.handlerDeco
async def handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
"""处理请求,捕获所有的异常"""
request_id = self.total_conn_count
logger = self.logger.get_child(f"{request_id}")
# 请求处理主体
# 1. 预协商
try:
true_ip, true_domain, true_port = await self.__exchange_block(
reader, writer
)
logger.info(f"Get request > {true_ip}|{true_domain}:{true_port}")
except Exception as err:
logger.warning(f"Protocol fail > {type(err)} {err}")
return
# 2. 格式化目标地址
if (not true_ip) and (not true_domain):
logger.error("NO IP OR DOMAIN")
return
if true_domain:
true_ip = socket.gethostbyname(true_domain)
logger.info(f"Start true connect > {true_ip}|{true_domain}:{true_port}")
# 3. 尝试建立真实连接
bind_address, bind_port = "", 0
try:
true_reader, true_writer = await asyncio.open_connection(true_ip, true_port)
bind_address, bind_port = true_writer.get_extra_info("sockname")
except Exception as error:
logger.warning(f"Unexpected error > {type(error)}:{error}")
raise error
finally:
await self.__exchange_block(
reader,
writer,
{"bind_address": bind_address, "bind_port": bind_port},
)
# 4. 开始转发
try:
await self.exchange_stream(reader, writer, true_reader, true_writer)
# 第一步之后的异常处理
except socket.gaierror as error:
logger.error(f"DNS failure > {error}")
except ConnectionResetError as error:
logger.warning(f"Connection Reset > {error}")
return
except ConnectionRefusedError as error:
logger.warning(f"Connection Refused > {error}")
except TimeoutError as error:
logger.warning(f"Connection timeout > {error}")
except OSError as error:
logger.warning(f"System fail connection > {error}")
except Exception as error:
logger.error(f"Unknown error > {type(error)} {error}")
finally:
await asyncio.gather(self.try_close(true_writer), self.try_close(writer))
# 收尾工作
logger.debug("Request Handle End")
async def __exchange_block(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
payload: dict = None,
):
"""远程的连接预协商"""
if payload:
# 发送
request = Block(self.key, payload)
writer.write(request.block_bytes)
await writer.drain()
return
else:
# 接收
try:
response = await reader.read(4096) # TODO: 引入超时
block = Block.from_bytes(self.key, response)
true_ip = block.payload["ip"]
true_domain = block.payload["domain"]
true_port = block.payload["port"]
# self.logger.debug(f'收到客户端请求 {trueDomain} | {trueIp}: \
# {truePort}')
return true_ip, true_domain, true_port
except DecryptError as error:
# self.logger.error(f'解密失败, {e}')
raise error
except KeyError as error:
# self.logger.error(f'加密方式正确,但请求无效, {e}')
raise error
except Exception as error:
# self.logger.error(f'预协商失败, {e}')
raise error
if __name__ == "__main__":
config = PyxyConfig()
serverIPv4 = Server(config)
serverIPv6 = Server(config)
loop = asyncio.get_event_loop()
loop.run_until_complete(serverIPv4.start())