Skip to content

Commit

Permalink
minor refactoring to lapis.db.mysql, use metatable to inherit base db
Browse files Browse the repository at this point in the history
  • Loading branch information
leafo committed Feb 15, 2023
1 parent eb5cbc1 commit af1f4f0
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 51 deletions.
51 changes: 27 additions & 24 deletions lapis/db/mysql.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ end
local concat
concat = table.concat
local unpack = unpack or table.unpack
local FALSE, NULL, TRUE, build_helpers, format_date, is_raw, raw, is_list, list, is_encodable, clause, is_clause
local NULL, build_helpers, is_raw, is_list
do
local _obj_0 = require("lapis.db.base")
FALSE, NULL, TRUE, build_helpers, format_date, is_raw, raw, is_list, list, is_encodable, clause, is_clause = _obj_0.FALSE, _obj_0.NULL, _obj_0.TRUE, _obj_0.build_helpers, _obj_0.format_date, _obj_0.is_raw, _obj_0.raw, _obj_0.is_list, _obj_0.list, _obj_0.is_encodable, _obj_0.clause, _obj_0.is_clause
NULL, build_helpers, is_raw, is_list = _obj_0.NULL, _obj_0.build_helpers, _obj_0.is_raw, _obj_0.is_list
end
local logger = require("lapis.logging")
local conn
local BACKENDS, set_raw_query, get_raw_query, escape_literal, escape_identifier, connect, raw_query, interpolate_query, encode_values, encode_assigns, encode_clause, append_all, add_cond, query, _select, _insert, _update, _delete, _truncate
BACKENDS = {
local active_connection
local connect, raw_query
local BACKENDS = {
luasql = function()
local config = require("lapis.config").get()
local mysql_config = assert(config.mysql, "missing mysql configuration")
Expand All @@ -30,14 +30,14 @@ BACKENDS = {
table.insert(conn_opts, mysql_config.port)
end
end
conn = assert(luasql:connect(unpack(conn_opts)))
active_connection = assert(luasql:connect(unpack(conn_opts)))
return function(q)
logger.query(q)
local cur = assert(conn:execute(q))
local cur = assert(active_connection:execute(q))
local has_rows = type(cur) ~= "number"
local result = {
affected_rows = has_rows and cur:numrows() or cur,
last_auto_id = conn:getlastautoid()
last_auto_id = active_connection:getlastautoid()
}
if has_rows then
local colnames = cur:getcolnames()
Expand Down Expand Up @@ -156,19 +156,22 @@ BACKENDS = {
end
end
}
local set_raw_query
set_raw_query = function(fn)
raw_query = fn
end
local get_raw_query
get_raw_query = function()
return raw_query
end
local escape_literal
escape_literal = function(val)
local _exp_0 = type(val)
if "number" == _exp_0 then
return tostring(val)
elseif "string" == _exp_0 then
if conn then
return "'" .. tostring(conn:escape(val)) .. "'"
if active_connection then
return "'" .. tostring(active_connection:escape(val)) .. "'"
else
if ngx then
return ngx.quote_sql_str(val)
Expand Down Expand Up @@ -206,6 +209,7 @@ escape_literal = function(val)
end
return error("don't know how to escape value: " .. tostring(val))
end
local escape_identifier
escape_identifier = function(ident)
if is_raw(ident) then
return ident[1]
Expand Down Expand Up @@ -234,12 +238,14 @@ raw_query = function(...)
connect()
return raw_query(...)
end
interpolate_query, encode_values, encode_assigns, encode_clause = build_helpers(escape_literal, escape_identifier)
local interpolate_query, encode_values, encode_assigns, encode_clause = build_helpers(escape_literal, escape_identifier)
local append_all
append_all = function(t, ...)
for i = 1, select("#", ...) do
t[#t + 1] = select(i, ...)
end
end
local add_cond
add_cond = function(buffer, cond, ...)
append_all(buffer, " WHERE ")
local _exp_0 = type(cond)
Expand All @@ -249,15 +255,18 @@ add_cond = function(buffer, cond, ...)
return append_all(buffer, interpolate_query(cond, ...))
end
end
local query
query = function(str, ...)
if select("#", ...) > 0 then
str = interpolate_query(str, ...)
end
return raw_query(str)
end
local _select
_select = function(str, ...)
return query("SELECT " .. str, ...)
end
local _insert
_insert = function(tbl, values, ...)
local buff = {
"INSERT INTO ",
Expand All @@ -267,6 +276,7 @@ _insert = function(tbl, values, ...)
encode_values(values, buff)
return raw_query(concat(buff))
end
local _update
_update = function(table, values, cond, ...)
local buff = {
"UPDATE ",
Expand All @@ -279,6 +289,7 @@ _update = function(table, values, cond, ...)
end
return raw_query(concat(buff))
end
local _delete
_delete = function(table, cond, ...)
local buff = {
"DELETE FROM ",
Expand All @@ -289,30 +300,20 @@ _delete = function(table, cond, ...)
end
return raw_query(concat(buff))
end
local _truncate
_truncate = function(table)
return raw_query("TRUNCATE " .. escape_identifier(table))
end
return {
return setmetatable({
__type = "mysql",
connect = connect,
NULL = NULL,
TRUE = TRUE,
FALSE = FALSE,
raw = raw,
is_raw = is_raw,
list = list,
is_list = is_list,
clause = clause,
is_clause = is_clause,
is_encodable = is_encodable,
encode_values = encode_values,
encode_assigns = encode_assigns,
encode_clause = encode_clause,
interpolate_query = interpolate_query,
query = query,
escape_literal = escape_literal,
escape_identifier = escape_identifier,
format_date = format_date,
set_raw_query = set_raw_query,
get_raw_query = get_raw_query,
parse_clause = function()
Expand All @@ -324,4 +325,6 @@ return {
delete = _delete,
truncate = _truncate,
BACKENDS = BACKENDS
}
}, {
__index = require("lapis.db.base")
})
48 changes: 23 additions & 25 deletions lapis/db/mysql.moon
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,19 @@ import concat from table
unpack = unpack or table.unpack

import
FALSE
NULL
TRUE
build_helpers
format_date
is_raw
raw
is_list
list
is_encodable
clause
is_clause
from require "lapis.db.base"

logger = require "lapis.logging"

local conn
local *
-- NOTE: active connection only stored in local with luasql, otherwise request
-- context is used to store connection
local active_connection

local connect, raw_query

BACKENDS = {
luasql: ->
Expand All @@ -35,16 +30,20 @@ BACKENDS = {
table.insert conn_opts, mysql_config.host
if mysql_config.port then table.insert conn_opts, mysql_config.port

conn = assert luasql\connect unpack(conn_opts)
-- Note that connection is established up front. This is
-- necessary since connection is used for escaping literal when
-- using lua sql. This is distinct from ngx mode which lazily
-- establishes connection on first query
active_connection = assert luasql\connect unpack(conn_opts)

(q) ->
logger.query q
cur = assert conn\execute q
cur = assert active_connection\execute q
has_rows = type(cur) != "number"

result = {
affected_rows: has_rows and cur\numrows! or cur
last_auto_id: conn\getlastautoid!
last_auto_id: active_connection\getlastautoid!
}

if has_rows
Expand Down Expand Up @@ -152,8 +151,8 @@ escape_literal = (val) ->
when "number"
return tostring val
when "string"
if conn
return "'#{conn\escape val}'"
if active_connection
return "'#{active_connection\escape val}'"
else if ngx
return ngx.quote_sql_str(val)
else
Expand Down Expand Up @@ -268,17 +267,17 @@ _truncate = (table) ->
--
-- }

{
setmetatable {
__type: "mysql"

:connect
:NULL, :TRUE, :FALSE

:raw, :is_raw
:list, :is_list
:clause, :is_clause

:is_encodable
-- :NULL, :TRUE, :FALSE
-- :raw, :is_raw
-- :list, :is_list
-- :clause, :is_clause
-- :format_date
-- :is_encodable

:encode_values
:encode_assigns
Expand All @@ -289,8 +288,6 @@ _truncate = (table) ->
:escape_literal
:escape_identifier

:format_date

:set_raw_query
:get_raw_query

Expand All @@ -303,4 +300,5 @@ _truncate = (table) ->
truncate: _truncate

:BACKENDS
}
}, __index: require "lapis.db.base"

4 changes: 2 additions & 2 deletions lapis/spec/db.moon
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

-- NOTE: do not require lapis.db, etc. on the top level as it will try to bind
-- to the connection type to the closure.
-- NOTE: do not require config dependent modules on the top level here, eg.
-- lapis.db

import assert_env from require "lapis.environment"

Expand Down
4 changes: 4 additions & 0 deletions spec/mysql_spec.moon
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ value_table = { hello: db.FALSE, age: 34 }

TESTS = {
-- lapis.db.mysql
{
-> db.format_date 0
"1970-01-01 00:00:00"
}
{
-> db.escape_identifier "dad"
'`dad`'
Expand Down
4 changes: 4 additions & 0 deletions spec/postgres_spec.moon
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import sorted_pairs from require "spec.helpers"

TESTS = {
-- lapis.db.postgres
{
-> db.format_date 0
"1970-01-01 00:00:00"
}
{
-> db.escape_identifier "dad"
'"dad"'
Expand Down

0 comments on commit af1f4f0

Please sign in to comment.