Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rpc,accountsdb): RPC Server for snapshot propagation #369

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/accountsdb/db.zig
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ pub const AccountsDB = struct {
/// Used to potentially skip the first `computeAccountHashesAndLamports`.
first_snapshot_load_info: RwMux(?SnapshotGenerationInfo),
/// Represents the largest slot info used to generate a full snapshot, and optionally an incremental snapshot relative to it, which currently exists.
/// It also protects access to the snapshot archive files it refers to - as in, the caller who has a lock on this has a lock on the snapshot archives.
latest_snapshot_gen_info: RwMux(?SnapshotGenerationInfo),

// TODO: populate this during snapshot load
Expand Down
6 changes: 5 additions & 1 deletion src/rpc/lib.zig
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
pub const client = @import("client.zig");
pub const server = @import("server.zig");

pub const request = @import("request.zig");
pub const response = @import("response.zig");
pub const types = @import("types.zig");

pub const ClusterType = types.ClusterType;
pub const Client = client.Client;
pub const Server = server.Server;

pub const ClusterType = types.ClusterType;
pub const Request = request.Request;
pub const Response = response.Response;
375 changes: 375 additions & 0 deletions src/rpc/server.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,375 @@
const std = @import("std");
const assert = std.debug.assert;
const sig = @import("../sig.zig");

const SnapshotGenerationInfo = sig.accounts_db.AccountsDB.SnapshotGenerationInfo;
const FullSnapshotFileInfo = sig.accounts_db.snapshots.FullSnapshotFileInfo;
const IncrementalSnapshotFileInfo = sig.accounts_db.snapshots.IncrementalSnapshotFileInfo;
const ThreadPool = sig.sync.ThreadPool;

const logger_scope = "rpc.Server";
const ScopedLogger = sig.trace.log.ScopedLogger(logger_scope);

pub const Server = struct {
allocator: std.mem.Allocator,
logger: ScopedLogger,

snapshot_dir: std.fs.Dir,
latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo),

/// Wait group for all currently running tasks, used to wait for
/// all of them to finish before deinitializing.
wait_group: std.Thread.WaitGroup,
exit_internal: std.atomic.Value(bool),
exit_external: *std.atomic.Value(bool),
thread_pool: *ThreadPool,

tcp_server: std.net.Server,
task_buffer_fl: TaskBufferFreelist,

pub const InitParams = struct {
allocator: std.mem.Allocator,
logger: sig.trace.Logger,

/// Not closed by the `Server`, but must live at least as long as it.
snapshot_dir: std.fs.Dir,
/// Should reflect the latest generated snapshot eligible for propagation at any
/// given time with respect to the contents of the specified `snapshot_dir`.
latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo),

exit: *std.atomic.Value(bool),
thread_pool: *ThreadPool,

/// The socket address to listen on for incoming HTTP and/or RPC requests.
socket_addr: sig.net.SocketAddr,
/// The size for the read buffer allocated to every request.
/// Clamped to be greater than or equal to `MIN_READ_BUFFER_SIZE`.
read_buffer_size: usize,

pub const MIN_READ_BUFFER_SIZE = 256;
};

pub const InitError = std.net.Address.ListenError;

pub fn init(params: InitParams) !Server {
var tcp_server = try params.socket_addr.toAddress().listen(.{ .force_nonblocking = true });
errdefer tcp_server.deinit();
return .{
.allocator = params.allocator,
.logger = params.logger.withScope(logger_scope),

.snapshot_dir = params.snapshot_dir,
.latest_snapshot_gen_info = params.latest_snapshot_gen_info,

.wait_group = .{},
.exit_internal = std.atomic.Value(bool).init(false),
.exit_external = params.exit,
.thread_pool = params.thread_pool,

.tcp_server = tcp_server,

.task_buffer_fl = .{
.read_buffer_size = @max(params.read_buffer_size, InitParams.MIN_READ_BUFFER_SIZE),
.head_mtx = .{},
.head = null,
},
};
}

pub fn deinit(server: *Server) void {
server.exit_internal.store(true, .release);
server.wait_group.wait();
server.tcp_server.deinit();
server.task_buffer_fl.deinit(server.allocator);
}

/// The main loop which handles incoming connections.
pub fn serve(server: *Server) !void {
while (true) {
if (server.exit_internal.load(.acquire)) break;
if (server.exit_external.load(.acquire)) break;

const conn = server.tcp_server.accept() catch |err| switch (err) {
error.ProcessFdQuotaExceeded,
error.SystemFdQuotaExceeded,
error.SystemResources,
error.ProtocolFailure,
error.BlockedByFirewall,
error.NetworkSubsystemFailed,
error.Unexpected,
=> |e| return e,

error.FileDescriptorNotASocket,
error.SocketNotListening,
error.OperationNotSupported,
=> unreachable, // Improperly initialized server.

error.WouldBlock,
=> continue,

error.ConnectionResetByPeer,
error.ConnectionAborted,
=> |e| {
server.logger.warn().logf("{}", .{e});
continue;
},
};
errdefer conn.stream.close();

server.wait_group.start();
errdefer server.wait_group.finish();

const new_hct = try server.task_buffer_fl.popOrCreateAndInit(server.allocator, .{
.wait_group = &server.wait_group,
.conn = conn,
.logger = server.logger,
.latest_snapshot_gen_info = server.latest_snapshot_gen_info,
});
server.thread_pool.schedule(ThreadPool.Batch.from(&new_hct.task));
}
}
};

const TaskBufferFreelist = struct {
/// Does/must not change after init.
read_buffer_size: usize,
head_mtx: std.Thread.Mutex,
head: ?*TaskBufferNode,

/// Assumes `fl.mutex` is already locked.
fn deinit(fl: *TaskBufferFreelist, allocator: std.mem.Allocator) void {
defer fl.head_mtx.unlock();
while (fl.head) |current| {
fl.head = current.next;
destroyNode(allocator, fl.read_buffer_size, current);
}
}

fn popOrCreateAndInit(
fl: *TaskBufferFreelist,
allocator: std.mem.Allocator,
init_data: HandleConnectionTask.Data,
) std.mem.Allocator.Error!*HandleConnectionTask {
return fl.popAndInit(init_data) orelse try fl.createNode(allocator, init_data);
}

fn createNode(
fl: *TaskBufferFreelist,
allocator: std.mem.Allocator,
init_data: HandleConnectionTask.Data,
) std.mem.Allocator.Error!*HandleConnectionTask {
const buffer_size = HandleConnectionTask.fullAllocSize(fl.read_buffer_size);
const alignment = @alignOf(HandleConnectionTask);
const buffer = try allocator.alignedAlloc(u8, alignment, buffer_size);
return HandleConnectionTask.initBuffer(buffer, fl, init_data);
}

fn popAndInit(
fl: *TaskBufferFreelist,
init_data: HandleConnectionTask.Data,
) ?*HandleConnectionTask {
fl.head_mtx.lock();
defer fl.head_mtx.unlock();

const popped = fl.head orelse return null;
fl.head = popped.next;

const buffer_ptr: [*]align(@alignOf(HandleConnectionTask)) u8 = @ptrCast(popped);
const buffer = buffer_ptr[0..HandleConnectionTask.fullAllocSize(fl.read_buffer_size)];
return HandleConnectionTask.initBuffer(buffer, fl, init_data);
}

fn push(fl: *TaskBufferFreelist, hct: *HandleConnectionTask) void {
const tbn: *TaskBufferNode = @ptrCast(hct);

fl.head_mtx.lock();
defer fl.head_mtx.unlock();

tbn.* = .{ .next = fl.head };
fl.head = tbn;
}

fn destroyNode(
allocator: std.mem.Allocator,
read_buffer_size: usize,
tbn: *TaskBufferNode,
) void {
const full_alloc_ptr: [*]align(@alignOf(HandleConnectionTask)) u8 = @ptrCast(tbn);
const full_alloc = full_alloc_ptr[0..HandleConnectionTask.fullAllocSize(read_buffer_size)];
allocator.free(full_alloc);
}

/// Secretly `*TaskBufferNode` = `HandleConnectionTask`.
const TaskBufferNode = extern struct {
next: ?*TaskBufferNode align(@alignOf(HandleConnectionTask)),

comptime {
assert(@sizeOf(TaskBufferNode) <= @sizeOf(HandleConnectionTask));
assert(@alignOf(TaskBufferNode) == @alignOf(HandleConnectionTask));
}
};
};

const HandleConnectionTask = struct {
task: ThreadPool.Task,
task_buffer_fl: *TaskBufferFreelist,
data: Data,

const Data = struct {
wait_group: *std.Thread.WaitGroup,
conn: std.net.Server.Connection,
logger: ScopedLogger,
latest_snapshot_gen_info: *sig.sync.RwMux(?SnapshotGenerationInfo),
};

fn fullAllocSize(read_buffer_size: usize) usize {
return @sizeOf(HandleConnectionTask) + read_buffer_size;
}

fn trailingReadBuffer(hct: *HandleConnectionTask) []u8 {
const full_alloc: [*]align(@alignOf(HandleConnectionTask)) u8 = @ptrCast(hct);
return full_alloc[@sizeOf(HandleConnectionTask)..][0..hct.task_buffer_fl.read_buffer_size];
}

fn initBuffer(
/// Must be `== fullAllocSize(task_buffer_fl.read_buffer_size)`
buffer: []align(@alignOf(HandleConnectionTask)) u8,
task_buffer_fl: *TaskBufferFreelist,
data: Data,
) *HandleConnectionTask {
assert(buffer.len == fullAllocSize(task_buffer_fl.read_buffer_size));
const result: *HandleConnectionTask = std.mem.bytesAsValue(
HandleConnectionTask,
buffer[0..@sizeOf(HandleConnectionTask)],
);
result.* = .{
.task = .{ .callback = callback },
.task_buffer_fl = task_buffer_fl,
.data = data,
};
return result;
}

fn callback(task: *ThreadPool.Task) void {
const hct: *HandleConnectionTask = @fieldParentPtr("task", task);

const wait_group = hct.data.wait_group;
defer wait_group.finish();

const task_buffer_freelist = hct.task_buffer_fl;
defer task_buffer_freelist.push(hct);

handleConnection(
hct.data.conn,
hct.trailingReadBuffer(),
hct.data.logger,
hct.data.latest_snapshot_gen_info,
) catch |err| {
const logger = hct.data.logger;
if (@errorReturnTrace()) |stack_trace| {
logger.err().logf("{}\n{}", .{ err, stack_trace });
} else {
logger.err().logf("{}", .{err});
}
};
}
};

fn handleConnection(
conn: std.net.Server.Connection,
read_buffer: []u8,
logger: ScopedLogger,
latest_snapshot_gen_info_rw: *sig.sync.RwMux(?SnapshotGenerationInfo),
) !void {
var http_server = std.http.Server.init(conn, read_buffer);
var request = try http_server.receiveHead();
const head = request.head;
switch (head.method) {
.POST => {
logger.err().logf("{} tried to invoke our RPC", .{conn.address});
return try request.respond("RPCs are not yet implemented", .{
.status = .service_unavailable,
.keep_alive = false,
});
},
.GET => {
if (std.mem.startsWith(u8, head.target, "/")) {

// we hold the lock for the entirety of this process in order to prevent
// the snapshot generation process from deleting the associated snapshot.
const maybe_latest_snapshot_gen_info, //
var latest_snapshot_info_lg //
= latest_snapshot_gen_info_rw.readWithLock();
defer latest_snapshot_info_lg.unlock();

const full_info: ?FullSnapshotFileInfo, //
const inc_info: ?IncrementalSnapshotFileInfo //
= blk: {
const latest_snapshot_gen_info = maybe_latest_snapshot_gen_info.* orelse
break :blk .{ null, null };
const latest_full = latest_snapshot_gen_info.full;
const full_info: FullSnapshotFileInfo = .{
.slot = latest_full.slot,
.hash = latest_full.hash,
};
const latest_incremental = latest_snapshot_gen_info.inc orelse
break :blk .{ full_info, null };
const inc_info: IncrementalSnapshotFileInfo = .{
.base_slot = latest_full.slot,
.slot = latest_incremental.slot,
.hash = latest_incremental.hash,
};
break :blk .{ full_info, inc_info };
};

if (full_info) |full| {
const full_archive_name_bounded = full.snapshotNameStr();
const full_archive_name = full_archive_name_bounded.constSlice();
if (std.mem.eql(u8, head.target[1..], full_archive_name)) {
@panic("TODO: send full snapshot");
}
}

if (inc_info) |inc| {
const inc_archive_name_bounded = inc.snapshotNameStr();
const inc_archive_name = inc_archive_name_bounded.constSlice();
if (std.mem.eql(u8, head.target[1..], inc_archive_name)) {
@panic("TODO: send inc snapshot");
}
}
}

logger.err().logf(
"{} requested an unrecognized resource 'GET {s}'",
.{ conn.address, head.target },
);
},
else => {
logger.err().logf(
"{} made an unrecognized request '{} {s}'",
.{ conn.address, methodFmt(head.method), head.target },
);
},
}
try request.respond("", .{
.status = .not_found,
.keep_alive = false,
});
}

fn methodFmt(method: std.http.Method) MethodFmt {
return .{ .method = method };
}
const MethodFmt = struct {
method: std.http.Method,
pub fn format(
fmt: MethodFmt,
comptime fmt_str: []const u8,
fmt_options: std.fmt.FormatOptions,
writer: anytype,
) @TypeOf(writer).Error!void {
_ = fmt_options;
if (fmt_str.len != 0) std.fmt.invalidFmtError(fmt_str, fmt);
try fmt.method.write(writer);
}
};
Loading
Loading