From ac1fe4a4e3d21df2c06198f9e89c36104529e1dc Mon Sep 17 00:00:00 2001 From: Trevor Berrange Sanchez Date: Fri, 8 Nov 2024 06:17:05 +0100 Subject: [PATCH] RPC server first pass --- src/rpc/lib.zig | 6 +- src/rpc/server.zig | 375 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 380 insertions(+), 1 deletion(-) create mode 100644 src/rpc/server.zig diff --git a/src/rpc/lib.zig b/src/rpc/lib.zig index 705e3f418..f02c138a2 100644 --- a/src/rpc/lib.zig +++ b/src/rpc/lib.zig @@ -1,9 +1,13 @@ pub const client = @import("client.zig"); +pub const server = @import("server.zig"); + pub const request = @import("request.zig"); pub const response = @import("response.zig"); pub const types = @import("types.zig"); -pub const ClusterType = types.ClusterType; pub const Client = client.Client; +pub const Server = server.Server; + +pub const ClusterType = types.ClusterType; pub const Request = request.Request; pub const Response = response.Response; diff --git a/src/rpc/server.zig b/src/rpc/server.zig new file mode 100644 index 000000000..d9627aee9 --- /dev/null +++ b/src/rpc/server.zig @@ -0,0 +1,375 @@ +const std = @import("std"); +const assert = std.debug.assert; +const sig = @import("../sig.zig"); + +const SnapshotGenerationInfo = sig.accounts_db.AccountsDB.SnapshotGenerationInfo; +const FullSnapshotFileInfo = sig.accounts_db.snapshots.FullSnapshotFileInfo; +const IncrementalSnapshotFileInfo = sig.accounts_db.snapshots.IncrementalSnapshotFileInfo; +const ThreadPool = sig.sync.ThreadPool; + +const logger_scope = "rpc.Server"; +const ScopedLogger = sig.trace.log.ScopedLogger(logger_scope); + +pub const Server = struct { + allocator: std.mem.Allocator, + logger: ScopedLogger, + + snapshot_dir: std.fs.Dir, + latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo), + + /// Wait group for all currently running tasks, used to wait for + /// all of them to finish before deinitializing. + wait_group: std.Thread.WaitGroup, + exit_internal: std.atomic.Value(bool), + exit_external: *std.atomic.Value(bool), + thread_pool: *ThreadPool, + + tcp_server: std.net.Server, + task_buffer_fl: TaskBufferFreelist, + + pub const InitParams = struct { + allocator: std.mem.Allocator, + logger: sig.trace.Logger, + + /// Not closed by the `Server`, but must live at least as long as it. + snapshot_dir: std.fs.Dir, + /// Should reflect the latest generated snapshot eligible for propagation at any + /// given time with respect to the contents of the specified `snapshot_dir`. + latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo), + + exit: *std.atomic.Value(bool), + thread_pool: *ThreadPool, + + /// The socket address to listen on for incoming HTTP and/or RPC requests. + socket_addr: sig.net.SocketAddr, + /// The size for the read buffer allocated to every request. + /// Clamped to be greater than or equal to `MIN_READ_BUFFER_SIZE`. + read_buffer_size: usize, + + pub const MIN_READ_BUFFER_SIZE = 256; + }; + + pub const InitError = std.net.Address.ListenError; + + pub fn init(params: InitParams) !Server { + var tcp_server = try params.socket_addr.toAddress().listen(.{ .force_nonblocking = true }); + errdefer tcp_server.deinit(); + return .{ + .allocator = params.allocator, + .logger = params.logger.withScope(logger_scope), + + .snapshot_dir = params.snapshot_dir, + .latest_snapshot_gen_info = params.latest_snapshot_gen_info, + + .wait_group = .{}, + .exit_internal = std.atomic.Value(bool).init(false), + .exit_external = params.exit, + .thread_pool = params.thread_pool, + + .tcp_server = tcp_server, + + .task_buffer_fl = .{ + .read_buffer_size = @max(params.read_buffer_size, InitParams.MIN_READ_BUFFER_SIZE), + .head_mtx = .{}, + .head = null, + }, + }; + } + + pub fn deinit(server: *Server) void { + server.exit_internal.store(true, .release); + server.wait_group.wait(); + server.tcp_server.deinit(); + server.task_buffer_fl.deinit(server.allocator); + } + + /// The main loop which handles incoming connections. + pub fn serve(server: *Server) !void { + while (true) { + if (server.exit_internal.load(.acquire)) break; + if (server.exit_external.load(.acquire)) break; + + const conn = server.tcp_server.accept() catch |err| switch (err) { + error.ProcessFdQuotaExceeded, + error.SystemFdQuotaExceeded, + error.SystemResources, + error.ProtocolFailure, + error.BlockedByFirewall, + error.NetworkSubsystemFailed, + error.Unexpected, + => |e| return e, + + error.FileDescriptorNotASocket, + error.SocketNotListening, + error.OperationNotSupported, + => unreachable, // Improperly initialized server. + + error.WouldBlock, + => continue, + + error.ConnectionResetByPeer, + error.ConnectionAborted, + => |e| { + server.logger.warn().logf("{}", .{e}); + continue; + }, + }; + errdefer conn.stream.close(); + + server.wait_group.start(); + errdefer server.wait_group.finish(); + + const new_hct = try server.task_buffer_fl.popOrCreateAndInit(server.allocator, .{ + .wait_group = &server.wait_group, + .conn = conn, + .logger = server.logger, + .latest_snapshot_gen_info = server.latest_snapshot_gen_info, + }); + server.thread_pool.schedule(ThreadPool.Batch.from(&new_hct.task)); + } + } +}; + +const TaskBufferFreelist = struct { + /// Does/must not change after init. + read_buffer_size: usize, + head_mtx: std.Thread.Mutex, + head: ?*TaskBufferNode, + + /// Assumes `fl.mutex` is already locked. + fn deinit(fl: *TaskBufferFreelist, allocator: std.mem.Allocator) void { + defer fl.head_mtx.unlock(); + while (fl.head) |current| { + fl.head = current.next; + destroyNode(allocator, fl.read_buffer_size, current); + } + } + + fn popOrCreateAndInit( + fl: *TaskBufferFreelist, + allocator: std.mem.Allocator, + init_data: HandleConnectionTask.Data, + ) std.mem.Allocator.Error!*HandleConnectionTask { + return fl.popAndInit(init_data) orelse try fl.createNode(allocator, init_data); + } + + fn createNode( + fl: *TaskBufferFreelist, + allocator: std.mem.Allocator, + init_data: HandleConnectionTask.Data, + ) std.mem.Allocator.Error!*HandleConnectionTask { + const buffer_size = HandleConnectionTask.fullAllocSize(fl.read_buffer_size); + const alignment = @alignOf(HandleConnectionTask); + const buffer = try allocator.alignedAlloc(u8, alignment, buffer_size); + return HandleConnectionTask.initBuffer(buffer, fl, init_data); + } + + fn popAndInit( + fl: *TaskBufferFreelist, + init_data: HandleConnectionTask.Data, + ) ?*HandleConnectionTask { + fl.head_mtx.lock(); + defer fl.head_mtx.unlock(); + + const popped = fl.head orelse return null; + fl.head = popped.next; + + const buffer_ptr: [*]align(@alignOf(HandleConnectionTask)) u8 = @ptrCast(popped); + const buffer = buffer_ptr[0..HandleConnectionTask.fullAllocSize(fl.read_buffer_size)]; + return HandleConnectionTask.initBuffer(buffer, fl, init_data); + } + + fn push(fl: *TaskBufferFreelist, hct: *HandleConnectionTask) void { + const tbn: *TaskBufferNode = @ptrCast(hct); + + fl.head_mtx.lock(); + defer fl.head_mtx.unlock(); + + tbn.* = .{ .next = fl.head }; + fl.head = tbn; + } + + fn destroyNode( + allocator: std.mem.Allocator, + read_buffer_size: usize, + tbn: *TaskBufferNode, + ) void { + const full_alloc_ptr: [*]align(@alignOf(HandleConnectionTask)) u8 = @ptrCast(tbn); + const full_alloc = full_alloc_ptr[0..HandleConnectionTask.fullAllocSize(read_buffer_size)]; + allocator.free(full_alloc); + } + + /// Secretly `*TaskBufferNode` = `HandleConnectionTask`. + const TaskBufferNode = extern struct { + next: ?*TaskBufferNode align(@alignOf(HandleConnectionTask)), + + comptime { + assert(@sizeOf(TaskBufferNode) <= @sizeOf(HandleConnectionTask)); + assert(@alignOf(TaskBufferNode) == @alignOf(HandleConnectionTask)); + } + }; +}; + +const HandleConnectionTask = struct { + task: ThreadPool.Task, + task_buffer_fl: *TaskBufferFreelist, + data: Data, + + const Data = struct { + wait_group: *std.Thread.WaitGroup, + conn: std.net.Server.Connection, + logger: ScopedLogger, + latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo), + }; + + fn fullAllocSize(read_buffer_size: usize) usize { + return @sizeOf(HandleConnectionTask) + read_buffer_size; + } + + fn trailingReadBuffer(hct: *HandleConnectionTask) []u8 { + const full_alloc: [*]align(@alignOf(HandleConnectionTask)) u8 = @ptrCast(hct); + return full_alloc[@sizeOf(HandleConnectionTask)..][0..hct.task_buffer_fl.read_buffer_size]; + } + + fn initBuffer( + /// Must be `== fullAllocSize(task_buffer_fl.read_buffer_size)` + buffer: []align(@alignOf(HandleConnectionTask)) u8, + task_buffer_fl: *TaskBufferFreelist, + data: Data, + ) *HandleConnectionTask { + assert(buffer.len == fullAllocSize(task_buffer_fl.read_buffer_size)); + const result: *HandleConnectionTask = std.mem.bytesAsValue( + HandleConnectionTask, + buffer[0..@sizeOf(HandleConnectionTask)], + ); + result.* = .{ + .task = .{ .callback = callback }, + .task_buffer_fl = task_buffer_fl, + .data = data, + }; + return result; + } + + fn callback(task: *ThreadPool.Task) void { + const hct: *HandleConnectionTask = @fieldParentPtr("task", task); + + const task_buffer_freelist = hct.task_buffer_fl; + defer task_buffer_freelist.push(hct); + + const wait_group = hct.data.wait_group; + defer wait_group.finish(); + + handleConnection( + hct.data.conn, + hct.trailingReadBuffer(), + hct.data.logger, + hct.data.latest_snapshot_gen_info, + ) catch |err| { + const logger = hct.data.logger; + if (@errorReturnTrace()) |stack_trace| { + logger.err().logf("{}\n{}", .{ err, stack_trace }); + } else { + logger.err().logf("{}", .{err}); + } + }; + } +}; + +fn handleConnection( + conn: std.net.Server.Connection, + read_buffer: []u8, + logger: ScopedLogger, + latest_snapshot_gen_info_rw: *sig.sync.RwMux(?SnapshotGenerationInfo), +) !void { + var http_server = std.http.Server.init(conn, read_buffer); + var request = try http_server.receiveHead(); + const head = request.head; + switch (head.method) { + .POST => { + logger.err().logf("{} tried to invoke our RPC", .{conn.address}); + return try request.respond("RPCs are not yet implemented", .{ + .status = .service_unavailable, + .keep_alive = false, + }); + }, + .GET => { + if (std.mem.startsWith(u8, head.target, "/")) { + + // we hold the lock for the entirety of this process in order to prevent + // the snapshot generation process from deleting the associated snapshot. + const maybe_latest_snapshot_gen_info, // + var latest_snapshot_info_lg // + = latest_snapshot_gen_info_rw.readWithLock(); + defer latest_snapshot_info_lg.unlock(); + + const full_info: ?FullSnapshotFileInfo, // + const inc_info: ?IncrementalSnapshotFileInfo // + = blk: { + const latest_snapshot_gen_info = maybe_latest_snapshot_gen_info.* orelse + break :blk .{ null, null }; + const latest_full = latest_snapshot_gen_info.full; + const full_info: FullSnapshotFileInfo = .{ + .slot = latest_full.slot, + .hash = latest_full.hash, + }; + const latest_incremental = latest_snapshot_gen_info.inc orelse + break :blk .{ full_info, null }; + const inc_info: IncrementalSnapshotFileInfo = .{ + .base_slot = latest_full.slot, + .slot = latest_incremental.slot, + .hash = latest_incremental.hash, + }; + break :blk .{ full_info, inc_info }; + }; + + if (full_info) |full| { + const full_archive_name_bounded = full.snapshotNameStr(); + const full_archive_name = full_archive_name_bounded.constSlice(); + if (std.mem.eql(u8, head.target[1..], full_archive_name)) { + @panic("TODO: send full snapshot"); + } + } + + if (inc_info) |inc| { + const inc_archive_name_bounded = inc.snapshotNameStr(); + const inc_archive_name = inc_archive_name_bounded.constSlice(); + if (std.mem.eql(u8, head.target[1..], inc_archive_name)) { + @panic("TODO: send inc snapshot"); + } + } + } + + logger.err().logf( + "{} requested an unrecognized resource 'GET {s}'", + .{ conn.address, head.target }, + ); + }, + else => { + logger.err().logf( + "{} made an unrecognized request '{} {s}'", + .{ conn.address, methodFmt(head.method), head.target }, + ); + }, + } + try request.respond("", .{ + .status = .not_found, + .keep_alive = false, + }); +} + +fn methodFmt(method: std.http.Method) MethodFmt { + return .{ .method = method }; +} +const MethodFmt = struct { + method: std.http.Method, + pub fn format( + fmt: MethodFmt, + comptime fmt_str: []const u8, + fmt_options: std.fmt.FormatOptions, + writer: anytype, + ) @TypeOf(writer).Error!void { + _ = fmt_options; + if (fmt_str.len != 0) std.fmt.invalidFmtError(fmt_str, fmt); + try fmt.method.write(writer); + } +};