-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3699a0e
commit ac1fe4a
Showing
2 changed files
with
380 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,13 @@ | ||
pub const client = @import("client.zig"); | ||
pub const server = @import("server.zig"); | ||
|
||
pub const request = @import("request.zig"); | ||
pub const response = @import("response.zig"); | ||
pub const types = @import("types.zig"); | ||
|
||
pub const ClusterType = types.ClusterType; | ||
pub const Client = client.Client; | ||
pub const Server = server.Server; | ||
|
||
pub const ClusterType = types.ClusterType; | ||
pub const Request = request.Request; | ||
pub const Response = response.Response; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,375 @@ | ||
const std = @import("std"); | ||
const assert = std.debug.assert; | ||
const sig = @import("../sig.zig"); | ||
|
||
const SnapshotGenerationInfo = sig.accounts_db.AccountsDB.SnapshotGenerationInfo; | ||
const FullSnapshotFileInfo = sig.accounts_db.snapshots.FullSnapshotFileInfo; | ||
const IncrementalSnapshotFileInfo = sig.accounts_db.snapshots.IncrementalSnapshotFileInfo; | ||
const ThreadPool = sig.sync.ThreadPool; | ||
|
||
const logger_scope = "rpc.Server"; | ||
const ScopedLogger = sig.trace.log.ScopedLogger(logger_scope); | ||
|
||
pub const Server = struct { | ||
allocator: std.mem.Allocator, | ||
logger: ScopedLogger, | ||
|
||
snapshot_dir: std.fs.Dir, | ||
latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo), | ||
|
||
/// Wait group for all currently running tasks, used to wait for | ||
/// all of them to finish before deinitializing. | ||
wait_group: std.Thread.WaitGroup, | ||
exit_internal: std.atomic.Value(bool), | ||
exit_external: *std.atomic.Value(bool), | ||
thread_pool: *ThreadPool, | ||
|
||
tcp_server: std.net.Server, | ||
task_buffer_fl: TaskBufferFreelist, | ||
|
||
pub const InitParams = struct { | ||
allocator: std.mem.Allocator, | ||
logger: sig.trace.Logger, | ||
|
||
/// Not closed by the `Server`, but must live at least as long as it. | ||
snapshot_dir: std.fs.Dir, | ||
/// Should reflect the latest generated snapshot eligible for propagation at any | ||
/// given time with respect to the contents of the specified `snapshot_dir`. | ||
latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo), | ||
|
||
exit: *std.atomic.Value(bool), | ||
thread_pool: *ThreadPool, | ||
|
||
/// The socket address to listen on for incoming HTTP and/or RPC requests. | ||
socket_addr: sig.net.SocketAddr, | ||
/// The size for the read buffer allocated to every request. | ||
/// Clamped to be greater than or equal to `MIN_READ_BUFFER_SIZE`. | ||
read_buffer_size: usize, | ||
|
||
pub const MIN_READ_BUFFER_SIZE = 256; | ||
}; | ||
|
||
pub const InitError = std.net.Address.ListenError; | ||
|
||
pub fn init(params: InitParams) !Server { | ||
var tcp_server = try params.socket_addr.toAddress().listen(.{ .force_nonblocking = true }); | ||
errdefer tcp_server.deinit(); | ||
return .{ | ||
.allocator = params.allocator, | ||
.logger = params.logger.withScope(logger_scope), | ||
|
||
.snapshot_dir = params.snapshot_dir, | ||
.latest_snapshot_gen_info = params.latest_snapshot_gen_info, | ||
|
||
.wait_group = .{}, | ||
.exit_internal = std.atomic.Value(bool).init(false), | ||
.exit_external = params.exit, | ||
.thread_pool = params.thread_pool, | ||
|
||
.tcp_server = tcp_server, | ||
|
||
.task_buffer_fl = .{ | ||
.read_buffer_size = @max(params.read_buffer_size, InitParams.MIN_READ_BUFFER_SIZE), | ||
.head_mtx = .{}, | ||
.head = null, | ||
}, | ||
}; | ||
} | ||
|
||
pub fn deinit(server: *Server) void { | ||
server.exit_internal.store(true, .release); | ||
server.wait_group.wait(); | ||
server.tcp_server.deinit(); | ||
server.task_buffer_fl.deinit(server.allocator); | ||
} | ||
|
||
/// The main loop which handles incoming connections. | ||
pub fn serve(server: *Server) !void { | ||
while (true) { | ||
if (server.exit_internal.load(.acquire)) break; | ||
if (server.exit_external.load(.acquire)) break; | ||
|
||
const conn = server.tcp_server.accept() catch |err| switch (err) { | ||
error.ProcessFdQuotaExceeded, | ||
error.SystemFdQuotaExceeded, | ||
error.SystemResources, | ||
error.ProtocolFailure, | ||
error.BlockedByFirewall, | ||
error.NetworkSubsystemFailed, | ||
error.Unexpected, | ||
=> |e| return e, | ||
|
||
error.FileDescriptorNotASocket, | ||
error.SocketNotListening, | ||
error.OperationNotSupported, | ||
=> unreachable, // Improperly initialized server. | ||
|
||
error.WouldBlock, | ||
=> continue, | ||
|
||
error.ConnectionResetByPeer, | ||
error.ConnectionAborted, | ||
=> |e| { | ||
server.logger.warn().logf("{}", .{e}); | ||
continue; | ||
}, | ||
}; | ||
errdefer conn.stream.close(); | ||
|
||
server.wait_group.start(); | ||
errdefer server.wait_group.finish(); | ||
|
||
const new_hct = try server.task_buffer_fl.popOrCreateAndInit(server.allocator, .{ | ||
.wait_group = &server.wait_group, | ||
.conn = conn, | ||
.logger = server.logger, | ||
.latest_snapshot_gen_info = server.latest_snapshot_gen_info, | ||
}); | ||
server.thread_pool.schedule(ThreadPool.Batch.from(&new_hct.task)); | ||
} | ||
} | ||
}; | ||
|
||
const TaskBufferFreelist = struct { | ||
/// Does/must not change after init. | ||
read_buffer_size: usize, | ||
head_mtx: std.Thread.Mutex, | ||
head: ?*TaskBufferNode, | ||
|
||
/// Assumes `fl.mutex` is already locked. | ||
fn deinit(fl: *TaskBufferFreelist, allocator: std.mem.Allocator) void { | ||
defer fl.head_mtx.unlock(); | ||
while (fl.head) |current| { | ||
fl.head = current.next; | ||
destroyNode(allocator, fl.read_buffer_size, current); | ||
} | ||
} | ||
|
||
fn popOrCreateAndInit( | ||
fl: *TaskBufferFreelist, | ||
allocator: std.mem.Allocator, | ||
init_data: HandleConnectionTask.Data, | ||
) std.mem.Allocator.Error!*HandleConnectionTask { | ||
return fl.popAndInit(init_data) orelse try fl.createNode(allocator, init_data); | ||
} | ||
|
||
fn createNode( | ||
fl: *TaskBufferFreelist, | ||
allocator: std.mem.Allocator, | ||
init_data: HandleConnectionTask.Data, | ||
) std.mem.Allocator.Error!*HandleConnectionTask { | ||
const buffer_size = HandleConnectionTask.fullAllocSize(fl.read_buffer_size); | ||
const alignment = @alignOf(HandleConnectionTask); | ||
const buffer = try allocator.alignedAlloc(u8, alignment, buffer_size); | ||
return HandleConnectionTask.initBuffer(buffer, fl, init_data); | ||
} | ||
|
||
fn popAndInit( | ||
fl: *TaskBufferFreelist, | ||
init_data: HandleConnectionTask.Data, | ||
) ?*HandleConnectionTask { | ||
fl.head_mtx.lock(); | ||
defer fl.head_mtx.unlock(); | ||
|
||
const popped = fl.head orelse return null; | ||
fl.head = popped.next; | ||
|
||
const buffer_ptr: [*]align(@alignOf(HandleConnectionTask)) u8 = @ptrCast(popped); | ||
const buffer = buffer_ptr[0..HandleConnectionTask.fullAllocSize(fl.read_buffer_size)]; | ||
return HandleConnectionTask.initBuffer(buffer, fl, init_data); | ||
} | ||
|
||
fn push(fl: *TaskBufferFreelist, hct: *HandleConnectionTask) void { | ||
const tbn: *TaskBufferNode = @ptrCast(hct); | ||
|
||
fl.head_mtx.lock(); | ||
defer fl.head_mtx.unlock(); | ||
|
||
tbn.* = .{ .next = fl.head }; | ||
fl.head = tbn; | ||
} | ||
|
||
fn destroyNode( | ||
allocator: std.mem.Allocator, | ||
read_buffer_size: usize, | ||
tbn: *TaskBufferNode, | ||
) void { | ||
const full_alloc_ptr: [*]align(@alignOf(HandleConnectionTask)) u8 = @ptrCast(tbn); | ||
const full_alloc = full_alloc_ptr[0..HandleConnectionTask.fullAllocSize(read_buffer_size)]; | ||
allocator.free(full_alloc); | ||
} | ||
|
||
/// Secretly `*TaskBufferNode` = `HandleConnectionTask`. | ||
const TaskBufferNode = extern struct { | ||
next: ?*TaskBufferNode align(@alignOf(HandleConnectionTask)), | ||
|
||
comptime { | ||
assert(@sizeOf(TaskBufferNode) <= @sizeOf(HandleConnectionTask)); | ||
assert(@alignOf(TaskBufferNode) == @alignOf(HandleConnectionTask)); | ||
} | ||
}; | ||
}; | ||
|
||
const HandleConnectionTask = struct { | ||
task: ThreadPool.Task, | ||
task_buffer_fl: *TaskBufferFreelist, | ||
data: Data, | ||
|
||
const Data = struct { | ||
wait_group: *std.Thread.WaitGroup, | ||
conn: std.net.Server.Connection, | ||
logger: ScopedLogger, | ||
latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo), | ||
}; | ||
|
||
fn fullAllocSize(read_buffer_size: usize) usize { | ||
return @sizeOf(HandleConnectionTask) + read_buffer_size; | ||
} | ||
|
||
fn trailingReadBuffer(hct: *HandleConnectionTask) []u8 { | ||
const full_alloc: [*]align(@alignOf(HandleConnectionTask)) u8 = @ptrCast(hct); | ||
return full_alloc[@sizeOf(HandleConnectionTask)..][0..hct.task_buffer_fl.read_buffer_size]; | ||
} | ||
|
||
fn initBuffer( | ||
/// Must be `== fullAllocSize(task_buffer_fl.read_buffer_size)` | ||
buffer: []align(@alignOf(HandleConnectionTask)) u8, | ||
task_buffer_fl: *TaskBufferFreelist, | ||
data: Data, | ||
) *HandleConnectionTask { | ||
assert(buffer.len == fullAllocSize(task_buffer_fl.read_buffer_size)); | ||
const result: *HandleConnectionTask = std.mem.bytesAsValue( | ||
HandleConnectionTask, | ||
buffer[0..@sizeOf(HandleConnectionTask)], | ||
); | ||
result.* = .{ | ||
.task = .{ .callback = callback }, | ||
.task_buffer_fl = task_buffer_fl, | ||
.data = data, | ||
}; | ||
return result; | ||
} | ||
|
||
fn callback(task: *ThreadPool.Task) void { | ||
const hct: *HandleConnectionTask = @fieldParentPtr("task", task); | ||
|
||
const task_buffer_freelist = hct.task_buffer_fl; | ||
defer task_buffer_freelist.push(hct); | ||
|
||
const wait_group = hct.data.wait_group; | ||
defer wait_group.finish(); | ||
|
||
handleConnection( | ||
hct.data.conn, | ||
hct.trailingReadBuffer(), | ||
hct.data.logger, | ||
hct.data.latest_snapshot_gen_info, | ||
) catch |err| { | ||
const logger = hct.data.logger; | ||
if (@errorReturnTrace()) |stack_trace| { | ||
logger.err().logf("{}\n{}", .{ err, stack_trace }); | ||
} else { | ||
logger.err().logf("{}", .{err}); | ||
} | ||
}; | ||
} | ||
}; | ||
|
||
fn handleConnection( | ||
conn: std.net.Server.Connection, | ||
read_buffer: []u8, | ||
logger: ScopedLogger, | ||
latest_snapshot_gen_info_rw: *sig.sync.RwMux(?SnapshotGenerationInfo), | ||
) !void { | ||
var http_server = std.http.Server.init(conn, read_buffer); | ||
var request = try http_server.receiveHead(); | ||
const head = request.head; | ||
switch (head.method) { | ||
.POST => { | ||
logger.err().logf("{} tried to invoke our RPC", .{conn.address}); | ||
return try request.respond("RPCs are not yet implemented", .{ | ||
.status = .service_unavailable, | ||
.keep_alive = false, | ||
}); | ||
}, | ||
.GET => { | ||
if (std.mem.startsWith(u8, head.target, "/")) { | ||
|
||
// we hold the lock for the entirety of this process in order to prevent | ||
// the snapshot generation process from deleting the associated snapshot. | ||
const maybe_latest_snapshot_gen_info, // | ||
var latest_snapshot_info_lg // | ||
= latest_snapshot_gen_info_rw.readWithLock(); | ||
defer latest_snapshot_info_lg.unlock(); | ||
|
||
const full_info: ?FullSnapshotFileInfo, // | ||
const inc_info: ?IncrementalSnapshotFileInfo // | ||
= blk: { | ||
const latest_snapshot_gen_info = maybe_latest_snapshot_gen_info.* orelse | ||
break :blk .{ null, null }; | ||
const latest_full = latest_snapshot_gen_info.full; | ||
const full_info: FullSnapshotFileInfo = .{ | ||
.slot = latest_full.slot, | ||
.hash = latest_full.hash, | ||
}; | ||
const latest_incremental = latest_snapshot_gen_info.inc orelse | ||
break :blk .{ full_info, null }; | ||
const inc_info: IncrementalSnapshotFileInfo = .{ | ||
.base_slot = latest_full.slot, | ||
.slot = latest_incremental.slot, | ||
.hash = latest_incremental.hash, | ||
}; | ||
break :blk .{ full_info, inc_info }; | ||
}; | ||
|
||
if (full_info) |full| { | ||
const full_archive_name_bounded = full.snapshotNameStr(); | ||
const full_archive_name = full_archive_name_bounded.constSlice(); | ||
if (std.mem.eql(u8, head.target[1..], full_archive_name)) { | ||
@panic("TODO: send full snapshot"); | ||
} | ||
} | ||
|
||
if (inc_info) |inc| { | ||
const inc_archive_name_bounded = inc.snapshotNameStr(); | ||
const inc_archive_name = inc_archive_name_bounded.constSlice(); | ||
if (std.mem.eql(u8, head.target[1..], inc_archive_name)) { | ||
@panic("TODO: send inc snapshot"); | ||
} | ||
} | ||
} | ||
|
||
logger.err().logf( | ||
"{} requested an unrecognized resource 'GET {s}'", | ||
.{ conn.address, head.target }, | ||
); | ||
}, | ||
else => { | ||
logger.err().logf( | ||
"{} made an unrecognized request '{} {s}'", | ||
.{ conn.address, methodFmt(head.method), head.target }, | ||
); | ||
}, | ||
} | ||
try request.respond("", .{ | ||
.status = .not_found, | ||
.keep_alive = false, | ||
}); | ||
} | ||
|
||
fn methodFmt(method: std.http.Method) MethodFmt { | ||
return .{ .method = method }; | ||
} | ||
const MethodFmt = struct { | ||
method: std.http.Method, | ||
pub fn format( | ||
fmt: MethodFmt, | ||
comptime fmt_str: []const u8, | ||
fmt_options: std.fmt.FormatOptions, | ||
writer: anytype, | ||
) @TypeOf(writer).Error!void { | ||
_ = fmt_options; | ||
if (fmt_str.len != 0) std.fmt.invalidFmtError(fmt_str, fmt); | ||
try fmt.method.write(writer); | ||
} | ||
}; |