Skip to content

Commit

Permalink
ip parsing, cli flags, cloudflare proxy support
Browse files Browse the repository at this point in the history
  • Loading branch information
kelindi committed Feb 3, 2025
1 parent a6d63c6 commit 68a459d
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 74 deletions.
13 changes: 12 additions & 1 deletion pkg/options/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package options

import (
"fmt"

"net"
"github.com/lilypad-tech/lilypad/pkg/http"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -69,6 +69,10 @@ func AddServerCliFlags(cmd *cobra.Command, serverOptions *http.ServerOptions) {
&serverOptions.RateLimiter.WindowLength, "server-rate-window-length", serverOptions.RateLimiter.WindowLength,
`The time window over which to limit in seconds (SERVER_RATE_WINDOW_LENGTH).`,
)
cmd.PersistentFlags().StringArrayVar(
&serverOptions.RateLimiter.ExemptedIPs, "server-rate-exempted-ips", serverOptions.RateLimiter.ExemptedIPs,
`The IPs to exempt from rate limiting (SERVER_RATE_EXEMPTED_IPS).`,
)
}

func CheckServerOptions(options http.ServerOptions) error {
Expand All @@ -81,5 +85,12 @@ func CheckServerOptions(options http.ServerOptions) error {
if options.AccessControl.ValidationTokenKid == "" {
return fmt.Errorf("SERVER_VALIDATION_TOKEN_KID is required")
}
if len(options.RateLimiter.ExemptedIPs) > 0 {
for _, ip := range options.RateLimiter.ExemptedIPs {
if net.ParseIP(ip) == nil {
return fmt.Errorf("invalid IP address: %s", ip)
}
}
}
return nil
}
171 changes: 98 additions & 73 deletions pkg/solver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
"errors"
"fmt"
"io"
corehttp "net/http"
"net"
corehttp "net/http"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -52,35 +52,16 @@ func NewSolverServer(
}

/*
*
*
*
*
*
*
Routes
# Routes
*
*
*
*
*
*
*/

func exemptIPKeyFunc(exemptIPs []string) func(r *corehttp.Request) (string, error) {
return func(r *corehttp.Request) (string, error) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
ip = r.RemoteAddr
}


for _, exemptIP := range exemptIPs {
if ip == exemptIP {
return "", nil
}
}

return ip, nil
}
}

func (solverServer *solverServer) ListenAndServe(ctx context.Context, cm *system.CleanupManager, tracerProvider *trace.TracerProvider) error {
router := mux.NewRouter()

Expand All @@ -89,7 +70,6 @@ func (solverServer *solverServer) ListenAndServe(ctx context.Context, cm *system
subrouter.Use(http.CorsMiddleware)
subrouter.Use(otelmux.Middleware("solver", otelmux.WithTracerProvider(tracerProvider)))


exemptIPs := solverServer.options.RateLimiter.ExemptedIPs

log.Debug().Strs("exemptIPs", exemptIPs).Msg("Loaded rate limit exemptions")
Expand Down Expand Up @@ -208,6 +188,51 @@ func (solverServer *solverServer) disconnectCB(connParams http.WSConnectionParam
}
}

func exemptIPKeyFunc(exemptIPs []string) func(r *corehttp.Request) (string, error) {
return func(r *corehttp.Request) (string, error) {

ip, err := getRealIPWithCloudflare(r)

if err != nil {
log.Error().Err(err).Msgf("error getting real ip")
return httprate.KeyByEndpoint(r)
}

for _, exemptIP := range exemptIPs {
if ip == canonicalizeIP(exemptIP) {
return "", nil
}
}

return httprate.KeyByEndpoint(r)
}
}

func getRealIPWithCloudflare(r *corehttp.Request) (string, error) {
if ip := r.Header.Get("CF-Connecting-IP"); ip != "" {
return canonicalizeIP(ip), nil
}

return httprate.KeyByRealIP(r)
}

func canonicalizeIP(ip string) string {
for _, c := range ip {
switch c {
case '.':
return ip
case ':':
parsed := net.ParseIP(ip)
if parsed == nil {
return ip
}
return parsed.Mask(net.CIDRMask(64, 128)).String()
}
}

return ip
}

/*
*
*
Expand Down Expand Up @@ -655,51 +680,51 @@ func (solverServer *solverServer) jobOfferDownloadFiles(res corehttp.ResponseWri
}

func (solverServer *solverServer) handleFileDownload(dirPath string, res corehttp.ResponseWriter) *http.HTTPError {
// Read directory contents
files, err := os.ReadDir(dirPath)
if err != nil {
return &http.HTTPError{
Message: fmt.Sprintf("error reading directory: %s", err.Error()),
StatusCode: corehttp.StatusNotFound,
}
}

// Find the first regular file
var targetFile os.DirEntry
for _, file := range files {
info, err := file.Info()
if err != nil {
continue
}
if info.Mode().IsRegular() {
targetFile = file
break
}
}

if targetFile == nil {
return &http.HTTPError{
Message: "no regular files found in directory",
StatusCode: corehttp.StatusNotFound,
}
}

// Get the actual filename and path
filename := targetFile.Name()
filePath := filepath.Join(dirPath, filename)

// Open and serve the file
file, err := os.Open(filePath)
if err != nil {
return &http.HTTPError{
Message: err.Error(),
StatusCode: corehttp.StatusInternalServerError,
}
}
defer file.Close()

res.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename))
res.Header().Set("Content-Type", "application/x-tar")
// Read directory contents
files, err := os.ReadDir(dirPath)
if err != nil {
return &http.HTTPError{
Message: fmt.Sprintf("error reading directory: %s", err.Error()),
StatusCode: corehttp.StatusNotFound,
}
}

// Find the first regular file
var targetFile os.DirEntry
for _, file := range files {
info, err := file.Info()
if err != nil {
continue
}
if info.Mode().IsRegular() {
targetFile = file
break
}
}

if targetFile == nil {
return &http.HTTPError{
Message: "no regular files found in directory",
StatusCode: corehttp.StatusNotFound,
}
}

// Get the actual filename and path
filename := targetFile.Name()
filePath := filepath.Join(dirPath, filename)

// Open and serve the file
file, err := os.Open(filePath)
if err != nil {
return &http.HTTPError{
Message: err.Error(),
StatusCode: corehttp.StatusInternalServerError,
}
}
defer file.Close()

res.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", filename))
res.Header().Set("Content-Type", "application/x-tar")

_, err = io.Copy(res, file)
if err != nil {
Expand Down

0 comments on commit 68a459d

Please sign in to comment.