diff --git a/rpc/http_handler.go b/rpc/http_handler.go index 2bfae50..b7ac0d1 100644 --- a/rpc/http_handler.go +++ b/rpc/http_handler.go @@ -18,16 +18,17 @@ var ( ErrRpcResponseUnmarshal = errors.New("failed to unmarshal rpc response") ) -// httpHandler implements Handler interface using the HTTP protocol under the implementation -type httpHandler struct { - httpClient *http.Client - endpoint string +// HttpHandler implements Handler interface using the HTTP protocol under the implementation +type HttpHandler struct { + httpClient *http.Client + endpoint string + CustomHeaders map[string]string } -// NewHttpHandler is a constructor for httpHandler that suppose to configure http.Client +// NewHttpHandler is a constructor for HttpHandler that suppose to configure http.Client // examples of usage can be found here [Test_ConfigurableClient_GetDeploy] -func NewHttpHandler(endpoint string, client *http.Client) Handler { - return &httpHandler{ +func NewHttpHandler(endpoint string, client *http.Client) *HttpHandler { + return &HttpHandler{ httpClient: client, endpoint: endpoint, } @@ -36,7 +37,7 @@ func NewHttpHandler(endpoint string, client *http.Client) Handler { // ProcessCall operates with an external RPC server through HTTP. It builds and processes the request, // reads a response and handles errors. All logic with HTTP interaction is isolated here and can be replaced with // other (more efficient) protocols. -func (c *httpHandler) ProcessCall(ctx context.Context, params RpcRequest) (RpcResponse, error) { +func (c *HttpHandler) ProcessCall(ctx context.Context, params RpcRequest) (RpcResponse, error) { body, err := json.Marshal(params) if err != nil { return RpcResponse{}, fmt.Errorf("%w, details: %s", ErrParamsUnmarshalHandler, err.Error()) @@ -47,6 +48,9 @@ func (c *httpHandler) ProcessCall(ctx context.Context, params RpcRequest) (RpcRe return RpcResponse{}, fmt.Errorf("%w, details: %s", ErrBuildHttpRequestHandler, err.Error()) } request.Header.Add("Content-Type", "application/json") + for name, val := range c.CustomHeaders { + request.Header.Add(name, val) + } request = request.WithContext(ctx) resp, err := c.httpClient.Do(request) diff --git a/tests/rpc/client_example_test.go b/tests/rpc/client_example_test.go index 1aa51bf..8cf709c 100644 --- a/tests/rpc/client_example_test.go +++ b/tests/rpc/client_example_test.go @@ -84,3 +84,25 @@ func Test_SpeculativeExec(t *testing.T) { require.NoError(t, err) assert.Equal(t, uint64(100000000), result.ExecutionResult.Success.Cost) } + +func Test_Client_RPCGetStatus_WithAuthorizationHeader(t *testing.T) { + authToken := "1234567890" + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + auth := req.Header.Get("Authorization") + if auth != authToken { + rw.WriteHeader(http.StatusUnauthorized) + return + } + fixture, err := os.ReadFile("../data/rpc_response/get_status.json") + require.NoError(t, err) + rw.Write(fixture) + rw.WriteHeader(http.StatusOK) + })) + handler := casper.NewRPCHandler(server.URL, http.DefaultClient) + handler.CustomHeaders = map[string]string{"Authorization": authToken} + client := casper.NewRPCClient(handler) + + status, err := client.GetStatus(context.Background()) + require.NoError(t, err) + assert.Equal(t, "1.0.0", status.APIVersion) +} diff --git a/tests/sse/example_test.go b/tests/sse/example_test.go index 46a9401..e6d0bef 100644 --- a/tests/sse/example_test.go +++ b/tests/sse/example_test.go @@ -5,14 +5,18 @@ package sse import ( "context" + "encoding/json" "errors" "fmt" + "io" "log" "net/http" + "net/http/httptest" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/make-software/casper-go-sdk/sse" ) @@ -141,3 +145,32 @@ func Test_SSE_WithConfigurations(t *testing.T) { _ = streamer _ = consumer } + +func Test_Client_WithAuthorizationHeader(t *testing.T) { + authToken := "1234567890" + server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + auth := request.Header.Get("Authorization") + if auth != authToken { + writer.WriteHeader(http.StatusUnauthorized) + return + } + _, err := writer.Write(json.RawMessage(`data: {"ApiVersion":"1.0.0"}`)) + require.NoError(t, err) + })) + + client := sse.NewClient(server.URL) + client.Streamer.Connection.Headers = map[string]string{"Authorization": authToken} + ctx, cancel := context.WithCancel(context.Background()) + client.RegisterHandler(sse.APIVersionEventType, func(ctx context.Context, event sse.RawEvent) error { + data, err := event.ParseAsAPIVersionEvent() + require.NoError(t, err) + assert.Equal(t, "1.0.0", data.APIVersion) + cancel() + return nil + }) + err := client.Start(context.Background(), 123) + if err != io.EOF { + require.NoError(t, err) + } + <-ctx.Done() +}