Skip to content

Commit

Permalink
RPC server first pass
Browse files Browse the repository at this point in the history
  • Loading branch information
InKryption committed Nov 8, 2024
1 parent 3699a0e commit ac1fe4a
Show file tree
Hide file tree
Showing 2 changed files with 380 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/rpc/lib.zig
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
pub const client = @import("client.zig");
pub const server = @import("server.zig");

pub const request = @import("request.zig");
pub const response = @import("response.zig");
pub const types = @import("types.zig");

pub const ClusterType = types.ClusterType;
pub const Client = client.Client;
pub const Server = server.Server;

pub const ClusterType = types.ClusterType;
pub const Request = request.Request;
pub const Response = response.Response;
375 changes: 375 additions & 0 deletions src/rpc/server.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,375 @@
const std = @import("std");
const assert = std.debug.assert;
const sig = @import("../sig.zig");

const SnapshotGenerationInfo = sig.accounts_db.AccountsDB.SnapshotGenerationInfo;
const FullSnapshotFileInfo = sig.accounts_db.snapshots.FullSnapshotFileInfo;
const IncrementalSnapshotFileInfo = sig.accounts_db.snapshots.IncrementalSnapshotFileInfo;
const ThreadPool = sig.sync.ThreadPool;

const logger_scope = "rpc.Server";
const ScopedLogger = sig.trace.log.ScopedLogger(logger_scope);

pub const Server = struct {
allocator: std.mem.Allocator,
logger: ScopedLogger,

snapshot_dir: std.fs.Dir,
latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo),

/// Wait group for all currently running tasks, used to wait for
/// all of them to finish before deinitializing.
wait_group: std.Thread.WaitGroup,
exit_internal: std.atomic.Value(bool),
exit_external: *std.atomic.Value(bool),
thread_pool: *ThreadPool,

tcp_server: std.net.Server,
task_buffer_fl: TaskBufferFreelist,

pub const InitParams = struct {
allocator: std.mem.Allocator,
logger: sig.trace.Logger,

/// Not closed by the `Server`, but must live at least as long as it.
snapshot_dir: std.fs.Dir,
/// Should reflect the latest generated snapshot eligible for propagation at any
/// given time with respect to the contents of the specified `snapshot_dir`.
latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo),

exit: *std.atomic.Value(bool),
thread_pool: *ThreadPool,

/// The socket address to listen on for incoming HTTP and/or RPC requests.
socket_addr: sig.net.SocketAddr,
/// The size for the read buffer allocated to every request.
/// Clamped to be greater than or equal to `MIN_READ_BUFFER_SIZE`.
read_buffer_size: usize,

pub const MIN_READ_BUFFER_SIZE = 256;
};

pub const InitError = std.net.Address.ListenError;

pub fn init(params: InitParams) !Server {
var tcp_server = try params.socket_addr.toAddress().listen(.{ .force_nonblocking = true });
errdefer tcp_server.deinit();
return .{
.allocator = params.allocator,
.logger = params.logger.withScope(logger_scope),

.snapshot_dir = params.snapshot_dir,
.latest_snapshot_gen_info = params.latest_snapshot_gen_info,

.wait_group = .{},
.exit_internal = std.atomic.Value(bool).init(false),
.exit_external = params.exit,
.thread_pool = params.thread_pool,

.tcp_server = tcp_server,

.task_buffer_fl = .{
.read_buffer_size = @max(params.read_buffer_size, InitParams.MIN_READ_BUFFER_SIZE),
.head_mtx = .{},
.head = null,
},
};
}

pub fn deinit(server: *Server) void {
server.exit_internal.store(true, .release);
server.wait_group.wait();
server.tcp_server.deinit();
server.task_buffer_fl.deinit(server.allocator);
}

/// The main loop which handles incoming connections.
pub fn serve(server: *Server) !void {
while (true) {
if (server.exit_internal.load(.acquire)) break;
if (server.exit_external.load(.acquire)) break;

const conn = server.tcp_server.accept() catch |err| switch (err) {
error.ProcessFdQuotaExceeded,
error.SystemFdQuotaExceeded,
error.SystemResources,
error.ProtocolFailure,
error.BlockedByFirewall,
error.NetworkSubsystemFailed,
error.Unexpected,
=> |e| return e,

error.FileDescriptorNotASocket,
error.SocketNotListening,
error.OperationNotSupported,
=> unreachable, // Improperly initialized server.

error.WouldBlock,
=> continue,

error.ConnectionResetByPeer,
error.ConnectionAborted,
=> |e| {
server.logger.warn().logf("{}", .{e});
continue;
},
};
errdefer conn.stream.close();

server.wait_group.start();
errdefer server.wait_group.finish();

const new_hct = try server.task_buffer_fl.popOrCreateAndInit(server.allocator, .{
.wait_group = &server.wait_group,
.conn = conn,
.logger = server.logger,
.latest_snapshot_gen_info = server.latest_snapshot_gen_info,
});
server.thread_pool.schedule(ThreadPool.Batch.from(&new_hct.task));
}
}
};

const TaskBufferFreelist = struct {
/// Does/must not change after init.
read_buffer_size: usize,
head_mtx: std.Thread.Mutex,
head: ?*TaskBufferNode,

/// Assumes `fl.mutex` is already locked.
fn deinit(fl: *TaskBufferFreelist, allocator: std.mem.Allocator) void {
defer fl.head_mtx.unlock();
while (fl.head) |current| {
fl.head = current.next;
destroyNode(allocator, fl.read_buffer_size, current);
}
}

fn popOrCreateAndInit(
fl: *TaskBufferFreelist,
allocator: std.mem.Allocator,
init_data: HandleConnectionTask.Data,
) std.mem.Allocator.Error!*HandleConnectionTask {
return fl.popAndInit(init_data) orelse try fl.createNode(allocator, init_data);
}

fn createNode(
fl: *TaskBufferFreelist,
allocator: std.mem.Allocator,
init_data: HandleConnectionTask.Data,
) std.mem.Allocator.Error!*HandleConnectionTask {
const buffer_size = HandleConnectionTask.fullAllocSize(fl.read_buffer_size);
const alignment = @alignOf(HandleConnectionTask);
const buffer = try allocator.alignedAlloc(u8, alignment, buffer_size);
return HandleConnectionTask.initBuffer(buffer, fl, init_data);
}

fn popAndInit(
fl: *TaskBufferFreelist,
init_data: HandleConnectionTask.Data,
) ?*HandleConnectionTask {
fl.head_mtx.lock();
defer fl.head_mtx.unlock();

const popped = fl.head orelse return null;
fl.head = popped.next;

const buffer_ptr: [*]align(@alignOf(HandleConnectionTask)) u8 = @ptrCast(popped);
const buffer = buffer_ptr[0..HandleConnectionTask.fullAllocSize(fl.read_buffer_size)];
return HandleConnectionTask.initBuffer(buffer, fl, init_data);
}

fn push(fl: *TaskBufferFreelist, hct: *HandleConnectionTask) void {
const tbn: *TaskBufferNode = @ptrCast(hct);

fl.head_mtx.lock();
defer fl.head_mtx.unlock();

tbn.* = .{ .next = fl.head };
fl.head = tbn;
}

fn destroyNode(
allocator: std.mem.Allocator,
read_buffer_size: usize,
tbn: *TaskBufferNode,
) void {
const full_alloc_ptr: [*]align(@alignOf(HandleConnectionTask)) u8 = @ptrCast(tbn);
const full_alloc = full_alloc_ptr[0..HandleConnectionTask.fullAllocSize(read_buffer_size)];
allocator.free(full_alloc);
}

/// Secretly `*TaskBufferNode` = `HandleConnectionTask`.
const TaskBufferNode = extern struct {
next: ?*TaskBufferNode align(@alignOf(HandleConnectionTask)),

comptime {
assert(@sizeOf(TaskBufferNode) <= @sizeOf(HandleConnectionTask));
assert(@alignOf(TaskBufferNode) == @alignOf(HandleConnectionTask));
}
};
};

const HandleConnectionTask = struct {
task: ThreadPool.Task,
task_buffer_fl: *TaskBufferFreelist,
data: Data,

const Data = struct {
wait_group: *std.Thread.WaitGroup,
conn: std.net.Server.Connection,
logger: ScopedLogger,
latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo),
};

fn fullAllocSize(read_buffer_size: usize) usize {
return @sizeOf(HandleConnectionTask) + read_buffer_size;
}

fn trailingReadBuffer(hct: *HandleConnectionTask) []u8 {
const full_alloc: [*]align(@alignOf(HandleConnectionTask)) u8 = @ptrCast(hct);
return full_alloc[@sizeOf(HandleConnectionTask)..][0..hct.task_buffer_fl.read_buffer_size];
}

fn initBuffer(
/// Must be `== fullAllocSize(task_buffer_fl.read_buffer_size)`
buffer: []align(@alignOf(HandleConnectionTask)) u8,
task_buffer_fl: *TaskBufferFreelist,
data: Data,
) *HandleConnectionTask {
assert(buffer.len == fullAllocSize(task_buffer_fl.read_buffer_size));
const result: *HandleConnectionTask = std.mem.bytesAsValue(
HandleConnectionTask,
buffer[0..@sizeOf(HandleConnectionTask)],
);
result.* = .{
.task = .{ .callback = callback },
.task_buffer_fl = task_buffer_fl,
.data = data,
};
return result;
}

fn callback(task: *ThreadPool.Task) void {
const hct: *HandleConnectionTask = @fieldParentPtr("task", task);

const task_buffer_freelist = hct.task_buffer_fl;
defer task_buffer_freelist.push(hct);

const wait_group = hct.data.wait_group;
defer wait_group.finish();

handleConnection(
hct.data.conn,
hct.trailingReadBuffer(),
hct.data.logger,
hct.data.latest_snapshot_gen_info,
) catch |err| {
const logger = hct.data.logger;
if (@errorReturnTrace()) |stack_trace| {
logger.err().logf("{}\n{}", .{ err, stack_trace });
} else {
logger.err().logf("{}", .{err});
}
};
}
};

fn handleConnection(
conn: std.net.Server.Connection,
read_buffer: []u8,
logger: ScopedLogger,
latest_snapshot_gen_info_rw: *sig.sync.RwMux(?SnapshotGenerationInfo),
) !void {
var http_server = std.http.Server.init(conn, read_buffer);
var request = try http_server.receiveHead();
const head = request.head;
switch (head.method) {
.POST => {
logger.err().logf("{} tried to invoke our RPC", .{conn.address});
return try request.respond("RPCs are not yet implemented", .{
.status = .service_unavailable,
.keep_alive = false,
});
},
.GET => {
if (std.mem.startsWith(u8, head.target, "/")) {

// we hold the lock for the entirety of this process in order to prevent
// the snapshot generation process from deleting the associated snapshot.
const maybe_latest_snapshot_gen_info, //
var latest_snapshot_info_lg //
= latest_snapshot_gen_info_rw.readWithLock();
defer latest_snapshot_info_lg.unlock();

const full_info: ?FullSnapshotFileInfo, //
const inc_info: ?IncrementalSnapshotFileInfo //
= blk: {
const latest_snapshot_gen_info = maybe_latest_snapshot_gen_info.* orelse
break :blk .{ null, null };
const latest_full = latest_snapshot_gen_info.full;
const full_info: FullSnapshotFileInfo = .{
.slot = latest_full.slot,
.hash = latest_full.hash,
};
const latest_incremental = latest_snapshot_gen_info.inc orelse
break :blk .{ full_info, null };
const inc_info: IncrementalSnapshotFileInfo = .{
.base_slot = latest_full.slot,
.slot = latest_incremental.slot,
.hash = latest_incremental.hash,
};
break :blk .{ full_info, inc_info };
};

if (full_info) |full| {
const full_archive_name_bounded = full.snapshotNameStr();
const full_archive_name = full_archive_name_bounded.constSlice();
if (std.mem.eql(u8, head.target[1..], full_archive_name)) {
@panic("TODO: send full snapshot");
}
}

if (inc_info) |inc| {
const inc_archive_name_bounded = inc.snapshotNameStr();
const inc_archive_name = inc_archive_name_bounded.constSlice();
if (std.mem.eql(u8, head.target[1..], inc_archive_name)) {
@panic("TODO: send inc snapshot");
}
}
}

logger.err().logf(
"{} requested an unrecognized resource 'GET {s}'",
.{ conn.address, head.target },
);
},
else => {
logger.err().logf(
"{} made an unrecognized request '{} {s}'",
.{ conn.address, methodFmt(head.method), head.target },
);
},
}
try request.respond("", .{
.status = .not_found,
.keep_alive = false,
});
}

fn methodFmt(method: std.http.Method) MethodFmt {
return .{ .method = method };
}
const MethodFmt = struct {
method: std.http.Method,
pub fn format(
fmt: MethodFmt,
comptime fmt_str: []const u8,
fmt_options: std.fmt.FormatOptions,
writer: anytype,
) @TypeOf(writer).Error!void {
_ = fmt_options;
if (fmt_str.len != 0) std.fmt.invalidFmtError(fmt_str, fmt);
try fmt.method.write(writer);
}
};

0 comments on commit ac1fe4a

Please sign in to comment.