diff --git a/src/mochiweb_http.erl b/src/mochiweb_http.erl index 568019ff..e7f38ff2 100644 --- a/src/mochiweb_http.erl +++ b/src/mochiweb_http.erl @@ -23,7 +23,7 @@ -module(mochiweb_http). -author('bob@mochimedia.com'). --export([start/1, start_link/1, stop/0, stop/1]). +-export([start/1, start_link/1, stop/0, stop/1, stop/2]). -export([loop/3]). -export([after_response/2, reentry/1]). -export([parse_range_request/1, range_skip_length/2]). @@ -53,6 +53,9 @@ stop() -> stop(Name) -> mochiweb_socket_server:stop(Name). +stop(Name, Timeout) -> + mochiweb_socket_server:stop(Name, Timeout). + %% @spec start(Options) -> ServerRet %% Options = [option()] %% Option = {name, atom()} | {ip, string() | tuple()} | {backlog, integer()} diff --git a/src/mochiweb_socket_server.erl b/src/mochiweb_socket_server.erl index fd5e3824..36687c63 100644 --- a/src/mochiweb_socket_server.erl +++ b/src/mochiweb_socket_server.erl @@ -9,7 +9,7 @@ -include("internal.hrl"). --export([start/1, start_link/1, stop/1]). +-export([start/1, start_link/1, stop/1, stop/2]). -export([init/1, handle_call/3, handle_cast/2, terminate/2, code_change/3, handle_info/2]). -export([get/2, set/3]). @@ -29,7 +29,8 @@ ssl=false, ssl_opts=[{ssl_imp, new}], acceptor_pool=sets:new(), - profile_fun=undefined}). + profile_fun=undefined, + shutdown_notify_pid=undefined}). -define(is_old_state(State), not is_record(State, mochiweb_socket_server)). @@ -60,12 +61,15 @@ set(Name, Property, _Value) -> [Name, Property]). stop(Name) when is_atom(Name) orelse is_pid(Name) -> - gen_server:call(Name, stop); + gen_server:call(Name, stop); stop({Scope, Name}) when Scope =:= local orelse Scope =:= global -> stop(Name); stop(Options) -> State = parse_options(Options), stop(State#mochiweb_socket_server.name). +stop(Name, Timeout) when is_atom(Name) orelse is_pid(Name) andalso is_integer(Timeout) -> + gen_server:call(Name, prep_stop, Timeout), + gen_server:call(Name, stop). %% Internal API @@ -145,7 +149,9 @@ parse_options([{ssl_opts, SslOpts} | Rest], State) when is_list(SslOpts) -> SslOpts1 = [{ssl_imp, new} | proplists:delete(ssl_imp, SslOpts)], parse_options(Rest, State#mochiweb_socket_server{ssl_opts=SslOpts1}); parse_options([{profile_fun, ProfileFun} | Rest], State) when is_function(ProfileFun) -> - parse_options(Rest, State#mochiweb_socket_server{profile_fun=ProfileFun}). + parse_options(Rest, State#mochiweb_socket_server{profile_fun=ProfileFun}); +parse_options([{shutdown_notify_pid, NotifyPid} | Rest], State) when is_pid(NotifyPid) -> + parse_options(Rest, State#mochiweb_socket_server{shutdown_notify_pid=NotifyPid}). start_server(F, State=#mochiweb_socket_server{ssl=Ssl, name=Name}) -> @@ -265,6 +271,11 @@ handle_call({get, Property}, _From, State) -> {reply, Res, State}; handle_call(stop, _From, State) -> {stop, normal, ok, State}; +handle_call(prep_stop, From, State) -> + close_listen_socket(State), + State1 = State#mochiweb_socket_server{shutdown_notify_pid=From, acceptor_pool_size=0}, + % Reply will be given when active_socket count goes to 0 + {noreply, State1}; handle_call(_Message, _From, State) -> Res = error, {reply, Res, State}. @@ -294,7 +305,10 @@ handle_cast({set, profile_fun, ProfileFun}, State) -> terminate(Reason, State) when ?is_old_state(State) -> terminate(Reason, upgrade_state(State)); -terminate(_Reason, #mochiweb_socket_server{listen=Listen}) -> +terminate(_Reason, State) -> + close_listen_socket(State). + +close_listen_socket(#mochiweb_socket_server{listen=Listen}) -> mochiweb_socket:close(Listen). code_change(_OldVsn, State, _Extra) -> @@ -304,7 +318,8 @@ recycle_acceptor(Pid, State=#mochiweb_socket_server{ acceptor_pool=Pool, acceptor_pool_size=PoolSize, max=Max, - active_sockets=ActiveSockets}) -> + active_sockets=ActiveSockets, + shutdown_notify_pid=NotifyPid}) -> %% A socket is considered to be active from immediately after it %% has been accepted (see the {accepted, Pid, Timing} cast above). %% This function will be called when an acceptor is transitioning @@ -322,6 +337,12 @@ recycle_acceptor(Pid, State=#mochiweb_socket_server{ State1 = State#mochiweb_socket_server{ acceptor_pool=Pool1, active_sockets=ActiveSockets1}, + case NotifyPid of + undefined -> ok; + _ -> if ActiveSockets1 =< 0 -> gen_server:reply(NotifyPid, ok); + true -> error_logger:info_msg("~p clients outstanding",[ActiveSockets1]) + end + end, %% Spawn a new acceptor only if it will not overrun the maximum socket %% count or the maximum pool size. case NewSize + ActiveSockets1 < Max andalso NewSize < PoolSize of @@ -363,8 +384,6 @@ handle_info(Info, State) -> error_logger:info_report([{'INFO', Info}, {'State', State}]), {noreply, State}. - - %% %% Tests %% @@ -388,7 +407,8 @@ upgrade_state_test() -> acceptor_pool_size=acceptor_pool_size, ssl=ssl, ssl_opts=ssl_opts, acceptor_pool=acceptor_pool, - profile_fun=undefined}, + profile_fun=undefined, + shutdown_notify_pid=undefined}, ?assertEqual(CmpState, State). -endif. diff --git a/test/mochiweb_socket_server_tests.erl b/test/mochiweb_socket_server_tests.erl index c64f5b72..c483e0ec 100644 --- a/test/mochiweb_socket_server_tests.erl +++ b/test/mochiweb_socket_server_tests.erl @@ -140,10 +140,69 @@ normal_acceptor_test_fun() -> ?assertEqual(Expected, Result) end || {Max, PoolSize, NumClients, Expected} <- Tests]. +graceful_shutdown_test_fun(ShutDownDelay) -> + Tester = self(), + NumClients = 2, + ServerOpts = [{max, NumClients}, {acceptor_pool_size, NumClients}], + ServerLoop = + fun (Socket, _Opts) -> + Tester ! {server_accepted, self()}, + mochiweb_socket:setopts(Socket, [{packet, 1}]), + echo_loop(Socket) + end, + {Server, Port} = socket_server(ServerOpts, ServerLoop), + Data = <<"data">>, + ClientCmds = [{send_pid, Tester}, {wait_msg, go}, + {send, Data, Tester}, + {close_sock}, {send_msg, done, Tester}], + start_client_conns(Port, NumClients, fun client_fun/2, ClientCmds, Tester), + + ConnectLoop = + fun (Loop, Connected, Accepted, Errors) -> + case (length(Accepted) + Errors >= NumClients + andalso length(Connected) + Errors >= NumClients) of + true -> {Connected, Accepted}; + false -> + receive + {server_accepted, ServerPid} -> + Loop(Loop, Connected, [ServerPid | Accepted], Errors); + {client, ClientPid} -> + Loop(Loop, Connected ++ [ClientPid], Accepted, Errors); + {client_conn_error, _E} -> + Loop(Loop, Connected, Accepted, Errors + 1) + end + end + end, + {Connected, _} = ConnectLoop(ConnectLoop, [], [], 0), + + spawn(mochiweb_socket_server, stop, [Server, ShutDownDelay]), + + WaitLoop = + fun (_Loop, Done, Error, []) -> + {Done, Error}; + (Loop, Done, Error, [NextClient | Rest]) -> + NextClient ! go, + receive + {done, From} -> + Loop(Loop, [From | Done], Error, Rest); + E -> + Loop(Loop, Done, [E | Error], Rest) + end + end, + + {Done, Error} = WaitLoop(WaitLoop, [], [], Connected), + ?assertEqual(NumClients, length(Done)), + ?assertEqual([], Error). + + -define(LARGE_TIMEOUT, 40). normal_acceptor_test_() -> Tests = normal_acceptor_test_fun(), {timeout, ?LARGE_TIMEOUT, Tests}. + +graceful_shutdown_test_() -> + {timeout, ?LARGE_TIMEOUT, [fun() -> graceful_shutdown_test_fun(?LARGE_TIMEOUT - 1) end]}. + -endif.