From 63bd1f83172d313a80df75fabb5780e01b65d384 Mon Sep 17 00:00:00 2001 From: Max Kellermann Date: Mon, 4 Dec 2023 19:13:01 +0100 Subject: [PATCH] lua/Class: add WrapMethod() --- src/lua/Class.hxx | 16 +++++++++++ src/lua/io/XattrTable.cxx | 29 ++++++------------- src/lua/pg/Connection.cxx | 60 +++++++++++++++------------------------ src/lua/pg/Result.cxx | 24 +++++----------- 4 files changed, 55 insertions(+), 74 deletions(-) diff --git a/src/lua/Class.hxx b/src/lua/Class.hxx index 1f67f9e10..2177bc729 100644 --- a/src/lua/Class.hxx +++ b/src/lua/Class.hxx @@ -125,6 +125,22 @@ struct Class { return *(pointer)luaL_checkudata(L, idx, name); } + /** + * Generate a wrapper function which invokes Cast() and calls + * the specified method. + */ + template + static constexpr lua_CFunction WrapMethod() noexcept + requires std::is_class_v && + std::is_member_function_pointer_v { + static_assert(std::is_same_v); + + return [](lua_State *L) { + reference object = Cast(L, 1); + return (object.*method)(L); + }; + } + private: static int l_gc(lua_State *L) { const ScopeCheckStack check_stack(L); diff --git a/src/lua/io/XattrTable.cxx b/src/lua/io/XattrTable.cxx index f0273d381..3065b66d3 100644 --- a/src/lua/io/XattrTable.cxx +++ b/src/lua/io/XattrTable.cxx @@ -19,42 +19,31 @@ class XattrTable final { explicit XattrTable(UniqueFileDescriptor &&_fd) noexcept :fd(std::move(_fd)) {} - static int Close(lua_State *L); - static int Index(lua_State *L); - -private: - int _Close() { - fd.Close(); - return 0; - } - - int _Index(lua_State *L, const char *name); + int Close(lua_State *L); + int Index(lua_State *L); }; static constexpr char lua_xattr_table_class[] = "io.XattrTable"; using XattrTableClass = Lua::Class; -int +inline int XattrTable::Close(lua_State *L) { if (lua_gettop(L) != 1) return luaL_error(L, "Invalid parameters"); - return XattrTableClass::Cast(L, 1)._Close(); + fd.Close(); + return 0; } -int +inline int XattrTable::Index(lua_State *L) { if (lua_gettop(L) != 2) return luaL_error(L, "Invalid parameters"); - return XattrTableClass::Cast(L, 1)._Index(L, luaL_checkstring(L, 2)); -} + const char *const name = luaL_checkstring(L, 2); -inline int -XattrTable::_Index(lua_State *L, const char *name) -{ if (!fd.IsDefined()) luaL_error(L, "Stale object"); @@ -86,8 +75,8 @@ void InitXattrTable(lua_State *L) noexcept { XattrTableClass::Register(L); - SetField(L, RelativeStackIndex{-1}, "__index", XattrTable::Index); - SetField(L, RelativeStackIndex{-1}, "__close", XattrTable::Close); + SetField(L, RelativeStackIndex{-1}, "__index", XattrTableClass::WrapMethod<&XattrTable::Index>()); + SetField(L, RelativeStackIndex{-1}, "__close", XattrTableClass::WrapMethod<&XattrTable::Close>()); lua_pop(L, 1); } diff --git a/src/lua/pg/Connection.cxx b/src/lua/pg/Connection.cxx index ee2bda0be..faadb612a 100644 --- a/src/lua/pg/Connection.cxx +++ b/src/lua/pg/Connection.cxx @@ -125,28 +125,25 @@ class PgConnection final : Pg::SharedConnectionHandler { return connection.GetEventLoop(); } -private: - static int Execute(lua_State *L); - int Execute(lua_State *L, int sql, int params); - static int Listen(lua_State *L); - int Listen(lua_State *L, int name_idx, int handler_idx); + int Execute(lua_State *L); + int Listen(lua_State *L); +private: /* virtual methods from class Pg::SharedConnectionHandler */ void OnPgConnect() override; void OnPgNotify(const char *name) override; void OnPgError(std::exception_ptr e) noexcept override; - -public: - static constexpr struct luaL_Reg methods [] = { - {"execute", Execute}, - {"listen", Listen}, - {nullptr, nullptr} - }; }; static constexpr char lua_pg_connection_class[] = "pg.Connection"; using PgConnectionClass = Lua::Class; +static constexpr struct luaL_Reg lua_pg_connection_methods[] = { + {"execute", PgConnectionClass::WrapMethod<&PgConnection::Execute>()}, + {"listen", PgConnectionClass::WrapMethod<&PgConnection::Listen>()}, + {nullptr, nullptr} +}; + class PgRequest final : public Pg::SharedConnectionQuery, Pg::AsyncResultHandler { @@ -243,15 +240,6 @@ static constexpr char lua_pg_request_class[] = "pg.Request"; using PgRequestClass = Lua::Class; inline int -PgConnection::Execute(lua_State *L, int sql, int params) -{ - auto *request = PgRequestClass::New(L, L, connection, - sql, params); - connection.ScheduleQuery(*request); - return lua_yield(L, 1); -} - -int PgConnection::Execute(lua_State *L) { if (lua_gettop(L) < 2) @@ -268,13 +256,23 @@ PgConnection::Execute(lua_State *L) params = 3; } - auto &connection = PgConnectionClass::Cast(L, 1); - return connection.Execute(L, sql, params); + auto *request = PgRequestClass::New(L, L, connection, + sql, params); + connection.ScheduleQuery(*request); + return lua_yield(L, 1); } inline int -PgConnection::Listen(lua_State *L, int name_idx, int handler_idx) +PgConnection::Listen(lua_State *L) { + if (lua_gettop(L) < 3) + return luaL_error(L, "Not enough parameters"); + if (lua_gettop(L) > 3) + return luaL_error(L, "Too many parameters"); + + constexpr int name_idx = 2; + constexpr int handler_idx = 3; + const char *name = luaL_checkstring(L, name_idx); luaL_checktype(L, 3, LUA_TFUNCTION); @@ -290,18 +288,6 @@ PgConnection::Listen(lua_State *L, int name_idx, int handler_idx) return 0; } -int -PgConnection::Listen(lua_State *L) -{ - if (lua_gettop(L) < 3) - return luaL_error(L, "Not enough parameters"); - if (lua_gettop(L) > 3) - return luaL_error(L, "Too many parameters"); - - auto &connection = PgConnectionClass::Cast(L, 1); - return connection.Listen(L, 2, 3); -} - void PgConnection::OnPgConnect() { @@ -457,7 +443,7 @@ void InitPgConnection(lua_State *L) noexcept { PgConnectionClass::Register(L); - luaL_newlib(L, PgConnection::methods); + luaL_newlib(L, lua_pg_connection_methods); lua_setfield(L, -2, "__index"); lua_pop(L, 1); diff --git a/src/lua/pg/Result.cxx b/src/lua/pg/Result.cxx index e04157622..80267594e 100644 --- a/src/lua/pg/Result.cxx +++ b/src/lua/pg/Result.cxx @@ -23,29 +23,19 @@ class PgResult final { explicit PgResult(Pg::Result &&_result) noexcept :result(std::move(_result)) {} -private: - static int Fetch(lua_State *L); - int _Fetch(lua_State *L); - -public: - static constexpr struct luaL_Reg methods [] = { - {"fetch", Fetch}, - {nullptr, nullptr} - }; + int Fetch(lua_State *L); }; static constexpr char lua_pg_result_class[] = "pg.Result"; using PgResultClass = Lua::Class; -int -PgResult::Fetch(lua_State *L) -{ - auto &result = PgResultClass::Cast(L, 1); - return result._Fetch(L); -} +static constexpr struct luaL_Reg lua_pg_result_methods [] = { + {"fetch", PgResultClass::WrapMethod<&PgResult::Fetch>()}, + {nullptr, nullptr} +}; inline int -PgResult::_Fetch(lua_State *L) +PgResult::Fetch(lua_State *L) { if (next_row >= result.GetRowCount()) return 0; @@ -86,7 +76,7 @@ void InitPgResult(lua_State *L) noexcept { PgResultClass::Register(L); - luaL_newlib(L, PgResult::methods); + luaL_newlib(L, lua_pg_result_methods); lua_setfield(L, -2, "__index"); lua_pop(L, 1); }