diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4921ce6e..5574f5a9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -53,7 +53,7 @@ jobs: - name: Start solver run: | ./stack solver --disable-telemetry=true --api-host="" > solver.log & - sleep 5 + sleep 20 - name: Run solver integration tests run: ./stack integration-tests-solver diff --git a/.local.dev b/.local.dev index 46b96a84..d0fab552 100644 --- a/.local.dev +++ b/.local.dev @@ -30,3 +30,4 @@ WEB3_TOKEN_ADDRESS_=0xa513E6E4b8f2a923D98304ec87F64353C4D5C853 WEB3_USERS_ADDRESS=0x0DCd1Bf9A1b36cE34237eEaFef220932846BCD82 BACALHAU_API_HOST=localhost BACALHAU_API_PORT=1234 +SERVER_RATE_EXEMPTED_IPS=127.0.0.1,::1 diff --git a/pkg/http/types.go b/pkg/http/types.go index 126165a9..b8b2db94 100644 --- a/pkg/http/types.go +++ b/pkg/http/types.go @@ -22,6 +22,7 @@ type ValidationToken struct { type RateLimiterOptions struct { RequestLimit int WindowLength int + ExemptedIPs []string } type ClientOptions struct { diff --git a/pkg/http/utils.go b/pkg/http/utils.go index 17aa2e73..9696ee86 100644 --- a/pkg/http/utils.go +++ b/pkg/http/utils.go @@ -9,6 +9,7 @@ import ( "fmt" "io" stdlog "log" + "net" "net/http" "net/url" "strings" @@ -443,3 +444,30 @@ func newRetryClient() *retryablehttp.Client { } return retryClient } + +func CanonicalizeIP(ip string) string { + isIPv6 := false + // This is how net.ParseIP decides if an address is IPv6 + // https://cs.opensource.google/go/go/+/refs/tags/go1.17.7:src/net/ip.go;l=704 + for i := 0; !isIPv6 && i < len(ip); i++ { + switch ip[i] { + case '.': + // IPv4 + return ip + case ':': + // IPv6 + isIPv6 = true + break + } + } + if !isIPv6 { + // Not an IP address at all + return ip + } + + ipv6 := net.ParseIP(ip) + if ipv6 == nil { + return ip + } + return ipv6.Mask(net.CIDRMask(64, 128)).String() +} diff --git a/pkg/jobcreator/controller.go b/pkg/jobcreator/controller.go index 6ea21dd2..ea6b08cc 100644 --- a/pkg/jobcreator/controller.go +++ b/pkg/jobcreator/controller.go @@ -387,19 +387,18 @@ func (controller *JobCreatorController) downloadResult(dealContainer data.DealCo func (controller *JobCreatorController) acceptResult(deal data.DealContainer) error { controller.log.Debug("Accepting results for job", deal.ID) - txHash, err := controller.web3SDK.AcceptResult(deal.ID) - if err != nil { - return fmt.Errorf("error calling accept result tx for deal: %s", err.Error()) - } - controller.log.Debug("accept result tx", txHash) + // !TODO MAINNET: get txHash for accepting results on chain or get this flow + txHash := "0x" // we have agreed to the deal so we need to update the tx in the solver - _, err = controller.solverClient.UpdateTransactionsJobCreator(deal.ID, data.DealTransactionsJobCreator{ + _, err := controller.solverClient.UpdateTransactionsJobCreator(deal.ID, data.DealTransactionsJobCreator{ AcceptResult: txHash, }) if err != nil { return fmt.Errorf("error adding AcceptResult tx hash for deal: %s", err.Error()) } + + return nil } diff --git a/pkg/options/server.go b/pkg/options/server.go index 84ab8d23..448b90f7 100644 --- a/pkg/options/server.go +++ b/pkg/options/server.go @@ -2,7 +2,7 @@ package options import ( "fmt" - + "net" "github.com/lilypad-tech/lilypad/pkg/http" "github.com/spf13/cobra" ) @@ -31,6 +31,7 @@ func GetDefaultRateLimiterOptions() http.RateLimiterOptions { return http.RateLimiterOptions{ RequestLimit: GetDefaultServeOptionInt("SERVER_RATE_REQUEST_LIMIT", 5), WindowLength: GetDefaultServeOptionInt("SERVER_RATE_WINDOW_LENGTH", 10), + ExemptedIPs: GetDefaultServeOptionStringArray("SERVER_RATE_EXEMPTED_IPS", []string{}), } } @@ -75,6 +76,10 @@ func AddServerCliFlags(cmd *cobra.Command, serverOptions *http.ServerOptions) { &serverOptions.RateLimiter.WindowLength, "server-rate-window-length", serverOptions.RateLimiter.WindowLength, `The time window over which to limit in seconds (SERVER_RATE_WINDOW_LENGTH).`, ) + cmd.PersistentFlags().StringArrayVar( + &serverOptions.RateLimiter.ExemptedIPs, "server-rate-exempted-ips", serverOptions.RateLimiter.ExemptedIPs, + `The IPs to exempt from rate limiting (SERVER_RATE_EXEMPTED_IPS).`, + ) } func CheckServerOptions(options http.ServerOptions, storeType string) error { @@ -90,5 +95,12 @@ func CheckServerOptions(options http.ServerOptions, storeType string) error { if options.AccessControl.ValidationTokenKid == "" { return fmt.Errorf("SERVER_VALIDATION_TOKEN_KID is required") } + if len(options.RateLimiter.ExemptedIPs) > 0 { + for _, ip := range options.RateLimiter.ExemptedIPs { + if net.ParseIP(ip) == nil { + return fmt.Errorf("invalid IP address: %s", ip) + } + } + } return nil } diff --git a/pkg/resourceprovider/controller.go b/pkg/resourceprovider/controller.go index ab8f09f1..dd7f07a9 100644 --- a/pkg/resourceprovider/controller.go +++ b/pkg/resourceprovider/controller.go @@ -526,12 +526,8 @@ func (controller *ResourceProviderController) runJob(ctx context.Context, deal d span.AddEvent("solver.result.added", trace.WithAttributes(attribute.String("result.id", createdResult.ID))) span.AddEvent("chain.result.add") - txHash, err := controller.web3SDK.AddResult( - deal.Deal.ID, - createdResult.ID, - createdResult.DataID, - result.InstructionCount, - ) + // !TODO MAINNET: get txHash for submittinng results on chain + txHash := "0x" if err != nil { controller.log.Error("error calling add result tx for job", err) span.SetStatus(codes.Error, "add result to chain failed") diff --git a/pkg/solver/controller.go b/pkg/solver/controller.go index c5aea5c8..d45bcd7d 100644 --- a/pkg/solver/controller.go +++ b/pkg/solver/controller.go @@ -448,8 +448,12 @@ func (controller *SolverController) addDeal(ctx context.Context, deal data.Deal) controller.log.Info("add deal", deal) + //creates deal container and sets state to agreed + dealContainer := data.GetDealContainer(deal) + dealContainer.State = data.GetAgreementStateIndex("DealAgreed") + span.AddEvent("store.add_deal.start") - ret, err := controller.store.AddDeal(data.GetDealContainer(deal)) + ret, err := controller.store.AddDeal(dealContainer) if err != nil { span.SetStatus(codes.Error, "add deal to store failed") span.RecordError(err) @@ -592,6 +596,9 @@ func (controller *SolverController) updateDealTransactionsJobCreator(id string, if err != nil { return nil, err } + if payload.AcceptResult != "" { + return controller.updateDealState(id, data.GetAgreementStateIndex("ResultsAccepted")) + } controller.writeEvent(SolverEvent{ EventType: JobCreatorTransactionsUpdated, Deal: dealContainer, diff --git a/pkg/solver/ratelimit_test.go b/pkg/solver/ratelimit_test.go index f2976646..617255ea 100644 --- a/pkg/solver/ratelimit_test.go +++ b/pkg/solver/ratelimit_test.go @@ -4,8 +4,8 @@ package solver_test import ( "fmt" + "math/rand" "net/http" - "os" "sync" "testing" "time" @@ -17,9 +17,20 @@ type rateResult struct { limitedCount int } -// This test suite sends 100 requests over approximately half a second. +type rateTestCase struct { + name string + headers map[string]string + expectedOK int + expectedLimit int +} + +// This test suite sends 200 requests to three different paths. We send the +// requests in rate limited and exempt test groups. The rate limited group +// should allow 5/100 requests through and the exempt group should allow 100/100. +// // We assume the solver uses the default rate limiting settings with -// a request limit of 5 and window length of 10 seconds. +// a request limit of 5 and window length of 10 seconds. In addition, the solver +// should be configured to exempt localhost. func TestRateLimiter(t *testing.T) { paths := []string{ "/api/v1/resource_offers", @@ -27,45 +38,96 @@ func TestRateLimiter(t *testing.T) { "/api/v1/deals", } + // The solver should rate limit when forwarded + // headers are set to 1.2.3.4. + nonExemptHeaders := []map[string]string{ + {"True-Client-IP": "1.2.3.4"}, + {"X-Real-IP": "1.2.3.4"}, + {"X-Forwarded-For": "1.2.3.4"}, + } + + // The running solver is configured to exempt localhost. + // When no headers are set, test using the IP address from + // the underlying connection (also localhost) + // TODO: re-enable exempt IP rate limiting + // exemptHeaders := []map[string]string{ + // {"True-Client-IP": "127.0.0.1"}, + // {"X-Real-IP": "127.0.0.1"}, + // {"X-Forwarded-For": "127.0.0.1"}, + // {}, // No headers case - uses RemoteAddr + // } + + t.Run("non-exempt IP is rate limited", func(t *testing.T) { + // Select a random header on each test run. Over time we test them all. + headers := nonExemptHeaders[rand.Intn(len(nonExemptHeaders))] + tc := rateTestCase{ + name: fmt.Sprintf("rate limited with headers %v", headers), + headers: headers, + expectedOK: 5, + expectedLimit: 95, + } + runRateLimitTest(t, paths, tc) + }) + + // TODO: re-enable exempt IP rate limiting + // t.Run("exempt IP is not rate limited", func(t *testing.T) { + // // Select a random header on each test run. Over time we test them all. + // headers := exemptHeaders[rand.Intn(len(exemptHeaders))] + // tc := rateTestCase{ + // name: fmt.Sprintf("exempt with headers %v", headers), + // headers: headers, + // expectedOK: 100, + // expectedLimit: 0, + // } + // runRateLimitTest(t, paths, tc) + // }) +} + +func runRateLimitTest(t *testing.T, paths []string, tc rateTestCase) { var wg sync.WaitGroup ch := make(chan rateResult, len(paths)) - // Send off callers to run concurrently + // Run the calls against paths in parallel for _, path := range paths { wg.Add(1) - - go func() { + go func(path string) { defer wg.Done() - makeCalls(t, path, ch) - }() + makeCalls(t, path, ch, tc) + }(path) } wg.Wait() close(ch) - expectedOkCount := 5 for result := range ch { - if result.okCount > expectedOkCount { - t.Errorf( - "%s allowed %d requests and limited %d requests, but expected limiting after %d requests\n", - result.path, result.okCount, result.limitedCount, expectedOkCount, - ) + if result.okCount != tc.expectedOK { + t.Errorf("%s: Expected %d successful requests, got %d", + result.path, tc.expectedOK, result.okCount) + } + if result.limitedCount != tc.expectedLimit { + t.Errorf("%s: Expected %d rate limited requests, got %d", + result.path, tc.expectedLimit, result.limitedCount) } } } -func makeCalls(t *testing.T, path string, ch chan rateResult) { +func makeCalls(t *testing.T, path string, ch chan rateResult, tc rateTestCase) { var okCount int var limitedCount int + client := &http.Client{} + + for i := 0; i < 100; i++ { + req, _ := http.NewRequest("GET", fmt.Sprintf("http://localhost:%d%s", 8081, path), nil) - // Make 100 requests - for range 100 { - requestURL := fmt.Sprintf("http://localhost:%d%s", 8081, path) - res, err := http.Get(requestURL) + // Set test case headers + for key, value := range tc.headers { + req.Header.Set(key, value) + } + res, err := client.Do(req) if err != nil { - t.Errorf("Get request failed on %s: %s\n", path, err) - os.Exit(1) + t.Errorf("Request failed on %s: %s\n", path, err) + return } if res.StatusCode == 200 { @@ -76,7 +138,6 @@ func makeCalls(t *testing.T, path string, ch chan rateResult) { t.Errorf("Expected a 200 or 429 status code, but received a %d\n", res.StatusCode) } - // Wait before making next call time.Sleep(5 * time.Millisecond) } diff --git a/pkg/solver/server.go b/pkg/solver/server.go index b57e5a99..67adeb21 100644 --- a/pkg/solver/server.go +++ b/pkg/solver/server.go @@ -52,17 +52,16 @@ func NewSolverServer( } /* - * - * - * +* +* +* - Routes +# Routes - * - * - * +* +* +* */ - func (solverServer *solverServer) ListenAndServe(ctx context.Context, cm *system.CleanupManager, tracerProvider *trace.TracerProvider) error { router := mux.NewRouter() @@ -70,15 +69,41 @@ func (solverServer *solverServer) ListenAndServe(ctx context.Context, cm *system subrouter.Use(http.CorsMiddleware) subrouter.Use(otelmux.Middleware("solver", otelmux.WithTracerProvider(tracerProvider))) + + exemptIPs := solverServer.options.RateLimiter.ExemptedIPs + // TODO: re-enable exempt IP rate limiting + // subrouter.Use(httprate.Limit( + // solverServer.options.RateLimiter.RequestLimit, + // time.Duration(solverServer.options.RateLimiter.WindowLength)*time.Second, + // httprate.WithKeyFuncs( + // exemptIPKeyFunc(exemptIPs), + // httprate.KeyByEndpoint, + // ), + // httprate.WithLimitHandler(func(w corehttp.ResponseWriter, r *corehttp.Request) { + + // key, _ := exemptIPKeyFunc(exemptIPs)(r) + // if strings.HasPrefix(key, "exempt-") { + // return + // } + + // corehttp.Error(w, "Too Many Requests", corehttp.StatusTooManyRequests) + // }), + // )) + subrouter.Use(httprate.Limit( solverServer.options.RateLimiter.RequestLimit, time.Duration(solverServer.options.RateLimiter.WindowLength)*time.Second, httprate.WithKeyFuncs(httprate.KeyByRealIP, httprate.KeyByEndpoint), )) + log.Info().Strs("exemptIPs", exemptIPs).Msg("Loaded rate limit exemptions") + subrouter.HandleFunc("/job_offers", http.GetHandler(solverServer.getJobOffers)).Methods("GET") subrouter.HandleFunc("/job_offers", http.PostHandler(solverServer.addJobOffer)).Methods("POST") + subrouter.HandleFunc("/job_offers/{id}", http.GetHandler(solverServer.getJobOffer)).Methods("GET") + subrouter.HandleFunc("/job_offers/{id}/files", solverServer.jobOfferDownloadFiles).Methods("GET") + subrouter.HandleFunc("/resource_offers", http.GetHandler(solverServer.getResourceOffers)).Methods("GET") subrouter.HandleFunc("/resource_offers", http.PostHandler(solverServer.addResourceOffer)).Methods("POST") @@ -178,6 +203,25 @@ func (solverServer *solverServer) disconnectCB(connParams http.WSConnectionParam } } +func exemptIPKeyFunc(exemptIPs []string) func(r *corehttp.Request) (string, error) { + return func(r *corehttp.Request) (string, error) { + ip, err := httprate.KeyByRealIP(r) + if err != nil { + log.Error().Err(err).Msg("error getting real ip") + return "", err + } + + // Check if the IP is in the exempt list + for _, exemptIP := range exemptIPs { + if http.CanonicalizeIP(exemptIP) == ip { + return "exempt-" + ip, nil + } + } + + return ip, nil + } +} + /* * * @@ -245,6 +289,17 @@ func (solverServer *solverServer) getDeals(res corehttp.ResponseWriter, req *cor * * */ + +func (solverServer *solverServer) getJobOffer(res corehttp.ResponseWriter, req *corehttp.Request) (data.JobOfferContainer, error) { + vars := mux.Vars(req) + id := vars["id"] + jobOffer, err := solverServer.store.GetJobOffer(id) + if err != nil { + return data.JobOfferContainer{}, err + } + return *jobOffer, nil +} + func (solverServer *solverServer) getDeal(res corehttp.ResponseWriter, req *corehttp.Request) (data.DealContainer, error) { vars := mux.Vars(req) id := vars["id"] @@ -362,7 +417,18 @@ func (solverServer *solverServer) addResult(results data.Result, res corehttp.Re return nil, err } results.DealID = id - return solverServer.store.AddResult(results) + + storedResult, err := solverServer.store.AddResult(results) + if err != nil { + return nil, err + } + + err = solverServer.updateJobStates(id, "ResultsSubmitted") + if err != nil { + return nil, err + } + + return storedResult, nil } /* @@ -496,67 +562,7 @@ func (solverServer *solverServer) downloadFiles(res corehttp.ResponseWriter, req StatusCode: corehttp.StatusUnauthorized, } } - - // Get the directory path - dirPath := GetDealsFilePath(id) - - // Read directory contents - files, err := os.ReadDir(dirPath) - if err != nil { - return &http.HTTPError{ - Message: fmt.Sprintf("error reading directory: %s", err.Error()), - StatusCode: corehttp.StatusNotFound, - } - } - - // Find the first regular file - var targetFile os.DirEntry - for _, file := range files { - info, err := file.Info() - if err != nil { - continue - } - if info.Mode().IsRegular() { - targetFile = file - break - } - } - - if targetFile == nil { - return &http.HTTPError{ - Message: "no regular files found in directory", - StatusCode: corehttp.StatusNotFound, - } - } - - // Get the actual filename - filename := targetFile.Name() - filePath := filepath.Join(dirPath, filename) - - // Open the file - file, err := os.Open(filePath) - if err != nil { - return &http.HTTPError{ - Message: err.Error(), - StatusCode: corehttp.StatusInternalServerError, - } - } - defer file.Close() - - // Set appropriate headers using the actual filename - res.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename)) - res.Header().Set("Content-Type", "application/x-tar") - - // Copy the file directly to the response - _, err = io.Copy(res, file) - if err != nil { - return &http.HTTPError{ - Message: err.Error(), - StatusCode: corehttp.StatusInternalServerError, - } - } - - return nil + return solverServer.handleFileDownload(GetDealsFilePath(deal.ID), res) }() if err != nil { @@ -641,6 +647,112 @@ func (solverServer *solverServer) uploadFiles(res corehttp.ResponseWriter, req * } } +func (solverServer *solverServer) jobOfferDownloadFiles(res corehttp.ResponseWriter, req *corehttp.Request) { + vars := mux.Vars(req) + id := vars["id"] + + err := func() *http.HTTPError { + jobOffer, err := solverServer.store.GetJobOffer(id) + if err != nil { + log.Error().Err(err).Msgf("error loading job offer") + return &http.HTTPError{ + Message: err.Error(), + StatusCode: corehttp.StatusInternalServerError, + } + } + if jobOffer == nil { + return &http.HTTPError{ + Message: err.Error(), + StatusCode: corehttp.StatusNotFound, + } + } + + signerAddress, err := http.CheckSignature(req) + if err != nil { + log.Error().Err(err).Msgf("error checking signature") + return &http.HTTPError{ + Message: errors.New("not authorized").Error(), + StatusCode: corehttp.StatusUnauthorized, + } + } + + if signerAddress != jobOffer.JobCreator { + log.Error().Err(err).Msgf("job creator address does not match signer address") + return &http.HTTPError{ + Message: errors.New("not authorized").Error(), + StatusCode: corehttp.StatusUnauthorized, + } + } + + solverServer.updateJobStates(jobOffer.DealID, "ResultsAccepted") + + return solverServer.handleFileDownload(GetDealsFilePath(jobOffer.DealID), res) + }() + + if err != nil { + log.Ctx(req.Context()).Error().Msgf("error for route: %s", err.Error()) + corehttp.Error(res, err.Error(), err.StatusCode) + } +} + +func (solverServer *solverServer) handleFileDownload(dirPath string, res corehttp.ResponseWriter) *http.HTTPError { + // Read directory contents + files, err := os.ReadDir(dirPath) + if err != nil { + return &http.HTTPError{ + Message: fmt.Sprintf("error reading directory: %s", err.Error()), + StatusCode: corehttp.StatusNotFound, + } + } + + // Find the first regular file + var targetFile os.DirEntry + for _, file := range files { + info, err := file.Info() + if err != nil { + continue + } + if info.Mode().IsRegular() { + targetFile = file + break + } + } + + if targetFile == nil { + return &http.HTTPError{ + Message: "no regular files found in directory", + StatusCode: corehttp.StatusNotFound, + } + } + + // Get the actual filename and path + filename := targetFile.Name() + filePath := filepath.Join(dirPath, filename) + + // Open and serve the file + file, err := os.Open(filePath) + if err != nil { + return &http.HTTPError{ + Message: err.Error(), + StatusCode: corehttp.StatusInternalServerError, + } + } + defer file.Close() + + res.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename)) + res.Header().Set("Content-Type", "application/x-tar") + + _, err = io.Copy(res, file) + if err != nil { + return &http.HTTPError{ + Message: err.Error(), + StatusCode: corehttp.StatusInternalServerError, + } + } + + return nil +} + // Validation Service func (solverServer *solverServer) getValidationToken(res corehttp.ResponseWriter, req *corehttp.Request) (*http.ValidationToken, error) { @@ -676,3 +788,32 @@ func (solverServer *solverServer) getValidationToken(res corehttp.ResponseWriter // Respond with the JWT return &http.ValidationToken{JWT: tokenString}, nil } + +func (solverServer *solverServer) updateJobStates(dealID string, state string) error { + deal, err := solverServer.store.GetDeal(dealID) + if err != nil { + return err + } + + _, err = solverServer.controller.updateDealState(deal.Deal.ID, data.GetAgreementStateIndex(state)) + if err != nil { + return err + } + // update the job offer state + _, err = solverServer.controller.updateJobOfferState(deal.Deal.JobOffer.ID, deal.ID, data.GetAgreementStateIndex(state)) + if err != nil { + return err + } + // update the resource offer state + _, err = solverServer.controller.updateResourceOfferState(deal.Deal.ResourceOffer.ID, deal.ID, data.GetAgreementStateIndex(state)) + if err != nil { + return err + } + + solverServer.controller.writeEvent(SolverEvent{ + EventType: DealStateUpdated, + Deal: deal, + }) + + return nil +}