diff --git a/apps/api_accounts/lib/api_accounts/key.ex b/apps/api_accounts/lib/api_accounts/key.ex index 8ff20452..b0559d73 100644 --- a/apps/api_accounts/lib/api_accounts/key.ex +++ b/apps/api_accounts/lib/api_accounts/key.ex @@ -18,6 +18,8 @@ defmodule ApiAccounts.Key do field(:requested_date, :datetime) field(:approved, :boolean, default: false) field(:locked, :boolean, default: false) + field(:static_concurrent_limit, :integer) + field(:streaming_concurrent_limit, :integer) field(:daily_limit, :integer) field(:rate_request_pending, :boolean, default: false) field(:api_version, :string) @@ -28,7 +30,7 @@ defmodule ApiAccounts.Key do @doc false def changeset(struct, params \\ %{}) do fields = ~w( - created requested_date approved locked daily_limit rate_request_pending api_version description allowed_domains + created requested_date approved locked static_concurrent_limit streaming_concurrent_limit daily_limit rate_request_pending api_version description allowed_domains )a cast(struct, params, fields) end diff --git a/apps/api_web/config/config.exs b/apps/api_web/config/config.exs index 0758942f..7b63f9c4 100644 --- a/apps/api_web/config/config.exs +++ b/apps/api_web/config/config.exs @@ -23,6 +23,13 @@ config :api_web, ApiWeb.Endpoint, config :api_web, :signing_salt, "NdisAeo6Jf02spiKqa" +config :api_web, RateLimiter.Memcache, + connection_opts: [ + namespace: "api_dev_rate_limit", + hostname: "localhost", + coder: Memcache.Coder.JSON + ] + config :api_web, :rate_limiter, clear_interval: 60_000, limiter: ApiWeb.RateLimiter.ETS, @@ -30,6 +37,19 @@ config :api_web, :rate_limiter, max_registered_per_interval: 100_000, wait_time_ms: 0 +config :api_web, :rate_limiter_concurrent, + enabled: false, + memcache: false, + log_statistics: true, + limit_users: false, + # How many seconds tolerated when calculating whether a connection is still open + # 45 - 30 (see ApiWeb.EventStream.Initialize's timeout value) gives us a buffer of 15 seconds: + heartbeat_tolerance: 45, + # Default concurrent connections - these can be overridden on a per-key basis in the admin UI: + max_anon_static: 5, + max_registered_streaming: 10, + max_registered_static: 20 + config :api_web, ApiWeb.Plugs.ModifiedSinceHandler, check_caller: false config :api_web, :api_pipeline, diff --git a/apps/api_web/config/dev.exs b/apps/api_web/config/dev.exs index 98a75d8e..39d6dbe7 100644 --- a/apps/api_web/config/dev.exs +++ b/apps/api_web/config/dev.exs @@ -28,3 +28,5 @@ config :logger, :console, format: "[$level] $message\n", level: :debug # Do not configure such in production as keeping # and calculating stacktraces is usually expensive. config :phoenix, :stacktrace_depth, 20 + +config :api_web, :rate_limiter_concurrent, enabled: false, memcache: false diff --git a/apps/api_web/config/prod.exs b/apps/api_web/config/prod.exs index 3b4948d1..44f0f16d 100644 --- a/apps/api_web/config/prod.exs +++ b/apps/api_web/config/prod.exs @@ -54,6 +54,10 @@ config :ehmon, :report_mf, {:ehmon, :info_report} config :logster, :filter_parameters, ~w(password password_confirm) +config :api_web, :rate_limiter_concurrent, + enabled: true, + memcache: true + config :api_web, :rate_limiter, clear_interval: 60_000, max_anon_per_interval: 20, diff --git a/apps/api_web/config/test.exs b/apps/api_web/config/test.exs index e1ceb7c1..88af1335 100644 --- a/apps/api_web/config/test.exs +++ b/apps/api_web/config/test.exs @@ -13,7 +13,8 @@ config :api_web, :rate_limiter, config :api_web, RateLimiter.Memcache, connection_opts: [ namespace: "api_test_rate_limit", - hostname: "localhost" + hostname: "localhost", + coder: Memcache.Coder.JSON ] config :api_web, ApiWeb.Plugs.ModifiedSinceHandler, check_caller: true @@ -26,3 +27,7 @@ config :recaptcha, # Print only warnings and errors during test config :logger, level: :warn + +config :api_web, :rate_limiter_concurrent, + enabled: false, + memcache: false diff --git a/apps/api_web/lib/api_web.ex b/apps/api_web/lib/api_web.ex index f2e48450..54ec13a8 100644 --- a/apps/api_web/lib/api_web.ex +++ b/apps/api_web/lib/api_web.ex @@ -15,7 +15,9 @@ defmodule ApiWeb do # no cover children = [ # Start the endpoint when the application starts + ApiWeb.RateLimiter.Memcache.Supervisor, ApiWeb.RateLimiter, + ApiWeb.RateLimiter.RateLimiterConcurrent, {RequestTrack, [name: ApiWeb.RequestTrack]}, ApiWeb.EventStream.Supervisor, ApiWeb.Endpoint, diff --git a/apps/api_web/lib/api_web/api_controller_helpers.ex b/apps/api_web/lib/api_web/api_controller_helpers.ex index 874fc8dc..52085c55 100644 --- a/apps/api_web/lib/api_web/api_controller_helpers.ex +++ b/apps/api_web/lib/api_web/api_controller_helpers.ex @@ -27,6 +27,7 @@ defmodule ApiWeb.ApiControllerHelpers do plug(:split_include) plug(ApiWeb.Plugs.ModifiedSinceHandler, caller: __MODULE__) plug(ApiWeb.Plugs.RateLimiter) + plug(ApiWeb.Plugs.RateLimiterConcurrent) def index(conn, params), do: ApiControllerHelpers.index(__MODULE__, conn, params) diff --git a/apps/api_web/lib/api_web/event_stream.ex b/apps/api_web/lib/api_web/event_stream.ex index fcb090ce..91d7439f 100644 --- a/apps/api_web/lib/api_web/event_stream.ex +++ b/apps/api_web/lib/api_web/event_stream.ex @@ -7,6 +7,7 @@ defmodule ApiWeb.EventStream do import Plug.Conn alias __MODULE__.Supervisor alias ApiWeb.Plugs.CheckForShutdown + alias ApiWeb.RateLimiter.RateLimiterConcurrent require Logger @enforce_keys [:conn, :pid, :timeout] @@ -53,6 +54,13 @@ defmodule ApiWeb.EventStream do @spec hibernate_loop(state) :: Plug.Conn.t() def hibernate_loop(state) do + if Map.has_key?(state.conn.assigns, :api_user) do + # Update the concurrent rate limit cache to ensure any flushing doesn't impact long-running connections: + RateLimiterConcurrent.add_lock(state.conn.assigns.api_user, self(), true) + else + Logger.warn("#{__MODULE__} missing_api_user - cannot rate limit!") + end + case receive_result(state) do {:continue, state} -> :proc_lib.hibernate(__MODULE__, :hibernate_loop, [state]) @@ -130,6 +138,13 @@ defmodule ApiWeb.EventStream do end defp unsubscribe(state) do + if Map.has_key?(state.conn.assigns, :api_user) do + # clean up our concurrent connections lock: + RateLimiterConcurrent.remove_lock(state.conn.assigns.api_user, self(), true) + else + Logger.warn("#{__MODULE__} missing_api_user - cannot rate limit!") + end + # consume any extra messages received after unsubscribing receive do {:events, _} -> diff --git a/apps/api_web/lib/api_web/plugs/rate_limiter_concurrent.ex b/apps/api_web/lib/api_web/plugs/rate_limiter_concurrent.ex new file mode 100644 index 00000000..f99e89d0 --- /dev/null +++ b/apps/api_web/lib/api_web/plugs/rate_limiter_concurrent.ex @@ -0,0 +1,70 @@ +defmodule ApiWeb.Plugs.RateLimiterConcurrent do + @moduledoc """ + Plug to invoke the concurrent rate limiter. + """ + + import Plug.Conn + import Phoenix.Controller, only: [render: 3, put_view: 2] + + require Logger + + alias ApiWeb.RateLimiter.RateLimiterConcurrent + + @rate_limit_concurrent_config Application.compile_env!(:api_web, :rate_limiter_concurrent) + + def init(opts), do: opts + + def call(conn, _opts) do + if enabled?() do + event_stream? = Plug.Conn.get_req_header(conn, "accept") == ["text/event-stream"] + + {at_limit?, remaining, limit} = + RateLimiterConcurrent.check_concurrent_rate_limit(conn.assigns.api_user, event_stream?) + + if log_statistics?() do + Logger.info( + "ApiWeb.Plugs.RateLimiterConcurrent event=request_statistics api_user=#{conn.assigns.api_user.id} at_limit=#{at_limit?} remaining=#{remaining - 1} limit=#{limit} event_stream=#{event_stream?}" + ) + end + + # Allow negative limits to allow unlimited use: + if limit_users?() and limit >= 0 and at_limit? do + conn + |> put_concurrent_rate_limit_headers(limit, remaining) + |> put_status(429) + |> put_view(ApiWeb.ErrorView) + |> render("429.json-api", []) + |> halt() + else + RateLimiterConcurrent.add_lock(conn.assigns.api_user, self(), event_stream?) + + conn + |> put_concurrent_rate_limit_headers(limit, remaining - 1) + |> register_before_send(fn conn -> + RateLimiterConcurrent.remove_lock(conn.assigns.api_user, self(), event_stream?) + conn + end) + end + else + conn + end + end + + defp put_concurrent_rate_limit_headers(conn, limit, remaining) do + conn + |> put_resp_header("x-concurrent-ratelimit-limit", "#{limit}") + |> put_resp_header("x-concurrent-ratelimit-remaining", "#{remaining}") + end + + def enabled? do + Keyword.fetch!(@rate_limit_concurrent_config, :enabled) + end + + def limit_users? do + Keyword.fetch!(@rate_limit_concurrent_config, :limit_users) + end + + def log_statistics? do + Keyword.fetch!(@rate_limit_concurrent_config, :log_statistics) + end +end diff --git a/apps/api_web/lib/api_web/rate_limiter/memcache.ex b/apps/api_web/lib/api_web/rate_limiter/memcache.ex index 86278bc2..9413f70b 100644 --- a/apps/api_web/lib/api_web/rate_limiter/memcache.ex +++ b/apps/api_web/lib/api_web/rate_limiter/memcache.ex @@ -4,16 +4,16 @@ defmodule ApiWeb.RateLimiter.Memcache do """ @behaviour ApiWeb.RateLimiter.Limiter alias ApiWeb.RateLimiter.Memcache.Supervisor + use GenServer @impl ApiWeb.RateLimiter.Limiter def start_link(opts) do - clear_interval_ms = Keyword.fetch!(opts, :clear_interval) - clear_interval = div(clear_interval_ms, 1000) - - connection_opts = - [ttl: clear_interval * 2] ++ ApiWeb.config(RateLimiter.Memcache, :connection_opts) + GenServer.start_link(__MODULE__, opts, name: __MODULE__) + end - Supervisor.start_link(connection_opts) + @impl true + def init(opts) do + {:ok, opts} end @impl ApiWeb.RateLimiter.Limiter diff --git a/apps/api_web/lib/api_web/rate_limiter/memcache/supervisor.ex b/apps/api_web/lib/api_web/rate_limiter/memcache/supervisor.ex index 36da061a..9a137ca6 100644 --- a/apps/api_web/lib/api_web/rate_limiter/memcache/supervisor.ex +++ b/apps/api_web/lib/api_web/rate_limiter/memcache/supervisor.ex @@ -4,16 +4,34 @@ defmodule ApiWeb.RateLimiter.Memcache.Supervisor do """ @worker_count 5 @registry_name __MODULE__.Registry + @rate_limit_config Application.compile_env!(:api_web, :rate_limiter) - def start_link(connection_opts) do + use Agent + + def start_link(_) do registry = {Registry, keys: :unique, name: @registry_name} - workers = - for i <- 1..@worker_count do - Supervisor.child_spec({Memcache, [connection_opts, [name: worker_name(i)]]}, id: i) - end + children = + if memcache_required?() do + clear_interval_ms = Keyword.fetch!(@rate_limit_config, :clear_interval) + clear_interval = div(clear_interval_ms, 1000) + + connection_opts_config = + :api_web + |> Application.fetch_env!(RateLimiter.Memcache) + |> Keyword.fetch!(:connection_opts) + + connection_opts = [ttl: clear_interval * 2] ++ connection_opts_config - children = [registry | workers] + workers = + for i <- 1..@worker_count do + Supervisor.child_spec({Memcache, [connection_opts, [name: worker_name(i)]]}, id: i) + end + + [registry | workers] + else + [registry] + end Supervisor.start_link( children, @@ -31,7 +49,13 @@ defmodule ApiWeb.RateLimiter.Memcache.Supervisor do {:via, Registry, {@registry_name, index}} end - defp random_child do + defp memcache_required? do + (ApiWeb.RateLimiter.RateLimiterConcurrent.enabled?() and + ApiWeb.RateLimiter.RateLimiterConcurrent.memcache?()) or + ApiWeb.config(:rate_limiter, :limiter) == ApiWeb.RateLimiter.Memcache + end + + def random_child do worker_name(:rand.uniform(@worker_count)) end end diff --git a/apps/api_web/lib/api_web/rate_limiter/rate_limiter_concurrent.ex b/apps/api_web/lib/api_web/rate_limiter/rate_limiter_concurrent.ex new file mode 100644 index 00000000..42bbffad --- /dev/null +++ b/apps/api_web/lib/api_web/rate_limiter/rate_limiter_concurrent.ex @@ -0,0 +1,162 @@ +defmodule ApiWeb.RateLimiter.RateLimiterConcurrent do + @moduledoc """ + Rate limits a user's concurrent connections based on their API key or by their IP address if no + API key is provided. Split by static and event-stream requests. + """ + + use GenServer + require Logger + alias ApiWeb.RateLimiter.Memcache.Supervisor + + @rate_limit_concurrent_config Application.compile_env!(:api_web, :rate_limiter_concurrent) + @uuid_key "ApiWeb.RateLimiter.RateLimiterConcurrent_uuid" + def start_link([]), do: GenServer.start_link(__MODULE__, nil, name: __MODULE__) + + def init(_) do + uuid = UUID.uuid1() + :persistent_term.put(@uuid_key, uuid) + {:ok, %{uuid: uuid}} + end + + defp lookup(%ApiWeb.User{} = user, event_stream?) do + type = if event_stream?, do: "event_stream", else: "static" + "concurrent_#{user.id}_#{type}" + end + + def get_pid_key(pid) do + sub_key = pid |> :erlang.pid_to_list() |> to_string + get_uuid() <> sub_key + end + + defp get_uuid do + :persistent_term.get(@uuid_key) + end + + defp get_current_unix_ts do + System.system_time(:second) + end + + defp get_heartbeat_tolerance do + Keyword.fetch!(@rate_limit_concurrent_config, :heartbeat_tolerance) + end + + def mutate_locks(%ApiWeb.User{} = user, event_stream?, before_commit \\ fn value -> value end) do + if enabled?() do + current_timestamp = get_current_unix_ts() + heartbeat_tolerance = get_heartbeat_tolerance() + key = lookup(user, event_stream?) + + memcache_update(key, %{}, fn locks -> + valid_locks = + :maps.filter( + fn _, timestamp -> + timestamp + heartbeat_tolerance >= current_timestamp + end, + locks + ) + + before_commit.(valid_locks) + end) + else + {:ok, %{}} + end + end + + @spec check_concurrent_rate_limit(ApiWeb.User.t(), boolean()) :: + {false, number(), number()} | {true, number(), number()} + def check_concurrent_rate_limit(user, event_stream?) do + {:ok, locks} = user |> mutate_locks(event_stream?) + active_connections = locks |> Map.keys() |> length + + limit = + case {event_stream?, user.type} do + {true, :registered} -> + if user.streaming_concurrent_limit >= 0, + do: + max( + user.streaming_concurrent_limit || 0, + Keyword.fetch!( + @rate_limit_concurrent_config, + :max_registered_streaming + ) + ), + else: user.streaming_concurrent_limit + + {false, :registered} -> + if user.static_concurrent_limit >= 0, + do: + max( + user.static_concurrent_limit || 0, + Keyword.fetch!( + @rate_limit_concurrent_config, + :max_registered_static + ) + ), + else: user.static_concurrent_limit + + {false, :anon} -> + Keyword.fetch!( + @rate_limit_concurrent_config, + :max_anon_static + ) + end + + remaining = limit - active_connections + at_limit? = remaining <= 0 + {at_limit?, remaining, limit} + end + + def add_lock(%ApiWeb.User{} = user, pid, event_stream?) do + if enabled?() do + key = lookup(user, event_stream?) + pid_key = get_pid_key(pid) + timestamp = get_current_unix_ts() + + Logger.info( + "#{__MODULE__} event=add_lock user=#{inspect(user)} pid_key=#{pid_key} key=#{key} timestamp=#{timestamp}" + ) + + locks = + user |> mutate_locks(event_stream?, fn locks -> Map.put(locks, pid_key, timestamp) end) + + Logger.info( + "#{__MODULE__} event=add_lock_after user=#{inspect(user)} pid_key=#{pid_key} key=#{key} timestamp=#{timestamp} locks=#{inspect(locks)}" + ) + end + + nil + end + + def remove_lock( + %ApiWeb.User{} = user, + pid, + event_stream?, + pid_key \\ nil + ) do + if enabled?() and memcache?() do + key = lookup(user, event_stream?) + pid_key = if pid_key, do: pid_key, else: get_pid_key(pid) + + {:ok, _locks} = + mutate_locks(user, event_stream?, fn locks -> Map.delete(locks, pid_key) end) + + Logger.info( + "#{__MODULE__} event=remove_lock user_id=#{user.id} pid_key=#{pid_key} key=#{key}" + ) + end + + nil + end + + def enabled? do + Keyword.fetch!(@rate_limit_concurrent_config, :enabled) + end + + def memcache? do + Keyword.fetch!(@rate_limit_concurrent_config, :memcache) + end + + def memcache_update(key, default_value, update_fn) do + Memcache.cas(Supervisor.random_child(), key, update_fn, default: default_value) + end +end diff --git a/apps/api_web/lib/api_web/templates/admin/accounts/key/form.html.heex b/apps/api_web/lib/api_web/templates/admin/accounts/key/form.html.heex index 74f5562d..ea2700b0 100644 --- a/apps/api_web/lib/api_web/templates/admin/accounts/key/form.html.heex +++ b/apps/api_web/lib/api_web/templates/admin/accounts/key/form.html.heex @@ -18,6 +18,18 @@ <%= error_tag(f, :daily_limit) %> +
+ <%= label(f, :static_concurrent_limit, class: "control-label") %> + <%= text_input(f, :static_concurrent_limit, class: "form-control") %> + <%= error_tag(f, :static_concurrent_limit) %> +
+ +
+ <%= label(f, :streaming_concurrent_limit, class: "control-label") %> + <%= text_input(f, :streaming_concurrent_limit, class: "form-control") %> + <%= error_tag(f, :streaming_concurrent_limit) %> +
+
<%= label(f, :description, class: "control-label") %> <%= text_input(f, :description, class: "form-control") %> diff --git a/apps/api_web/lib/api_web/user.ex b/apps/api_web/lib/api_web/user.ex index 4db1183c..46c620f9 100644 --- a/apps/api_web/lib/api_web/user.ex +++ b/apps/api_web/lib/api_web/user.ex @@ -9,6 +9,8 @@ defmodule ApiWeb.User do :id, :type, :limit, + :static_concurrent_limit, + :streaming_concurrent_limit, :version, :allowed_domains ] @@ -36,6 +38,16 @@ defmodule ApiWeb.User do """ @type requests_per_day :: integer + @typedoc """ + The max number of event-stream requests that the user can make at once. + """ + @type streaming_concurrent_limit :: integer + + @typedoc """ + The max number of static requests that the user can make at once. + """ + @type static_concurrent_limit :: integer + @typedoc """ Whether the user is an anonymous user or a registered user. @@ -103,6 +115,8 @@ defmodule ApiWeb.User do def from_key(%ApiAccounts.Key{ key: key, daily_limit: limit, + static_concurrent_limit: static_concurrent_limit, + streaming_concurrent_limit: streaming_concurrent_limit, api_version: version, allowed_domains: allowed_domains }) do @@ -111,6 +125,8 @@ defmodule ApiWeb.User do %__MODULE__{ id: key, limit: limit, + static_concurrent_limit: static_concurrent_limit, + streaming_concurrent_limit: streaming_concurrent_limit, type: :registered, version: version, allowed_domains: nil_or_allowed_domains(allowed_domains) diff --git a/apps/api_web/mix.exs b/apps/api_web/mix.exs index dd893d50..a7ee47dc 100644 --- a/apps/api_web/mix.exs +++ b/apps/api_web/mix.exs @@ -85,7 +85,8 @@ defmodule ApiWeb.Mixfile do {:recaptcha, git: "https://github.com/samueljseay/recaptcha.git", tag: "71cd746"}, {:sentry, "~> 8.0"}, {:qr_code, "~> 3.0"}, - {:nimble_totp, "~> 1.0"} + {:nimble_totp, "~> 1.0"}, + {:uuid, "~> 1.1"} ] end end diff --git a/apps/api_web/test/api_web/rate_limiter_concurrent_test.exs b/apps/api_web/test/api_web/rate_limiter_concurrent_test.exs new file mode 100644 index 00000000..8b786405 --- /dev/null +++ b/apps/api_web/test/api_web/rate_limiter_concurrent_test.exs @@ -0,0 +1,43 @@ +defmodule ApiWeb.RateLimiterConcurrentTest do + @moduledoc false + use ExUnit.Case, async: false + use Plug.Test + alias ApiWeb.RateLimiter.RateLimiterConcurrent + + test "start_link/1" do + Application.stop(:api_web) + + on_exit(fn -> + Application.start(:api_web) + end) + + assert {:ok, _pid} = RateLimiterConcurrent.start_link([]) + end + + test "check_concurrent_rate_limit/1" do + {anon_static_at_limit?, anon_static_remaining, anon_static_limit} = + RateLimiterConcurrent.check_concurrent_rate_limit(%ApiWeb.User{type: :anon}, false) + + assert anon_static_limit == ApiWeb.config(:rate_limiter_concurrent, :max_anon_static) + assert anon_static_remaining == anon_static_limit + assert anon_static_at_limit? == false + + {registered_streaming_at_limit?, registered_streaming_remaining, registered_streaming_limit} = + RateLimiterConcurrent.check_concurrent_rate_limit(%ApiWeb.User{type: :registered}, true) + + assert registered_streaming_limit == + ApiWeb.config(:rate_limiter_concurrent, :max_registered_streaming) + + assert registered_streaming_remaining == registered_streaming_limit + assert registered_streaming_at_limit? == false + + {registered_static_at_limit?, registered_static_remaining, registered_static_limit} = + RateLimiterConcurrent.check_concurrent_rate_limit(%ApiWeb.User{type: :registered}, false) + + assert registered_static_limit == + ApiWeb.config(:rate_limiter_concurrent, :max_registered_static) + + assert registered_static_remaining == registered_static_limit + assert registered_static_at_limit? == false + end +end diff --git a/config/runtime.exs b/config/runtime.exs index 680367c7..f9d16464 100644 --- a/config/runtime.exs +++ b/config/runtime.exs @@ -41,7 +41,8 @@ if is_prod? and is_release? do config :api_web, RateLimiter.Memcache, connection_opts: [ namespace: System.fetch_env!("HOST"), - hostname: System.fetch_env!("MEMCACHED_HOST") + hostname: System.fetch_env!("MEMCACHED_HOST"), + coder: Memcache.Coder.JSON ] config :state_mediator, Realtime,