diff --git a/spec/stream_client_spec.erl b/spec/stream_client_spec.erl index 4a29c1f..31a840a 100644 --- a/spec/stream_client_spec.erl +++ b/spec/stream_client_spec.erl @@ -4,6 +4,7 @@ -define(POST_ENDPOINT, {post, "http://localhost:4567/"}). -define(GET_ENDPOINT, {get, "http://localhost:4567/"}). -define(CONTENT_TYPE, "application/x-www-form-urlencoded"). +-define(TEST_AUTH, {basic, ["", ""]}). spec() -> describe("stream client", fun() -> @@ -25,7 +26,7 @@ spec() -> it("should return http errors", fun() -> meck:expect(httpc, request, fun(_, _, _, _) -> {error, something_went_wrong} end), - Result = stream_client:connect(?POST_ENDPOINT, [], "", self()), + Result = stream_client:connect(?POST_ENDPOINT, ?TEST_AUTH, "", self()), Expected = {error, {http_error, something_went_wrong}}, ?assertEqual(Expected, Result) @@ -39,7 +40,7 @@ spec() -> end ), - stream_client:connect(?POST_ENDPOINT, [], "", self()) + stream_client:connect(?POST_ENDPOINT, ?TEST_AUTH, "", self()) end), it("should use the correct url", fun() -> @@ -52,21 +53,7 @@ spec() -> end ), - stream_client:connect(?POST_ENDPOINT, [], "", self()) - end), - - it("should use the correct headers", fun() -> - Headers = [a, b, c], - - meck:expect(httpc, request, - fun(_, Args, _, _) -> - {_, PassedHeaders, _, _} = Args, - ?assertEqual(Headers, PassedHeaders), - {error, no_continue} % what the client expects - end - ), - - stream_client:connect(?POST_ENDPOINT, Headers, "", self()) + stream_client:connect(?POST_ENDPOINT, ?TEST_AUTH, "", self()) end), it("should use the correct content type", fun() -> @@ -78,7 +65,7 @@ spec() -> end ), - stream_client:connect(?POST_ENDPOINT, [], "", self()) + stream_client:connect(?POST_ENDPOINT, ?TEST_AUTH, "", self()) end), it("should use the correct params", fun() -> @@ -92,7 +79,7 @@ spec() -> end ), - stream_client:connect(?POST_ENDPOINT, [], PostData, self()) + stream_client:connect(?POST_ENDPOINT, ?TEST_AUTH, PostData, self()) end), @@ -105,7 +92,7 @@ spec() -> end ), - stream_client:connect(?POST_ENDPOINT, [], "", self()) + stream_client:connect(?POST_ENDPOINT, ?TEST_AUTH, "", self()) end), it("should use the correct http client arguments for streaming", fun() -> @@ -116,7 +103,7 @@ spec() -> end ), - stream_client:connect(?POST_ENDPOINT, [], "", self()) + stream_client:connect(?POST_ENDPOINT, ?TEST_AUTH, "", self()) end) end), @@ -124,7 +111,7 @@ spec() -> it("should return http errors", fun() -> meck:expect(httpc, request, fun(_, _, _, _) -> {error, something_went_wrong} end), - Result = stream_client:connect(?GET_ENDPOINT, [], "", self()), + Result = stream_client:connect(?GET_ENDPOINT, ?TEST_AUTH, "", self()), Expected = {error, {http_error, something_went_wrong}}, ?assertEqual(Expected, Result) @@ -138,7 +125,7 @@ spec() -> end ), - stream_client:connect(?GET_ENDPOINT, [], "", self()) + stream_client:connect(?GET_ENDPOINT, ?TEST_AUTH, "", self()) end), it("should use the correct url", fun() -> @@ -151,21 +138,7 @@ spec() -> end ), - stream_client:connect(?GET_ENDPOINT, [], "", self()) - end), - - it("should use the correct headers", fun() -> - Headers = [a, b, c], - - meck:expect(httpc, request, - fun(_, Args, _, _) -> - {_, PassedHeaders} = Args, - ?assertEqual(Headers, PassedHeaders), - {error, no_continue} % what the client expects - end - ), - - stream_client:connect(?GET_ENDPOINT, Headers, "", self()) + stream_client:connect(?GET_ENDPOINT, ?TEST_AUTH, "", self()) end), it("should use the correct params", fun() -> @@ -180,7 +153,7 @@ spec() -> end ), - stream_client:connect(?GET_ENDPOINT, [], PostData, self()) + stream_client:connect(?GET_ENDPOINT, ?TEST_AUTH, PostData, self()) end), @@ -193,7 +166,7 @@ spec() -> end ), - stream_client:connect(?GET_ENDPOINT, [], "", self()) + stream_client:connect(?GET_ENDPOINT, ?TEST_AUTH, "", self()) end), it("should use the correct http client arguments for streaming", fun() -> @@ -204,7 +177,7 @@ spec() -> end ), - stream_client:connect(?GET_ENDPOINT, [], "", self()) + stream_client:connect(?GET_ENDPOINT, ?TEST_AUTH, "", self()) end) end) end), @@ -234,7 +207,7 @@ spec() -> {ok, test} end), - stream_client:connect(?POST_ENDPOINT, [], "", Callback) + stream_client:connect(?POST_ENDPOINT, ?TEST_AUTH, "", Callback) end) end) end), diff --git a/spec/stream_manager_spec.erl b/spec/stream_manager_spec.erl index 2f6a82e..f85f348 100644 --- a/spec/stream_manager_spec.erl +++ b/spec/stream_manager_spec.erl @@ -197,18 +197,18 @@ spec() -> end) end), - describe("#set_headers", fun() -> - it("sets the headers", fun() -> - Headers = [{"X-Awesome", "true"}], + describe("#set_auth", fun() -> + it("sets the auth", fun() -> + Auth = {basic, ["User1", "Pass1"]}, meck:expect(stream_client, connect, fun(_, _, _, _) -> {ok, terminate} end ), - stream_manager:set_headers(test_stream_manager, Headers), + stream_manager:set_auth(test_stream_manager, Auth), stream_manager:start_stream(test_stream_manager), - meck:wait(stream_client, connect, ['_', Headers, '_', '_'], 100) + meck:wait(stream_client, connect, ['_', Auth, '_', '_'], 100) end), it("restarts the client if connected", fun() -> @@ -232,12 +232,13 @@ spec() -> ?assert(timeout) end, - NewHeaders = [{"X-Awesome", "true"}], - stream_manager:set_headers(test_stream_manager, NewHeaders), + NewAuth = {basic, ["User2", "Pass2"]}, + stream_manager:set_auth(test_stream_manager, NewAuth), % child 1 will be terminated by the manager, and this call will % return so we can wait for it through meck - meck:wait(stream_client, connect, ['_', [], '_', '_'], 100), + OldAuth = {basic, ["", ""]}, + meck:wait(stream_client, connect, ['_', OldAuth, '_', '_'], 100), % starting the client happens async, we need to wait for it % to return to check it was called (meck thing) @@ -248,7 +249,7 @@ spec() -> ?assert(timeout) end, - meck:wait(stream_client, connect, ['_', NewHeaders, '_', '_'], 100), + meck:wait(stream_client, connect, ['_', NewAuth, '_', '_'], 100), % check two seperate processes were started ?assertNotEqual(Child1, Child2) diff --git a/src/stream_client.erl b/src/stream_client.erl index b65e172..0304e84 100644 --- a/src/stream_client.erl +++ b/src/stream_client.erl @@ -7,7 +7,8 @@ -define(CONTENT_TYPE, "application/x-www-form-urlencoded"). -spec connect(string(), list(), string(), fun()) -> ok | {error, reason}. -connect({post, Url}, Headers, Params, Callback) -> +connect({post, Url}, Auth, Params, Callback) -> + Headers = stream_client_util:headers_for_auth(Auth, {post, Url}, Params), case catch httpc:request(post, {Url, Headers, ?CONTENT_TYPE, Params}, [], [{sync, false}, {stream, self}]) of {ok, RequestId} -> ?MODULE:handle_connection(Callback, RequestId); @@ -16,7 +17,8 @@ connect({post, Url}, Headers, Params, Callback) -> {error, {http_error, Reason}} end; -connect({get, BaseUrl}, Headers, Params, Callback) -> +connect({get, BaseUrl}, Auth, Params, Callback) -> + Headers = stream_client_util:headers_for_auth(Auth, {get, BaseUrl}, Params), Url = case Params of "" -> BaseUrl; diff --git a/src/stream_client_util.erl b/src/stream_client_util.erl index 87fe99f..0fb5f62 100644 --- a/src/stream_client_util.erl +++ b/src/stream_client_util.erl @@ -1,5 +1,6 @@ -module(stream_client_util). -export([ + headers_for_auth/3, generate_headers/0, generate_auth_headers/2, generate_auth_headers/3, @@ -9,6 +10,11 @@ decode/1 ]). +% TODO extend for oauth +-spec headers_for_auth(term(), term(), list()) -> list(). +headers_for_auth({basic, [User, Pass]}, _Endpoint, _Params) -> + generate_auth_headers(User, Pass). + -spec generate_headers() -> list(). generate_headers() -> [ diff --git a/src/stream_manager.erl b/src/stream_manager.erl index 78aeb2d..494fc73 100644 --- a/src/stream_manager.erl +++ b/src/stream_manager.erl @@ -15,13 +15,13 @@ stop_stream/1, set_params/2, set_callback/2, - set_headers/2, + set_auth/2, status/1 ]). -record(state, { status = disconnected :: atom(), - headers = [] :: list(), + auth = {basic, ["", ""]} :: list(), params = "" :: string(), callback :: term(), client_pid :: pid() @@ -52,8 +52,8 @@ set_params(ServerRef, Params) -> set_callback(ServerRef, Callback) -> gen_server:call(ServerRef, {set_callback, Callback}). -set_headers(ServerRef, Headers) -> - gen_server:call(ServerRef, {set_headers, Headers}). +set_auth(ServerRef, Auth) -> + gen_server:call(ServerRef, {set_auth, Auth}). status(ServerRef) -> gen_server:call(ServerRef, status). @@ -121,9 +121,9 @@ handle_call({set_params, Params}, _From, State = #state{client_pid = Pid, params {reply, ok, State#state{ params = Params, client_pid = NewPid }} end; -handle_call({set_headers, Headers}, _From, State = #state{client_pid = Pid, headers = OldHeaders}) -> - case Headers of - OldHeaders -> +handle_call({set_auth, Auth}, _From, State = #state{client_pid = Pid, auth = OldAuth}) -> + case Auth of + OldAuth -> % same, don't do anything {reply, ok, State}; _ -> @@ -135,9 +135,9 @@ handle_call({set_headers, Headers}, _From, State = #state{client_pid = Pid, head _ -> % already started, restart ok = client_shutdown(State), - NewPid = client_connect(State#state{ headers = Headers }) + NewPid = client_connect(State#state{ auth = Auth }) end, - {reply, ok, State#state{ headers = Headers, client_pid = NewPid }} + {reply, ok, State#state{ auth = Auth, client_pid = NewPid }} end; handle_call({set_callback, Callback}, _From, State) -> @@ -222,7 +222,7 @@ code_change(_OldVsn, State, _Extra) -> %%% Internal functions %%-------------------------------------------------------------------- -spec client_connect(record()) -> pid(). -client_connect(#state{headers = Headers, params = Params}) -> +client_connect(#state{auth = Auth, params = Params}) -> Parent = self(), % We don't use the callback from the state, as we want to be able to change @@ -234,7 +234,7 @@ client_connect(#state{headers = Headers, params = Params}) -> Endpoint = {post, stream_client_util:filter_url()}, spawn_link(fun() -> - case stream_client:connect(Endpoint, Headers, Params, Callback) of + case stream_client:connect(Endpoint, Auth, Params, Callback) of {error, unauthorised} -> % Didn't connect, unauthorised Parent ! {self(), client_exit, unauthorised};