forked from Kozea/Radicale
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.py
380 lines (334 loc) · 16.1 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
# This file is part of Radicale - CalDAV and CardDAV server
# Copyright © 2008 Nicolas Kandel
# Copyright © 2008 Pascal Halter
# Copyright © 2008-2017 Guillaume Ayoub
# Copyright © 2017-2023 Unrud <[email protected]>
# Copyright © 2024-2024 Peter Bieringer <[email protected]>
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Radicale. If not, see <http://www.gnu.org/licenses/>.
"""
Built-in WSGI server.
"""
import http
import select
import socket
import socketserver
import ssl
import sys
import wsgiref.simple_server
from typing import (Any, Callable, Dict, List, MutableMapping, Optional, Set,
Tuple, Union)
from urllib.parse import unquote
from radicale import Application, config, utils
from radicale.log import logger
COMPAT_EAI_ADDRFAMILY: int
if hasattr(socket, "EAI_ADDRFAMILY"):
COMPAT_EAI_ADDRFAMILY = socket.EAI_ADDRFAMILY # type:ignore[attr-defined]
elif hasattr(socket, "EAI_NONAME"):
# Windows and BSD don't have a special error code for this
COMPAT_EAI_ADDRFAMILY = socket.EAI_NONAME
COMPAT_EAI_NODATA: int
if hasattr(socket, "EAI_NODATA"):
COMPAT_EAI_NODATA = socket.EAI_NODATA
elif hasattr(socket, "EAI_NONAME"):
# Windows and BSD don't have a special error code for this
COMPAT_EAI_NODATA = socket.EAI_NONAME
COMPAT_IPPROTO_IPV6: int
if hasattr(socket, "IPPROTO_IPV6"):
COMPAT_IPPROTO_IPV6 = socket.IPPROTO_IPV6
elif sys.platform == "win32":
# HACK: https://bugs.python.org/issue29515
COMPAT_IPPROTO_IPV6 = 41
# IPv4 (host, port) and IPv6 (host, port, flowinfo, scopeid)
ADDRESS_TYPE = Union[Tuple[Union[str, bytes, bytearray], int],
Tuple[str, int, int, int]]
def format_address(address: ADDRESS_TYPE) -> str:
host, port, *_ = address
if not isinstance(host, str):
raise NotImplementedError("Unsupported address format: %r" %
(address,))
if host.find(":") == -1:
return "%s:%d" % (host, port)
else:
return "[%s]:%d" % (host, port)
class ParallelHTTPServer(socketserver.ThreadingMixIn,
wsgiref.simple_server.WSGIServer):
configuration: config.Configuration
worker_sockets: Set[socket.socket]
_timeout: float
# We wait for child threads ourself (ThreadingMixIn)
block_on_close: bool = False
daemon_threads: bool = True
def __init__(self, configuration: config.Configuration, family: int,
address: Tuple[str, int], RequestHandlerClass:
Callable[..., http.server.BaseHTTPRequestHandler]) -> None:
self.configuration = configuration
self.address_family = family
super().__init__(address, RequestHandlerClass)
self.worker_sockets = set()
self._timeout = configuration.get("server", "timeout")
def server_bind(self) -> None:
if self.address_family == socket.AF_INET6:
# Only allow IPv6 connections to the IPv6 socket
self.socket.setsockopt(COMPAT_IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
super().server_bind()
def get_request( # type:ignore[override]
self) -> Tuple[socket.socket, Tuple[ADDRESS_TYPE, socket.socket]]:
# Set timeout for client
request: socket.socket
client_address: ADDRESS_TYPE
request, client_address = super().get_request() # type:ignore[misc]
if self._timeout > 0:
request.settimeout(self._timeout)
worker_socket, worker_socket_out = socket.socketpair()
self.worker_sockets.add(worker_socket_out)
# HACK: Forward `worker_socket` via `client_address` return value
# to worker thread.
# The super class calls `verify_request`, `process_request` and
# `handle_error` with modified `client_address` value.
return request, (client_address, worker_socket)
def verify_request( # type:ignore[override]
self, request: socket.socket, client_address_and_socket:
Tuple[ADDRESS_TYPE, socket.socket]) -> bool:
return True
def process_request( # type:ignore[override]
self, request: socket.socket, client_address_and_socket:
Tuple[ADDRESS_TYPE, socket.socket]) -> None:
# HACK: Super class calls `finish_request` in new thread with
# `client_address_and_socket`
return super().process_request(
request, client_address_and_socket) # type:ignore[arg-type]
def finish_request( # type:ignore[override]
self, request: socket.socket, client_address_and_socket:
Tuple[ADDRESS_TYPE, socket.socket]) -> None:
# HACK: Unpack `client_address_and_socket` and call super class
# `finish_request` with original `client_address`
client_address, worker_socket = client_address_and_socket
try:
return self.finish_request_locked(request, client_address)
finally:
worker_socket.close()
def finish_request_locked(self, request: socket.socket,
client_address: ADDRESS_TYPE) -> None:
return super().finish_request(
request, client_address) # type:ignore[arg-type]
def handle_error( # type:ignore[override]
self, request: socket.socket,
client_address_or_client_address_and_socket:
Union[ADDRESS_TYPE, Tuple[ADDRESS_TYPE, socket.socket]]) -> None:
# HACK: This method can be called with the modified
# `client_address_and_socket` or the original `client_address` value
e = sys.exc_info()[1]
assert e is not None
if isinstance(e, socket.timeout):
logger.info("Client timed out", exc_info=True)
else:
logger.error("An exception occurred during request: %s",
sys.exc_info()[1], exc_info=True)
class ParallelHTTPSServer(ParallelHTTPServer):
def server_bind(self) -> None:
super().server_bind()
# Wrap the TCP socket in an SSL socket
certfile: str = self.configuration.get("server", "certificate")
keyfile: str = self.configuration.get("server", "key")
cafile: str = self.configuration.get("server", "certificate_authority")
protocol: str = self.configuration.get("server", "protocol")
ciphersuite: str = self.configuration.get("server", "ciphersuite")
# Test if the files can be read
for name, filename in [("certificate", certfile), ("key", keyfile),
("certificate_authority", cafile)]:
type_name = config.DEFAULT_CONFIG_SCHEMA["server"][name][
"type"].__name__
source = self.configuration.get_source("server", name)
if name == "certificate_authority" and not filename:
continue
try:
open(filename).close()
except OSError as e:
raise RuntimeError(
"Invalid %s value for option %r in section %r in %s: %r "
"(%s)" % (type_name, name, "server", source, filename,
e)) from e
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
logger.info("SSL load files certificate='%s' key='%s'", certfile, keyfile)
context.load_cert_chain(certfile=certfile, keyfile=keyfile)
if protocol:
logger.info("SSL set explicit protocols (maybe not all supported by underlying OpenSSL): '%s'", protocol)
context.options = utils.ssl_context_options_by_protocol(protocol, context.options)
context.minimum_version = utils.ssl_context_minimum_version_by_options(context.options)
if (context.minimum_version == 0):
raise RuntimeError("No SSL minimum protocol active")
context.maximum_version = utils.ssl_context_maximum_version_by_options(context.options)
if (context.maximum_version == 0):
raise RuntimeError("No SSL maximum protocol active")
else:
logger.info("SSL active protocols: (system-default)")
logger.debug("SSL minimum acceptable protocol: %s", context.minimum_version)
logger.debug("SSL maximum acceptable protocol: %s", context.maximum_version)
logger.info("SSL accepted protocols: %s", ' '.join(utils.ssl_get_protocols(context)))
if ciphersuite:
logger.info("SSL set explicit ciphersuite (maybe not all supported by underlying OpenSSL): '%s'", ciphersuite)
context.set_ciphers(ciphersuite)
else:
logger.info("SSL active ciphersuite: (system-default)")
cipherlist = []
for entry in context.get_ciphers():
cipherlist.append(entry["name"])
logger.info("SSL accepted ciphers: %s", ' '.join(cipherlist))
if cafile:
logger.info("SSL enable mandatory client certificate verification using CA file='%s'", cafile)
context.load_verify_locations(cafile=cafile)
context.verify_mode = ssl.CERT_REQUIRED
self.socket = context.wrap_socket(
self.socket, server_side=True, do_handshake_on_connect=False)
def finish_request_locked( # type:ignore[override]
self, request: ssl.SSLSocket, client_address: ADDRESS_TYPE
) -> None:
try:
try:
request.do_handshake()
except socket.timeout:
raise
except Exception as e:
raise RuntimeError("SSL handshake failed: %s" % e) from e
except Exception:
try:
self.handle_error(request, client_address)
finally:
self.shutdown_request(request) # type:ignore[attr-defined]
return
return super().finish_request_locked(request, client_address)
class ServerHandler(wsgiref.simple_server.ServerHandler):
# Don't pollute WSGI environ with OS environment
os_environ: MutableMapping[str, str] = {}
def log_exception(self, exc_info) -> None:
logger.error("An exception occurred during request: %s",
exc_info[1], exc_info=exc_info) # type:ignore[arg-type]
class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
"""HTTP requests handler."""
# HACK: Assigned in `socketserver.StreamRequestHandler`
connection: socket.socket
def log_request(self, code: Union[int, str] = "-",
size: Union[int, str] = "-") -> None:
pass # Disable request logging.
def log_error(self, format_: str, *args: Any) -> None:
logger.error("An error occurred during request: %s", format_ % args)
def get_environ(self) -> Dict[str, Any]:
env = super().get_environ()
if isinstance(self.connection, ssl.SSLSocket):
# The certificate can be evaluated by the auth module
env["REMOTE_CERTIFICATE"] = self.connection.getpeercert()
# Parent class only tries latin1 encoding
env["PATH_INFO"] = unquote(self.path.split("?", 1)[0])
return env
def handle(self) -> None:
"""Copy of WSGIRequestHandler.handle with different ServerHandler"""
self.raw_requestline = self.rfile.readline(65537)
if len(self.raw_requestline) > 65536:
self.requestline = ""
self.request_version = ""
self.command = ""
self.send_error(414)
return
if not self.parse_request():
return
handler = ServerHandler(
self.rfile, self.wfile, self.get_stderr(), self.get_environ()
)
handler.request_handler = self # type:ignore[attr-defined]
app = self.server.get_app() # type:ignore[attr-defined]
handler.run(app)
def serve(configuration: config.Configuration,
shutdown_socket: Optional[socket.socket] = None) -> None:
"""Serve radicale from configuration.
`shutdown_socket` can be used to gracefully shutdown the server.
The socket can be created with `socket.socketpair()`, when the other socket
gets closed the server stops accepting new requests by clients and the
function returns after all active requests are finished.
"""
logger.info("Starting Radicale (%s)", utils.packages_version())
# Copy configuration before modifying
configuration = configuration.copy()
configuration.update({"server": {"_internal_server": "True"}}, "server",
privileged=True)
use_ssl: bool = configuration.get("server", "ssl")
server_class = ParallelHTTPSServer if use_ssl else ParallelHTTPServer
application = Application(configuration)
servers = {}
try:
hosts: List[Tuple[str, int]] = configuration.get("server", "hosts")
for address_port in hosts:
# retrieve IPv4/IPv6 address of address
try:
getaddrinfo = socket.getaddrinfo(address_port[0], address_port[1], 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)
except OSError as e:
logger.warning("cannot retrieve IPv4 or IPv6 address of '%s': %s" % (format_address(address_port), e))
continue
logger.debug("getaddrinfo of '%s': %s" % (format_address(address_port), getaddrinfo))
for (address_family, socket_kind, socket_proto, socket_flags, socket_address) in getaddrinfo:
logger.debug("try to create server socket on '%s'" % (format_address(socket_address)))
try:
server = server_class(configuration, address_family, (socket_address[0], socket_address[1]), RequestHandler)
except OSError as e:
logger.warning("cannot create server socket on '%s': %s" % (format_address(socket_address), e))
continue
servers[server.socket] = server
server.set_app(application)
logger.info("Listening on %r%s",
format_address(server.server_address),
" with SSL" if use_ssl else "")
if not servers:
raise RuntimeError("No servers started")
# Mainloop
select_timeout = None
if sys.platform == "win32":
# Fallback to busy waiting. (select(...) blocks SIGINT on Windows.)
select_timeout = 1.0
max_connections: int = configuration.get("server", "max_connections")
logger.info("Radicale server ready")
while True:
rlist: List[socket.socket] = []
# Wait for finished clients
for server in servers.values():
rlist.extend(server.worker_sockets)
# Accept new connections if max_connections is not reached
if max_connections <= 0 or len(rlist) < max_connections:
rlist.extend(servers)
# Use socket to get notified of program shutdown
if shutdown_socket is not None:
rlist.append(shutdown_socket)
rlist, _, _ = select.select(rlist, [], [], select_timeout)
rset = set(rlist)
if shutdown_socket in rset:
logger.info("Stopping Radicale")
break
for server in servers.values():
finished_sockets = server.worker_sockets.intersection(rset)
for s in finished_sockets:
s.close()
server.worker_sockets.remove(s)
rset.remove(s)
if finished_sockets:
server.service_actions()
if rset:
active_server = servers.get(rset.pop())
if active_server:
active_server.handle_request()
finally:
# Wait for clients to finish and close servers
for server in servers.values():
for s in server.worker_sockets:
s.recv(1)
s.close()
server.server_close()