Skip to content

Commit

Permalink
ft: add struct for keeping state of ClientHandler
Browse files Browse the repository at this point in the history
This is used to allow Elixir compiler to provide minuscule optimisation,
as with `%{} = data` it will no longer need to check if `data` is an
atom when doing `data.foo`. It also makes it harder to accidentally
access non-existing field as new Elixir versions can check accessed
fields names.
  • Loading branch information
hauleth committed Jun 27, 2024
1 parent 18ad803 commit 943a76a
Showing 1 changed file with 70 additions and 53 deletions.
123 changes: 70 additions & 53 deletions lib/supavisor/client_handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,31 @@ defmodule Supavisor.ClientHandler do
@behaviour :ranch_protocol
@behaviour :partisan_gen_statem

defstruct [
:id,
:sock,
:trans,
:db_pid,
:tenant,
:user,
:pool,
:manager,
:query_start,
:timeout,
:ps,
:ssl,
:auth_secrets,
:proxy_type,
:mode,
:stats,
:idle_timeout,
:db_name,
:last_query,
:heartbeat_interval,
:connection_start,
:log_level
]

alias Supavisor, as: S
alias Supavisor.DbHandler, as: Db
alias Supavisor.Helpers, as: H
Expand Down Expand Up @@ -45,7 +70,7 @@ defmodule Supavisor.ClientHandler do
:ok = trans.setopts(sock, active: true)
Logger.debug("ClientHandler is: #{inspect(self())}")

data = %{
data = %__MODULE__{
id: nil,
sock: {:gen_tcp, sock},
trans: trans,
Expand Down Expand Up @@ -74,7 +99,7 @@ defmodule Supavisor.ClientHandler do
end

@impl true
def handle_event(:info, {_proto, _, <<"GET", _::binary>>}, :exchange, data) do
def handle_event(:info, {_proto, _, <<"GET", _::binary>>}, :exchange, %__MODULE__{} = data) do
Logger.debug("ClientHandler: Client is trying to request HTTP")

HH.sock_send(
Expand All @@ -93,7 +118,7 @@ defmodule Supavisor.ClientHandler do
end

# send cancel request to db
def handle_event(:info, :cancel_query, :busy, data) do
def handle_event(:info, :cancel_query, :busy, %__MODULE__{} = data) do
key = {data.tenant, data.db_pid}
Logger.debug("ClientHandler: Cancel query for #{inspect(key)}")
{_pool, db_pid} = data.db_pid
Expand All @@ -111,7 +136,7 @@ defmodule Supavisor.ClientHandler do
:keep_state_and_data
end

def handle_event(:info, {:tcp, _, <<_::64>>}, :exchange, %{sock: sock} = data) do
def handle_event(:info, {:tcp, _, <<_::64>>}, :exchange, %__MODULE__{sock: sock} = data) do
Logger.debug("ClientHandler: Client is trying to connect with SSL")

downstream_cert = H.downstream_cert()
Expand All @@ -131,15 +156,15 @@ defmodule Supavisor.ClientHandler do
{:ok, ssl_sock} ->
socket = {:ssl, ssl_sock}
:ok = HH.setopts(socket, active: true)
{:keep_state, %{data | sock: socket, ssl: true}}
{:keep_state, %__MODULE__{data | sock: socket, ssl: true}}

error ->
Logger.error("ClientHandler: SSL handshake error: #{inspect(error)}")
Telem.client_join(:fail, data.id)
{:stop, {:shutdown, :ssl_handshake_error}}
end
else
Logger.error(
Logger.warning(
"ClientHandler: User requested SSL connection but no downstream cert/key found"
)

Expand All @@ -148,7 +173,7 @@ defmodule Supavisor.ClientHandler do
end
end

def handle_event(:info, {_, _, bin}, :exchange, data) do
def handle_event(:info, {_, _, bin}, :exchange, %__MODULE__{} = data) do
case Server.decode_startup_packet(bin) do
{:ok, hello} ->
Logger.debug("ClientHandler: Client startup message: #{inspect(hello)}")
Expand Down Expand Up @@ -186,7 +211,7 @@ defmodule Supavisor.ClientHandler do
:internal,
{:hello, {type, {user, tenant_or_alias, db_name}}},
:exchange,
%{sock: sock} = data
%__MODULE__{sock: sock} = data
) do
sni_hostname = HH.try_get_sni(sock)

Expand Down Expand Up @@ -274,7 +299,7 @@ defmodule Supavisor.ClientHandler do
:internal,
{:handle, {method, secrets}, info},
_,
%{sock: sock} = data
%__MODULE__{sock: sock} = data
) do
Logger.debug("ClientHandler: Handle exchange, auth method: #{inspect(method)}")

Expand Down Expand Up @@ -341,7 +366,7 @@ defmodule Supavisor.ClientHandler do
end
end

def handle_event(:internal, :subscribe, _, data) do
def handle_event(:internal, :subscribe, _, %__MODULE__{} = data) do
Logger.debug("ClientHandler: Subscribe to tenant #{inspect(data.id)}")

with {:ok, sup} <-
Expand Down Expand Up @@ -374,7 +399,7 @@ defmodule Supavisor.ClientHandler do
end
end

def handle_event(:internal, {:greetings, ps}, _, %{sock: sock} = data) do
def handle_event(:internal, {:greetings, ps}, _, %__MODULE__{sock: sock} = data) do
{header, <<pid::32, key::32>> = payload} = Server.backend_key_data()
msg = [ps, [header, payload], Server.ready_for_query()]
:ok = HH.listen_cancel_query(pid, key)
Expand All @@ -387,7 +412,7 @@ defmodule Supavisor.ClientHandler do
{:keep_state_and_data, {:next_event, :internal, :subscribe}}
end

def handle_event(:timeout, :wait_ps, _, data) do
def handle_event(:timeout, :wait_ps, _, %__MODULE__{} = data) do
Logger.error(
"ClientHandler: Wait parameter status timeout, send default #{inspect(data.ps)}}"
)
Expand All @@ -396,12 +421,12 @@ defmodule Supavisor.ClientHandler do
{:keep_state_and_data, {:next_event, :internal, {:greetings, ps}}}
end

def handle_event(:timeout, :idle_terminate, _, data) do
def handle_event(:timeout, :idle_terminate, _, %__MODULE__{} = data) do
Logger.warning("ClientHandler: Terminate an idle connection by #{data.idle_timeout} timeout")
{:stop, {:shutdown, :idle_terminate}}
end

def handle_event(:timeout, :heartbeat_check, _, data) do
def handle_event(:timeout, :heartbeat_check, _, %__MODULE__{} = data) do
Logger.debug("ClientHandler: Send heartbeat to client")
HH.sock_send(data.sock, Server.application_name())
{:keep_state_and_data, {:timeout, data.heartbeat_interval, :heartbeat_check}}
Expand All @@ -415,22 +440,22 @@ defmodule Supavisor.ClientHandler do
end

# handle Sync message
def handle_event(:info, {proto, _, <<?S, 4::32>>}, :idle, data)
def handle_event(:info, {proto, _, <<?S, 4::32>>}, :idle, %__MODULE__{} = data)
when proto in [:tcp, :ssl] do
Logger.debug("ClientHandler: Receive sync")
:ok = HH.sock_send(data.sock, Server.ready_for_query())
{:keep_state_and_data, handle_actions(data)}
end

def handle_event(:info, {proto, _, <<?S, 4::32, _::binary>> = msg}, _, data)
def handle_event(:info, {proto, _, <<?S, 4::32, _::binary>> = msg}, _, %__MODULE__{} = data)
when proto in [:tcp, :ssl] do
Logger.debug("ClientHandler: Receive sync while not idle")
{_, db_pid} = data.db_pid
Db.cast(db_pid, self(), msg)
:keep_state_and_data
end

def handle_event(:info, {proto, _, <<?H, 4::32, _::binary>> = msg}, _, data)
def handle_event(:info, {proto, _, <<?H, 4::32, _::binary>> = msg}, _, %__MODULE__{} = data)
when proto in [:tcp, :ssl] do
Logger.debug("ClientHandler: Receive flush while not idle")
{_, db_pid} = data.db_pid
Expand All @@ -439,7 +464,7 @@ defmodule Supavisor.ClientHandler do
end

# incoming query with a single pool
def handle_event(:info, {proto, _, bin}, :idle, %{pool: pid} = data)
def handle_event(:info, {proto, _, bin}, :idle, %__MODULE__{pool: pid} = data)
when is_binary(bin) and is_pid(pid) do
ts = System.monotonic_time()
db_pid = db_checkout(:both, :on_query, data)
Expand All @@ -450,7 +475,7 @@ defmodule Supavisor.ClientHandler do
end

# incoming query with read/write pools
def handle_event(:info, {proto, _, bin}, :idle, data) do
def handle_event(:info, {proto, _, bin}, :idle, %__MODULE__{} = data) do
query_type =
with {:ok, payload} <- Client.get_payload(bin),
{:ok, statements} <- Supavisor.PgParser.statements(payload) do
Expand All @@ -477,7 +502,7 @@ defmodule Supavisor.ClientHandler do
end

# forward query to db
def handle_event(_, {proto, _, bin}, :busy, data)
def handle_event(_, {proto, _, bin}, :busy, %__MODULE__{} = data)
when proto in [:tcp, :ssl] do
{_, db_pid} = data.db_pid

Expand Down Expand Up @@ -517,21 +542,21 @@ defmodule Supavisor.ClientHandler do
end

# client closed connection
def handle_event(_, {closed, _}, _, data)
def handle_event(_, {closed, _}, _, %__MODULE__{} = data)
when closed in [:tcp_closed, :ssl_closed] do
Logger.debug("ClientHandler: #{closed} socket closed for #{inspect(data.tenant)}")
{:stop, {:shutdown, :socket_closed}}
end

# linked DbHandler went down
def handle_event(:info, {:EXIT, db_pid, reason}, _, data) do
def handle_event(:info, {:EXIT, db_pid, reason}, _, %__MODULE__{} = data) do
Logger.error("ClientHandler: DbHandler #{inspect(db_pid)} exited #{inspect(reason)}")
HH.sock_send(data.sock, Server.error_message("XX000", "DbHandler exited"))
{:stop, {:shutdown, :db_handler_exit}}
end

# pool's manager went down
def handle_event(:info, {:DOWN, _, _, _, reason}, state, data) do
def handle_event(:info, {:DOWN, _, _, _, reason}, state, %__MODULE__{} = data) do
Logger.error(
"ClientHandler: Manager #{inspect(data.manager)} went down #{inspect(reason)} state #{inspect(state)}"
)
Expand All @@ -554,21 +579,21 @@ defmodule Supavisor.ClientHandler do
end

# emulate handle_cast
def handle_event(:cast, {:client_cast, bin, status}, _, data) do
Logger.debug("ClientHandler: --> --> bin #{inspect(byte_size(bin))} bytes")
def handle_event(:cast, {:client_cast, bin, status}, _, %__MODULE__{} = data) do
Logger.debug("ClientHandler: --> --> bin #{byte_size(bin)} bytes")

case status do
:ready_for_query ->
Logger.debug("ClientHandler: Client is ready")

db_pid = handle_db_pid(data.mode, data.pool, data.db_pid)

{_, stats} = Telem.network_usage(:client, data.sock, data.id, data.stats)
# {_, stats} = Telem.network_usage(:client, data.sock, data.id, data.stats)

Telem.client_query_time(data.query_start, data.id)
:ok = HH.sock_send(data.sock, bin)
actions = handle_actions(data)
{:next_state, :idle, %{data | db_pid: db_pid, stats: stats}, actions}
{:next_state, :idle, %__MODULE__{data | db_pid: db_pid}, actions}

:continue ->
Logger.debug("ClientHandler: Client is not ready")
Expand All @@ -590,8 +615,8 @@ defmodule Supavisor.ClientHandler do
end

# emulate handle_call
def handle_event({:call, from}, {:client_call, bin, _}, _, data) do
Logger.debug("ClientHandler: --> --> bin call #{inspect(byte_size(bin))} bytes")
def handle_event({:call, from}, {:client_call, bin, _}, _, %__MODULE__{} = data) do
Logger.debug("ClientHandler: --> --> bin call #{byte_size(bin)} bytes")
{:keep_state_and_data, {:reply, from, HH.sock_send(data.sock, bin)}}
end

Expand All @@ -612,7 +637,7 @@ defmodule Supavisor.ClientHandler do
def terminate(
{:timeout, {_, _, [_, {:checkout, _, _}, _]}},
_,
data
%__MODULE__{} = data
) do
msg =
case data.mode do
Expand Down Expand Up @@ -789,15 +814,15 @@ defmodule Supavisor.ClientHandler do

defp handle_db_pid(:session, _, db_pid), do: db_pid

defp update_user_data(data, info, user, id, db_name, mode) do
defp update_user_data(%__MODULE__{} = data, info, user, id, db_name, mode) do
proxy_type =
if info.tenant.require_user do
:password
else
:auth_query
end

%{
%__MODULE__{
data
| tenant: info.tenant.external_id,
user: user,
Expand Down Expand Up @@ -931,11 +956,6 @@ defmodule Supavisor.ClientHandler do

def try_get_sni(_), do: nil

@spec timeout_check(atom, non_neg_integer) :: {:timeout, non_neg_integer, atom}
defp timeout_check(key, timeout) do
{:timeout, timeout, key}
end

defp db_pid_meta({_, {_, pid}} = _key) do
rkey = Supavisor.Registry.PoolPids
fnode = node(pid)
Expand All @@ -948,10 +968,10 @@ defmodule Supavisor.ClientHandler do
end

@spec handle_prepared_statements({pid, pid}, binary, map) :: :ok | nil
defp handle_prepared_statements({_, pid}, bin, %{mode: :transaction} = data) do
defp handle_prepared_statements({_, pid}, bin, %__MODULE__{mode: :transaction} = data) do
with {:ok, payload} <- Client.get_payload(bin),
{:ok, statamets} <- Supavisor.PgParser.statements(payload),
true <- Enum.member?([["PrepareStmt"], ["DeallocateStmt"]], statamets) do
{:ok, statements} when statements in [["PrepareStmt"], ["DeallocateStmt"]] <-
Supavisor.PgParser.statements(payload) do
Logger.info("ClientHandler: Handle prepared statement #{inspect(payload)}")

GenServer.call(data.pool, :get_all_workers)
Expand All @@ -976,18 +996,15 @@ defmodule Supavisor.ClientHandler do
defp handle_prepared_statements(_, _, _), do: nil

@spec handle_actions(map) :: [{:timeout, non_neg_integer, atom}]
defp handle_actions(data) do
Enum.flat_map(data, fn
{:heartbeat_interval, v} = t when v > 0 ->
Logger.debug("ClientHandler: Call timeout #{inspect(t)}")
[timeout_check(:heartbeat_check, v)]

{:idle_timeout, v} = t when v > 0 ->
Logger.debug("ClientHandler: Call timeout #{inspect(t)}")
[timeout_check(:idle_terminate, v)]

_ ->
[]
end)
defp handle_actions(%__MODULE__{} = data) do
heartbeat =
if data.heartbeat_interval > 0,
do: [{:timeout, data.heartbeat_interval, :heartbeat_check}],
else: []

idle =
if data.idle_timeout > 0, do: [{:timeout, data.idle_timeout, :idle_timeout}], else: []

heartbeat ++ idle
end
end

0 comments on commit 943a76a

Please sign in to comment.