From a476140a3a454eeedc0d95d7b78e20e693f7aca6 Mon Sep 17 00:00:00 2001 From: Brian Yeh Date: Tue, 19 Sep 2023 17:35:07 -0700 Subject: [PATCH] use data structure instead of route matching to register endpoints. --- nakama/main.go | 130 ++++++++++++++++++++++++++----------------------- 1 file changed, 69 insertions(+), 61 deletions(-) diff --git a/nakama/main.go b/nakama/main.go index 92400fb..91a280f 100644 --- a/nakama/main.go +++ b/nakama/main.go @@ -10,7 +10,6 @@ import ( "io" "net/http" "os" - "regexp" "sync" "github.com/heroiclabs/nakama-common/api" @@ -274,77 +273,86 @@ func handleShowPersona(ctx context.Context, logger runtime.Logger, db *sql.DB, n // initCardinalEndpoints queries the cardinal server to find the list of existing endpoints, and attempts to // set up RPC wrappers around each one. func initCardinalEndpoints(logger runtime.Logger, initializer runtime.Initializer) error { - endpoints, err := cardinalListAllEndpoints() + //endpoints, err := cardinalListAllEndpoints() + + endpointStruct, err := cardinalGetEndpointsStruct() if err != nil { - return fmt.Errorf("failed to get list of cardinal endpoints: %w", err) - } - matchQueryRoutes := regexp.MustCompile("^query/*") - for _, e := range endpoints { - logger.Debug("registering: %v", e) - currEndpoint := e - if currEndpoint[0] == '/' { - currEndpoint = currEndpoint[1:] + return err + } + + txEndpoints := endpointStruct.TxEndpoints + queryEndpoints := endpointStruct.QueryEndpoints + + createPayloadSigned := func(payload string, endpoint string, nk runtime.NakamaModule, ctx context.Context) (io.Reader, error) { + logger.Debug("The %s endpoint requires a signed payload", endpoint) + signedPayload, err := makeSignedPayload(ctx, nk, payload) + if err != nil { + return nil, err } - err := initializer.RegisterRpc(currEndpoint, func(ctx context.Context, logger runtime.Logger, db *sql.DB, nk runtime.NakamaModule, payload string) (string, error) { - logger.Debug("Got request for %q", currEndpoint) - var resultPayload io.Reader - if !matchQueryRoutes.MatchString(currEndpoint) { //queries are not signed. - logger.Debug("The %s endpoint requires a signed payload", currEndpoint) - signedPayload, err := makeSignedPayload(ctx, nk, payload) + return signedPayload, nil + } + + createUnsignedPayload := func(payload string, endpoint string, _ runtime.NakamaModule, _ context.Context) (io.Reader, error) { + payloadBytes := []byte(payload) + formattedPayloadBuffer := bytes.NewBuffer([]byte{}) + if !json.Valid(payloadBytes) { + return nil, fmt.Errorf("data %q is not valid json", string(payloadBytes)) + } + err = json.Compact(formattedPayloadBuffer, payloadBytes) + if err != nil { + return nil, err + } + return formattedPayloadBuffer, nil + } + + registerEndpoints := func(endpoints []string, createPayload func(string, string, runtime.NakamaModule, context.Context) (io.Reader, error)) error { + for _, e := range endpoints { + logger.Debug("registering: %v", e) + currEndpoint := e + if currEndpoint[0] == '/' { + currEndpoint = currEndpoint[1:] + } + err := initializer.RegisterRpc(currEndpoint, func(ctx context.Context, logger runtime.Logger, db *sql.DB, nk runtime.NakamaModule, payload string) (string, error) { + logger.Debug("Got request for %q", currEndpoint) + var resultPayload io.Reader + resultPayload, err = createPayload(payload, currEndpoint, nk, ctx) if err != nil { - return logError(logger, "unable to make signed payload: %v", err) + return logError(logger, "unable to make payload: %v", err) } - resultPayload = signedPayload - } else { - logger.Debug("The %s endpoint requires an unsigned payload", currEndpoint) - // Make sure the given string/[]byte is valid json - payloadBytes := []byte(payload) - if !json.Valid(payloadBytes) { - return "", fmt.Errorf("data %q is not valid json", string(payloadBytes)) + + req, err := http.NewRequestWithContext(ctx, "POST", makeURL(currEndpoint), resultPayload) + req.Header.Set("Content-Type", "application/json") + if err != nil { + return logError(logger, "request setup failed for endpoint %q: %v", currEndpoint, err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return logError(logger, "request failed for endpoint %q: %v", currEndpoint, err) } - // Unmarshal, then marshal the data to normalize it. For example, extra spaces will be removed. - // This is required because when the signed payload is serialized/deserialized those spaces will also - // be lost. If they are not removed beforehand, the hashes of the message before serialization and after - // will be different. - // update on the above comment 9/18/2023: - // This json normalization logic was extracted from signed payload. The payload is no longer signed in this branch - // of logic but keeping the normalization here in case there's any crypto stuff down the line. - m := map[string]any{} - if err := json.Unmarshal(payloadBytes, &m); err != nil { - return "", err + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + return logError(logger, "bad status code: %v: %s", resp.Status, body) } - formattedPayloadBytes, err := json.Marshal(m) + str, err := io.ReadAll(resp.Body) if err != nil { - return "", err + return logError(logger, "can't read body: %v", err) } - resultPayload = bytes.NewReader(formattedPayloadBytes) - } + return string(str), nil + }) if err != nil { - return logError(logger, "unable to make signed payload: %v", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", makeURL(currEndpoint), resultPayload) - req.Header.Set("Content-Type", "application/json") - if err != nil { - return logError(logger, "request setup failed for endpoint %q: %v", currEndpoint, err) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return logError(logger, "request failed for endpoint %q: %v", currEndpoint, err) - } - if resp.StatusCode != 200 { - body, _ := io.ReadAll(resp.Body) - return logError(logger, "bad status code: %v: %s", resp.Status, body) - } - str, err := io.ReadAll(resp.Body) - if err != nil { - return logError(logger, "can't read body: %v", err) + return err } - return string(str), nil - }) - if err != nil { - return err } + return nil + } + + err = registerEndpoints(txEndpoints, createPayloadSigned) + if err != nil { + return err + } + err = registerEndpoints(queryEndpoints, createUnsignedPayload) + if err != nil { + return err } return nil }