Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/go_modules/golang.org/x/sync-0.6.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ramin authored Jan 23, 2024
2 parents 1e33c75 + c173a1e commit ec47ff7
Show file tree
Hide file tree
Showing 31 changed files with 476 additions and 255 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/go-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ jobs:
needs: [lint, go_mod_tidy_check]
name: Run Unit Tests with Race Detector
runs-on: ubuntu-latest
continue-on-error: true

steps:
- uses: actions/checkout@v4
Expand All @@ -124,6 +125,7 @@ jobs:
needs: [lint, go_mod_tidy_check]
name: Run Integration Tests
runs-on: ubuntu-latest
continue-on-error: true

steps:
- uses: actions/checkout@v4
Expand Down
73 changes: 73 additions & 0 deletions api/gateway/bindings.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package gateway

import (
"fmt"
"net/http"
)

func (h *Handler) RegisterEndpoints(rpc *Server) {
// state endpoints
rpc.RegisterHandlerFunc(
fmt.Sprintf("%s/{%s}", balanceEndpoint, addrKey),
h.handleBalanceRequest,
http.MethodGet,
)

rpc.RegisterHandlerFunc(
submitTxEndpoint,
h.handleSubmitTx,
http.MethodPost,
)

rpc.RegisterHandlerFunc(
healthEndpoint,
h.handleHealthRequest,
http.MethodGet,
)

// share endpoints
rpc.RegisterHandlerFunc(
fmt.Sprintf(
"%s/{%s}/height/{%s}",
namespacedSharesEndpoint,
namespaceKey,
heightKey,
),
h.handleSharesByNamespaceRequest,
http.MethodGet,
)

rpc.RegisterHandlerFunc(
fmt.Sprintf("%s/{%s}", namespacedSharesEndpoint, namespaceKey),
h.handleSharesByNamespaceRequest,
http.MethodGet,
)

rpc.RegisterHandlerFunc(
fmt.Sprintf("%s/{%s}/height/{%s}", namespacedDataEndpoint, namespaceKey, heightKey),
h.handleDataByNamespaceRequest,
http.MethodGet,
)

rpc.RegisterHandlerFunc(
fmt.Sprintf("%s/{%s}", namespacedDataEndpoint, namespaceKey),
h.handleDataByNamespaceRequest,
http.MethodGet,
)

// DAS endpoints
rpc.RegisterHandlerFunc(
fmt.Sprintf("%s/{%s}", heightAvailabilityEndpoint, heightKey),
h.handleHeightAvailabilityRequest,
http.MethodGet,
)

// header endpoints
rpc.RegisterHandlerFunc(
fmt.Sprintf("%s/{%s}", headerByHeightEndpoint, heightKey),
h.handleHeaderRequest,
http.MethodGet,
)

rpc.RegisterHandlerFunc(headEndpoint, h.handleHeadRequest, http.MethodGet)
}
119 changes: 119 additions & 0 deletions api/gateway/bindings_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package gateway

import (
"fmt"
"net/http"
"testing"

"github.com/gorilla/mux"
"github.com/stretchr/testify/require"
)

func TestRegisterEndpoints(t *testing.T) {
handler := &Handler{}
rpc := NewServer("localhost", "6969")

handler.RegisterEndpoints(rpc)

testCases := []struct {
name string
path string
method string
expected bool
}{
{
name: "Get balance endpoint",
path: fmt.Sprintf("%s/{%s}", balanceEndpoint, addrKey),
method: http.MethodGet,
expected: true,
},
{
name: "Submit transaction endpoint",
path: submitTxEndpoint,
method: http.MethodPost,
expected: true,
},
{
name: "Get namespaced shares by height endpoint",
path: fmt.Sprintf("%s/{%s}/height/{%s}", namespacedSharesEndpoint, namespaceKey, heightKey),
method: http.MethodGet,
expected: true,
},
{
name: "Get namespaced shares endpoint",
path: fmt.Sprintf("%s/{%s}", namespacedSharesEndpoint, namespaceKey),
method: http.MethodGet,
expected: true,
},
{
name: "Get namespaced data by height endpoint",
path: fmt.Sprintf("%s/{%s}/height/{%s}", namespacedDataEndpoint, namespaceKey, heightKey),
method: http.MethodGet,
expected: true,
},
{
name: "Get namespaced data endpoint",
path: fmt.Sprintf("%s/{%s}", namespacedDataEndpoint, namespaceKey),
method: http.MethodGet,
expected: true,
},
{
name: "Get health endpoint",
path: "/status/health",
method: http.MethodGet,
expected: true,
},

// Going forward, we can add previously deprecated and since
// removed endpoints here to ensure we don't accidentally re-enable
// them in the future and accidentally expand surface area
{
name: "example totally bogus endpoint",
path: fmt.Sprintf("/wutang/{%s}/%s", "chambers", "36"),
method: http.MethodGet,
expected: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
require.Equal(
t,
tc.expected,
hasEndpointRegistered(rpc.Router(), tc.path, tc.method),
"Endpoint registration mismatch for: %s %s %s", tc.name, tc.method, tc.path)
})
}
}

func hasEndpointRegistered(router *mux.Router, path string, method string) bool {
var registered bool
err := router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
template, err := route.GetPathTemplate()
if err != nil {
return err
}

if template == path {
methods, err := route.GetMethods()
if err != nil {
return err
}

for _, m := range methods {
if m == method {
registered = true
return nil
}
}
}
return nil
})

if err != nil {
fmt.Println("Error walking through routes:", err)
return false
}

return registered
}
32 changes: 0 additions & 32 deletions api/gateway/endpoints.go

This file was deleted.

16 changes: 16 additions & 0 deletions api/gateway/health.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package gateway

import "net/http"

const (
healthEndpoint = "/status/health"
)

func (h *Handler) handleHealthRequest(w http.ResponseWriter, _ *http.Request) {
_, err := w.Write([]byte("ok"))
if err != nil {
log.Errorw("serving request", "endpoint", healthEndpoint, "err", err)
writeError(w, http.StatusBadGateway, healthEndpoint, err)
return
}
}
20 changes: 0 additions & 20 deletions api/gateway/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,15 @@ package gateway

import (
"context"
"errors"
"net/http"
"time"

"github.com/gorilla/mux"

"github.com/celestiaorg/celestia-node/nodebuilder/state"
)

const timeout = time.Minute

func (h *Handler) RegisterMiddleware(srv *Server) {
srv.RegisterMiddleware(
setContentType,
checkPostDisabled(h.state),
wrapRequestContext,
enableCors,
)
Expand All @@ -36,20 +30,6 @@ func setContentType(next http.Handler) http.Handler {
})
}

// checkPostDisabled ensures that context was canceled and prohibit POST requests.
func checkPostDisabled(state state.Module) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// check if state service was halted and deny the transaction
if r.Method == http.MethodPost && state.IsStopped(r.Context()) {
writeError(w, http.StatusMethodNotAllowed, r.URL.Path, errors.New("not possible to submit data"))
return
}
next.ServeHTTP(w, r)
})
}
}

// wrapRequestContext ensures we implement a deadline on serving requests
// via the gateway server-side to prevent context leaks.
func wrapRequestContext(next http.Handler) http.Handler {
Expand Down
4 changes: 4 additions & 0 deletions api/gateway/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ func NewServer(address, port string) *Server {
return server
}

func (s *Server) Router() *mux.Router {
return s.srvMux
}

// Start starts the gateway Server, listening on the given address.
func (s *Server) Start(context.Context) error {
couldStart := s.started.CompareAndSwap(false, true)
Expand Down
16 changes: 7 additions & 9 deletions api/gateway/util.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
package gateway

import (
"encoding/json"
"net/http"
)

func writeError(w http.ResponseWriter, statusCode int, endpoint string, err error) {
log.Debugw("serving request", "endpoint", endpoint, "err", err)

w.WriteHeader(statusCode)
errBody, jerr := json.Marshal(err.Error())
if jerr != nil {
log.Errorw("serializing error", "endpoint", endpoint, "err", jerr)
return
}
_, werr := w.Write(errBody)
if werr != nil {
log.Errorw("writing error response", "endpoint", endpoint, "err", werr)

errorMessage := err.Error() // Get the error message as a string
errorBytes := []byte(errorMessage)

_, err = w.Write(errorBytes)
if err != nil {
log.Errorw("writing error response", "endpoint", endpoint, "err", err)
}
}
24 changes: 24 additions & 0 deletions api/gateway/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package gateway

import (
"errors"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestWriteError(t *testing.T) {
t.Run("writeError", func(t *testing.T) {
// Create a mock HTTP response writer
w := httptest.NewRecorder()

testErr := errors.New("test error")

writeError(w, http.StatusInternalServerError, "/api/endpoint", testErr)
assert.Equal(t, http.StatusInternalServerError, w.Code)
responseBody := w.Body.Bytes()
assert.Equal(t, testErr.Error(), string(responseBody))
})
}
9 changes: 6 additions & 3 deletions api/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,15 @@ func TestAuthedRPC(t *testing.T) {
// 2. Test method with write-level permissions
expectedResp := &state.TxResponse{}
if tt.perm > 2 {
server.State.EXPECT().SubmitTx(gomock.Any(), gomock.Any()).Return(expectedResp, nil)
txResp, err := rpcClient.State.SubmitTx(ctx, []byte{})
server.State.EXPECT().Delegate(gomock.Any(), gomock.Any(),
gomock.Any(), gomock.Any(), gomock.Any()).Return(expectedResp, nil)
txResp, err := rpcClient.State.Delegate(ctx,
state.ValAddress{}, state.Int{}, state.Int{}, 0)
require.NoError(t, err)
require.Equal(t, expectedResp, txResp)
} else {
_, err := rpcClient.State.SubmitTx(ctx, []byte{})
_, err := rpcClient.State.Delegate(ctx,
state.ValAddress{}, state.Int{}, state.Int{}, 0)
require.Error(t, err)
require.ErrorContains(t, err, "missing permission")
}
Expand Down
Loading

0 comments on commit ec47ff7

Please sign in to comment.