diff --git a/api/gateway/bindings.go b/api/gateway/bindings.go new file mode 100644 index 0000000000..c01bd2da47 --- /dev/null +++ b/api/gateway/bindings.go @@ -0,0 +1,73 @@ +package gateway + +import ( + "fmt" + "net/http" +) + +func (h *Handler) RegisterEndpoints(rpc *Server) { + // state endpoints + rpc.RegisterHandlerFunc( + fmt.Sprintf("%s/{%s}", balanceEndpoint, addrKey), + h.handleBalanceRequest, + http.MethodGet, + ) + + rpc.RegisterHandlerFunc( + submitTxEndpoint, + h.handleSubmitTx, + http.MethodPost, + ) + + rpc.RegisterHandlerFunc( + healthEndpoint, + h.handleHealthRequest, + http.MethodGet, + ) + + // share endpoints + rpc.RegisterHandlerFunc( + fmt.Sprintf( + "%s/{%s}/height/{%s}", + namespacedSharesEndpoint, + namespaceKey, + heightKey, + ), + h.handleSharesByNamespaceRequest, + http.MethodGet, + ) + + rpc.RegisterHandlerFunc( + fmt.Sprintf("%s/{%s}", namespacedSharesEndpoint, namespaceKey), + h.handleSharesByNamespaceRequest, + http.MethodGet, + ) + + rpc.RegisterHandlerFunc( + fmt.Sprintf("%s/{%s}/height/{%s}", namespacedDataEndpoint, namespaceKey, heightKey), + h.handleDataByNamespaceRequest, + http.MethodGet, + ) + + rpc.RegisterHandlerFunc( + fmt.Sprintf("%s/{%s}", namespacedDataEndpoint, namespaceKey), + h.handleDataByNamespaceRequest, + http.MethodGet, + ) + + // DAS endpoints + rpc.RegisterHandlerFunc( + fmt.Sprintf("%s/{%s}", heightAvailabilityEndpoint, heightKey), + h.handleHeightAvailabilityRequest, + http.MethodGet, + ) + + // header endpoints + rpc.RegisterHandlerFunc( + fmt.Sprintf("%s/{%s}", headerByHeightEndpoint, heightKey), + h.handleHeaderRequest, + http.MethodGet, + ) + + rpc.RegisterHandlerFunc(headEndpoint, h.handleHeadRequest, http.MethodGet) +} diff --git a/api/gateway/bindings_test.go b/api/gateway/bindings_test.go new file mode 100644 index 0000000000..5d27d5e4c7 --- /dev/null +++ b/api/gateway/bindings_test.go @@ -0,0 +1,119 @@ +package gateway + +import ( + "fmt" + "net/http" + "testing" + + "github.com/gorilla/mux" + "github.com/stretchr/testify/require" +) + +func TestRegisterEndpoints(t *testing.T) { + handler := &Handler{} + rpc := NewServer("localhost", "6969") + + handler.RegisterEndpoints(rpc) + + testCases := []struct { + name string + path string + method string + expected bool + }{ + { + name: "Get balance endpoint", + path: fmt.Sprintf("%s/{%s}", balanceEndpoint, addrKey), + method: http.MethodGet, + expected: true, + }, + { + name: "Submit transaction endpoint", + path: submitTxEndpoint, + method: http.MethodPost, + expected: true, + }, + { + name: "Get namespaced shares by height endpoint", + path: fmt.Sprintf("%s/{%s}/height/{%s}", namespacedSharesEndpoint, namespaceKey, heightKey), + method: http.MethodGet, + expected: true, + }, + { + name: "Get namespaced shares endpoint", + path: fmt.Sprintf("%s/{%s}", namespacedSharesEndpoint, namespaceKey), + method: http.MethodGet, + expected: true, + }, + { + name: "Get namespaced data by height endpoint", + path: fmt.Sprintf("%s/{%s}/height/{%s}", namespacedDataEndpoint, namespaceKey, heightKey), + method: http.MethodGet, + expected: true, + }, + { + name: "Get namespaced data endpoint", + path: fmt.Sprintf("%s/{%s}", namespacedDataEndpoint, namespaceKey), + method: http.MethodGet, + expected: true, + }, + { + name: "Get health endpoint", + path: "/status/health", + method: http.MethodGet, + expected: true, + }, + + // Going forward, we can add previously deprecated and since + // removed endpoints here to ensure we don't accidentally re-enable + // them in the future and accidentally expand surface area + { + name: "example totally bogus endpoint", + path: fmt.Sprintf("/wutang/{%s}/%s", "chambers", "36"), + method: http.MethodGet, + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal( + t, + tc.expected, + hasEndpointRegistered(rpc.Router(), tc.path, tc.method), + "Endpoint registration mismatch for: %s %s %s", tc.name, tc.method, tc.path) + }) + } +} + +func hasEndpointRegistered(router *mux.Router, path string, method string) bool { + var registered bool + err := router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { + template, err := route.GetPathTemplate() + if err != nil { + return err + } + + if template == path { + methods, err := route.GetMethods() + if err != nil { + return err + } + + for _, m := range methods { + if m == method { + registered = true + return nil + } + } + } + return nil + }) + + if err != nil { + fmt.Println("Error walking through routes:", err) + return false + } + + return registered +} diff --git a/api/gateway/endpoints.go b/api/gateway/endpoints.go deleted file mode 100644 index 104d01b053..0000000000 --- a/api/gateway/endpoints.go +++ /dev/null @@ -1,32 +0,0 @@ -package gateway - -import ( - "fmt" - "net/http" -) - -func (h *Handler) RegisterEndpoints(rpc *Server) { - // state endpoints - rpc.RegisterHandlerFunc(fmt.Sprintf("%s/{%s}", balanceEndpoint, addrKey), h.handleBalanceRequest, - http.MethodGet) - rpc.RegisterHandlerFunc(submitTxEndpoint, h.handleSubmitTx, http.MethodPost) - - // share endpoints - rpc.RegisterHandlerFunc(fmt.Sprintf("%s/{%s}/height/{%s}", namespacedSharesEndpoint, namespaceKey, heightKey), - h.handleSharesByNamespaceRequest, http.MethodGet) - rpc.RegisterHandlerFunc(fmt.Sprintf("%s/{%s}", namespacedSharesEndpoint, namespaceKey), - h.handleSharesByNamespaceRequest, http.MethodGet) - rpc.RegisterHandlerFunc(fmt.Sprintf("%s/{%s}/height/{%s}", namespacedDataEndpoint, namespaceKey, heightKey), - h.handleDataByNamespaceRequest, http.MethodGet) - rpc.RegisterHandlerFunc(fmt.Sprintf("%s/{%s}", namespacedDataEndpoint, namespaceKey), - h.handleDataByNamespaceRequest, http.MethodGet) - - // DAS endpoints - rpc.RegisterHandlerFunc(fmt.Sprintf("%s/{%s}", heightAvailabilityEndpoint, heightKey), - h.handleHeightAvailabilityRequest, http.MethodGet) - - // header endpoints - rpc.RegisterHandlerFunc(fmt.Sprintf("%s/{%s}", headerByHeightEndpoint, heightKey), h.handleHeaderRequest, - http.MethodGet) - rpc.RegisterHandlerFunc(headEndpoint, h.handleHeadRequest, http.MethodGet) -} diff --git a/api/gateway/health.go b/api/gateway/health.go new file mode 100644 index 0000000000..0ddb56bb67 --- /dev/null +++ b/api/gateway/health.go @@ -0,0 +1,16 @@ +package gateway + +import "net/http" + +const ( + healthEndpoint = "/status/health" +) + +func (h *Handler) handleHealthRequest(w http.ResponseWriter, _ *http.Request) { + _, err := w.Write([]byte("ok")) + if err != nil { + log.Errorw("serving request", "endpoint", healthEndpoint, "err", err) + writeError(w, http.StatusBadGateway, healthEndpoint, err) + return + } +} diff --git a/api/gateway/server.go b/api/gateway/server.go index 181bfdfe55..7eab7c7bf9 100644 --- a/api/gateway/server.go +++ b/api/gateway/server.go @@ -36,6 +36,10 @@ func NewServer(address, port string) *Server { return server } +func (s *Server) Router() *mux.Router { + return s.srvMux +} + // Start starts the gateway Server, listening on the given address. func (s *Server) Start(context.Context) error { couldStart := s.started.CompareAndSwap(false, true) diff --git a/api/gateway/util.go b/api/gateway/util.go index bffd7ebc88..d3739f9e9c 100644 --- a/api/gateway/util.go +++ b/api/gateway/util.go @@ -1,7 +1,6 @@ package gateway import ( - "encoding/json" "net/http" ) @@ -9,13 +8,12 @@ func writeError(w http.ResponseWriter, statusCode int, endpoint string, err erro log.Debugw("serving request", "endpoint", endpoint, "err", err) w.WriteHeader(statusCode) - errBody, jerr := json.Marshal(err.Error()) - if jerr != nil { - log.Errorw("serializing error", "endpoint", endpoint, "err", jerr) - return - } - _, werr := w.Write(errBody) - if werr != nil { - log.Errorw("writing error response", "endpoint", endpoint, "err", werr) + + errorMessage := err.Error() // Get the error message as a string + errorBytes := []byte(errorMessage) + + _, err = w.Write(errorBytes) + if err != nil { + log.Errorw("writing error response", "endpoint", endpoint, "err", err) } } diff --git a/api/gateway/util_test.go b/api/gateway/util_test.go new file mode 100644 index 0000000000..d41b0918d2 --- /dev/null +++ b/api/gateway/util_test.go @@ -0,0 +1,24 @@ +package gateway + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWriteError(t *testing.T) { + t.Run("writeError", func(t *testing.T) { + // Create a mock HTTP response writer + w := httptest.NewRecorder() + + testErr := errors.New("test error") + + writeError(w, http.StatusInternalServerError, "/api/endpoint", testErr) + assert.Equal(t, http.StatusInternalServerError, w.Code) + responseBody := w.Body.Bytes() + assert.Equal(t, testErr.Error(), string(responseBody)) + }) +} diff --git a/cmd/celestia/cmd_test.go b/cmd/celestia/cmd_test.go index 9c26489e14..94dd3625b8 100644 --- a/cmd/celestia/cmd_test.go +++ b/cmd/celestia/cmd_test.go @@ -33,7 +33,7 @@ func TestCompletionHelpString(t *testing.T) { } methods := reflect.VisibleFields(reflect.TypeOf(TestFields{})) for i, method := range methods { - require.Equal(t, testOutputs[i], parseSignatureForHelpstring(method)) + require.Equal(t, testOutputs[i], parseSignatureForHelpString(method)) } } @@ -129,7 +129,7 @@ func TestBridge(t *testing.T) { */ } -func parseSignatureForHelpstring(methodSig reflect.StructField) string { +func parseSignatureForHelpString(methodSig reflect.StructField) string { simplifiedSignature := "(" in, out := methodSig.Type.NumIn(), methodSig.Type.NumOut() for i := 1; i < in; i++ {