From e0db556dd49287e2463a7d764aa35e89a8edd797 Mon Sep 17 00:00:00 2001 From: Roman Zavodskikh Date: Tue, 26 Sep 2023 17:59:48 +0200 Subject: [PATCH] Decouple fadeIn from loadbalancer Signed-off-by: Roman Zavodskikh --- loadbalancer/algorithm.go | 170 ++----------- loadbalancer/algorithm_test.go | 2 +- loadbalancer/fadein_test.go | 388 ----------------------------- loadbalancer/locked_source.go | 2 +- loadbalancer/locked_source_test.go | 2 +- proxy/fadein_internal_test.go | 281 +++++++++++++++++++++ proxy/proxy.go | 57 ++++- routing/routing.go | 2 +- skptesting/run_fadein_test.sh | 4 +- 9 files changed, 353 insertions(+), 555 deletions(-) delete mode 100644 loadbalancer/fadein_test.go create mode 100644 proxy/fadein_internal_test.go diff --git a/loadbalancer/algorithm.go b/loadbalancer/algorithm.go index 3fd60a88c2..e57d5efda7 100644 --- a/loadbalancer/algorithm.go +++ b/loadbalancer/algorithm.go @@ -3,7 +3,6 @@ package loadbalancer import ( "errors" "fmt" - "math" "math/rand" "net" "net/url" @@ -11,7 +10,6 @@ import ( "strings" "sync" "sync/atomic" - "time" "github.com/cespare/xxhash/v2" log "github.com/sirupsen/logrus" @@ -56,110 +54,14 @@ var ( defaultAlgorithm = newRoundRobin ) -func fadeInState(now time.Time, duration time.Duration, detected time.Time) (time.Duration, bool) { - rel := now.Sub(detected) - return rel, rel > 0 && rel < duration -} - -func fadeIn(now time.Time, duration time.Duration, exponent float64, detected time.Time) float64 { - rel, fadingIn := fadeInState(now, duration, detected) - if !fadingIn { - return 1 - } - - return math.Pow(float64(rel)/float64(duration), exponent) -} - -func shiftWeighted(rnd *rand.Rand, ctx *routing.LBContext, now time.Time) routing.LBEndpoint { - var sum float64 - rt := ctx.Route - ep := ctx.LBEndpoints - for _, epi := range ep { - detected := ctx.Registry.GetMetrics(epi.Host).DetectedTime() - wi := fadeIn(now, rt.LBFadeInDuration, rt.LBFadeInExponent, detected) - sum += wi - } - - choice := ep[len(ep)-1] - r := rnd.Float64() * sum - var upto float64 - for i, epi := range ep { - detected := ctx.Registry.GetMetrics(epi.Host).DetectedTime() - upto += fadeIn(now, rt.LBFadeInDuration, rt.LBFadeInExponent, detected) - if upto > r { - choice = ep[i] - break - } - } - - return choice -} - -func shiftToRemaining(rnd *rand.Rand, ctx *routing.LBContext, wi []int, now time.Time) routing.LBEndpoint { - notFadingIndexes := wi - ep := ctx.LBEndpoints - - // if all endpoints are fading, the simplest approach is to use the oldest, - // this departs from the desired curve, but guarantees monotonic fade-in. From - // the perspective of the oldest endpoint, this is temporarily the same as if - // there was no fade-in. - if len(notFadingIndexes) == 0 { - return shiftWeighted(rnd, ctx, now) - } - - // otherwise equally distribute between the old endpoints - return ep[notFadingIndexes[rnd.Intn(len(notFadingIndexes))]] -} - -func withFadeIn(rnd *rand.Rand, ctx *routing.LBContext, choice int, algo routing.LBAlgorithm) routing.LBEndpoint { - ep := ctx.LBEndpoints - now := time.Now() - detected := ctx.Registry.GetMetrics(ctx.LBEndpoints[choice].Host).DetectedTime() - f := fadeIn( - now, - ctx.Route.LBFadeInDuration, - ctx.Route.LBFadeInExponent, - detected, - ) - - if rnd.Float64() < f { - return ep[choice] - } - notFadingIndexes := make([]int, 0, len(ep)) - for i := 0; i < len(ep); i++ { - detected := ctx.Registry.GetMetrics(ep[i].Host).DetectedTime() - if _, fadingIn := fadeInState(now, ctx.Route.LBFadeInDuration, detected); !fadingIn { - notFadingIndexes = append(notFadingIndexes, i) - } - } - - switch a := algo.(type) { - case *roundRobin: - return shiftToRemaining(a.rnd, ctx, notFadingIndexes, now) - case *random: - return shiftToRemaining(a.rnd, ctx, notFadingIndexes, now) - case *consistentHash: - // If all endpoints are fading, normal consistent hash result - if len(notFadingIndexes) == 0 { - return ep[choice] - } - // otherwise calculate consistent hash again using endpoints which are not fading - return ep[a.chooseConsistentHashEndpoint(ctx, skipFadingEndpoints(notFadingIndexes))] - default: - return ep[choice] - } -} - type roundRobin struct { index int64 - rnd *rand.Rand } func newRoundRobin(endpoints []string) routing.LBAlgorithm { - rnd := rand.New(newLockedSource()) // #nosec + rnd := rand.New(NewLockedSource()) // #nosec return &roundRobin{ index: int64(rnd.Intn(len(endpoints))), - rnd: rnd, } } @@ -169,13 +71,8 @@ func (r *roundRobin) Apply(ctx *routing.LBContext) routing.LBEndpoint { return ctx.LBEndpoints[0] } - index := int(atomic.AddInt64(&r.index, 1) % int64(len(ctx.LBEndpoints))) - - if ctx.Route.LBFadeInDuration <= 0 { - return ctx.LBEndpoints[index] - } - - return withFadeIn(r.rnd, ctx, index, r) + choice := int(atomic.AddInt64(&r.index, 1) % int64(len(ctx.LBEndpoints))) + return ctx.LBEndpoints[choice] } type random struct { @@ -185,7 +82,7 @@ type random struct { func newRandom(endpoints []string) routing.LBAlgorithm { // #nosec return &random{ - rnd: rand.New(newLockedSource()), + rnd: rand.New(NewLockedSource()), } } @@ -195,12 +92,8 @@ func (r *random) Apply(ctx *routing.LBContext) routing.LBEndpoint { return ctx.LBEndpoints[0] } - i := r.rnd.Intn(len(ctx.LBEndpoints)) - if ctx.Route.LBFadeInDuration <= 0 { - return ctx.LBEndpoints[i] - } - - return withFadeIn(r.rnd, ctx, i, r) + choice := r.rnd.Intn(len(ctx.LBEndpoints)) + return ctx.LBEndpoints[choice] } type ( @@ -210,7 +103,6 @@ type ( } consistentHash struct { hashRing []endpointHash // list of endpoints sorted by hash value - rnd *rand.Rand } ) @@ -221,10 +113,8 @@ func (ch *consistentHash) Swap(i, j int) { } func newConsistentHashInternal(endpoints []string, hashesPerEndpoint int) routing.LBAlgorithm { - rnd := rand.New(newLockedSource()) // #nosec ch := &consistentHash{ hashRing: make([]endpointHash, hashesPerEndpoint*len(endpoints)), - rnd: rnd, } for i, ep := range endpoints { endpointStartIndex := hashesPerEndpoint * i @@ -245,21 +135,18 @@ func hash(s string) uint64 { } // Returns index in hash ring with the closest hash to key's hash -func (ch *consistentHash) searchRing(key string, skipEndpoint func(int) bool) int { +func (ch *consistentHash) searchRing(key string) int { h := hash(key) i := sort.Search(ch.Len(), func(i int) bool { return ch.hashRing[i].hash >= h }) if i == ch.Len() { // rollover i = 0 } - for skipEndpoint(ch.hashRing[i].index) { - i = (i + 1) % ch.Len() - } return i } // Returns index of endpoint with closest hash to key's hash -func (ch *consistentHash) search(key string, skipEndpoint func(int) bool) int { - ringIndex := ch.searchRing(key, skipEndpoint) +func (ch *consistentHash) search(key string) int { + ringIndex := ch.searchRing(key) return ch.hashRing[ringIndex].index } @@ -274,16 +161,13 @@ func computeLoadAverage(ctx *routing.LBContext) float64 { // Returns index of endpoint with closest hash to key's hash, which is also below the target load // skipEndpoint function is used to skip endpoints we don't want, such as fading endpoints -func (ch *consistentHash) boundedLoadSearch(key string, balanceFactor float64, ctx *routing.LBContext, skipEndpoint func(int) bool) int { - ringIndex := ch.searchRing(key, skipEndpoint) +func (ch *consistentHash) boundedLoadSearch(key string, balanceFactor float64, ctx *routing.LBContext) int { + ringIndex := ch.searchRing(key) averageLoad := computeLoadAverage(ctx) targetLoad := averageLoad * balanceFactor // Loop round ring, starting at endpoint with closest hash. Stop when we find one whose load is less than targetLoad. for i := 0; i < ch.Len(); i++ { endpointIndex := ch.hashRing[ringIndex].index - if skipEndpoint(endpointIndex) { - continue - } load := ctx.Registry.GetMetrics(ctx.LBEndpoints[endpointIndex].Host).InflightRequests() // We know there must be an endpoint whose load <= average load. // Since targetLoad >= average load (balancerFactor >= 1), there must also be an endpoint with load <= targetLoad. @@ -302,16 +186,11 @@ func (ch *consistentHash) Apply(ctx *routing.LBContext) routing.LBEndpoint { return ctx.LBEndpoints[0] } - choice := ch.chooseConsistentHashEndpoint(ctx, noSkippedEndpoints) - - if ctx.Route.LBFadeInDuration <= 0 { - return ctx.LBEndpoints[choice] - } - - return withFadeIn(ch.rnd, ctx, choice, ch) + choice := ch.chooseConsistentHashEndpoint(ctx) + return ctx.LBEndpoints[choice] } -func (ch *consistentHash) chooseConsistentHashEndpoint(ctx *routing.LBContext, skipEndpoint func(int) bool) int { +func (ch *consistentHash) chooseConsistentHashEndpoint(ctx *routing.LBContext) int { key, ok := ctx.Params[ConsistentHashKey].(string) if !ok { key = snet.RemoteHost(ctx.Request).String() @@ -319,29 +198,14 @@ func (ch *consistentHash) chooseConsistentHashEndpoint(ctx *routing.LBContext, s balanceFactor, ok := ctx.Params[ConsistentHashBalanceFactor].(float64) var choice int if !ok { - choice = ch.search(key, skipEndpoint) + choice = ch.search(key) } else { - choice = ch.boundedLoadSearch(key, balanceFactor, ctx, skipEndpoint) + choice = ch.boundedLoadSearch(key, balanceFactor, ctx) } return choice } -func skipFadingEndpoints(notFadingEndpoints []int) func(int) bool { - return func(i int) bool { - for _, notFadingEndpoint := range notFadingEndpoints { - if i == notFadingEndpoint { - return false - } - } - return true - } -} - -func noSkippedEndpoints(_ int) bool { - return false -} - type powerOfRandomNChoices struct { mx sync.Mutex rnd *rand.Rand @@ -350,7 +214,7 @@ type powerOfRandomNChoices struct { // newPowerOfRandomNChoices selects N random backends and picks the one with less outstanding requests. func newPowerOfRandomNChoices([]string) routing.LBAlgorithm { - rnd := rand.New(newLockedSource()) // #nosec + rnd := rand.New(NewLockedSource()) // #nosec return &powerOfRandomNChoices{ rnd: rnd, numberOfChoices: powerOfRandomNChoicesDefaultN, diff --git a/loadbalancer/algorithm_test.go b/loadbalancer/algorithm_test.go index 3db8e1fc19..9d4a54e68b 100644 --- a/loadbalancer/algorithm_test.go +++ b/loadbalancer/algorithm_test.go @@ -281,7 +281,7 @@ func TestApply(t *testing.T) { func TestConsistentHashSearch(t *testing.T) { apply := func(key string, endpoints []string) string { ch := newConsistentHash(endpoints).(*consistentHash) - return endpoints[ch.search(key, noSkippedEndpoints)] + return endpoints[ch.search(key)] } endpoints := []string{"http://127.0.0.1:8080", "http://127.0.0.2:8080", "http://127.0.0.3:8080"} diff --git a/loadbalancer/fadein_test.go b/loadbalancer/fadein_test.go deleted file mode 100644 index c05311fe23..0000000000 --- a/loadbalancer/fadein_test.go +++ /dev/null @@ -1,388 +0,0 @@ -package loadbalancer - -import ( - "fmt" - "math/rand" - "strconv" - "strings" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/zalando/skipper/routing" -) - -const ( - fadeInDuration = 500 * time.Millisecond - fadeInDurationHuge = 24 * time.Hour // we need this to be sure we're at the very beginning of fading in - bucketCount = 20 - monotonyTolerance = 0.4 // we need to use a high tolerance for CI testing -) - -func absint(i int) int { - if i < 0 { - return -i - } - - return i -} - -func tolerance(prev, next int) int { - return int(float64(prev+next) * monotonyTolerance / 2) -} - -func checkMonotony(direction, prev, next int) bool { - t := tolerance(prev, next) - switch direction { - case 1: - return next-prev >= -t - case -1: - return next-prev <= t - default: - return absint(next-prev) < t - } -} - -func initializeEndpoints(endpointAges []time.Duration, fadeInDuration time.Duration) (*routing.LBContext, []string) { - var detectionTimes []time.Time - now := time.Now() - for _, ea := range endpointAges { - detectionTimes = append(detectionTimes, now.Add(-ea)) - } - - var eps []string - for i := range endpointAges { - eps = append(eps, fmt.Sprintf("ep-%d-%s.test", i, endpointAges[i])) - } - - ctx := &routing.LBContext{ - Params: map[string]interface{}{}, - Route: &routing.Route{ - LBFadeInDuration: fadeInDuration, - LBFadeInExponent: 1, - }, - Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}), - } - - for i := range eps { - ctx.Route.LBEndpoints = append(ctx.Route.LBEndpoints, routing.LBEndpoint{ - Host: eps[i], - Detected: detectionTimes[i], - }) - ctx.Registry.SetDetectedTime(eps[i], detectionTimes[i]) - } - ctx.LBEndpoints = ctx.Route.LBEndpoints - - return ctx, eps -} - -func testFadeIn( - t *testing.T, - name string, - algorithm func([]string) routing.LBAlgorithm, - endpointAges ...time.Duration, -) { - t.Run(name, func(t *testing.T) { - ctx, eps := initializeEndpoints(endpointAges, fadeInDuration) - - a := algorithm(eps) - rnd := rand.New(rand.NewSource(time.Now().UnixNano())) - t.Log("test start", time.Now()) - var stats []string - stop := time.After(fadeInDuration) - // Emulate the load balancer loop, sending requests to it with random hash keys - // over and over again till fadeIn period is over. - func() { - for { - ctx.Params[ConsistentHashKey] = strconv.Itoa(rnd.Intn(100000)) - ep := a.Apply(ctx) - stats = append(stats, ep.Host) - select { - case <-stop: - return - default: - } - } - }() - - // Split fade-in period into buckets and count how many times each endpoint was selected. - t.Log("test done", time.Now()) - t.Log("CSV timestamp," + strings.Join(eps, ",")) - bucketSize := len(stats) / bucketCount - var allBuckets []map[string]int - for i := 0; i < bucketCount; i++ { - bucketStats := make(map[string]int) - for j := i * bucketSize; j < (i+1)*bucketSize; j++ { - bucketStats[stats[j]]++ - } - - allBuckets = append(allBuckets, bucketStats) - } - - directions := make(map[string]int) - for _, epi := range eps { - first := allBuckets[0][epi] - last := allBuckets[len(allBuckets)-1][epi] - t := tolerance(first, last) - switch { - case last-first > t: - directions[epi] = 1 - case last-first < t: - directions[epi] = -1 - } - } - - for i := range allBuckets { - // trim first and last (warmup and settling) - if i < 2 || i == len(allBuckets)-1 { - continue - } - - for _, epi := range eps { - if !checkMonotony( - directions[epi], - allBuckets[i-1][epi], - allBuckets[i][epi], - ) { - t.Error("non-monotonic change", epi, i) - } - } - } - - for i, bucketStats := range allBuckets { - var showStats []string - for _, epi := range eps { - showStats = append(showStats, fmt.Sprintf("%d", bucketStats[epi])) - } - - // Print CSV-like output for, where row number represents time and - // column represents endpoint. You can visualize it using - // ./skptesting/run_fadein_test.sh from the skipper repo root. - t.Log("CSV " + fmt.Sprintf("%d,", i) + strings.Join(showStats, ",")) - } - }) -} - -func newConsistentHashForTest(endpoints []string) routing.LBAlgorithm { - // The default parameter 100 is too small to get even distribution - return newConsistentHashInternal(endpoints, 1000) -} - -func TestFadeIn(t *testing.T) { - old := 2 * fadeInDuration - testFadeIn(t, "round-robin, 0", newRoundRobin, old, old) - testFadeIn(t, "round-robin, 1", newRoundRobin, 0, old) - testFadeIn(t, "round-robin, 2", newRoundRobin, 0, 0) - testFadeIn(t, "round-robin, 3", newRoundRobin, old, 0) - testFadeIn(t, "round-robin, 4", newRoundRobin, old, old, old, 0) - testFadeIn(t, "round-robin, 5", newRoundRobin, old, old, old, 0, 0, 0) - testFadeIn(t, "round-robin, 6", newRoundRobin, old, 0, 0, 0) - testFadeIn(t, "round-robin, 7", newRoundRobin, old, 0, 0, 0, 0, 0, 0) - testFadeIn(t, "round-robin, 8", newRoundRobin, 0, 0, 0, 0, 0, 0) - testFadeIn(t, "round-robin, 9", newRoundRobin, fadeInDuration/2, fadeInDuration/3, fadeInDuration/4) - - testFadeIn(t, "random, 0", newRandom, old, old) - testFadeIn(t, "random, 1", newRandom, 0, old) - testFadeIn(t, "random, 2", newRandom, 0, 0) - testFadeIn(t, "random, 3", newRandom, old, 0) - testFadeIn(t, "random, 4", newRandom, old, old, old, 0) - testFadeIn(t, "random, 5", newRandom, old, old, old, 0, 0, 0) - testFadeIn(t, "random, 6", newRandom, old, 0, 0, 0) - testFadeIn(t, "random, 7", newRandom, old, 0, 0, 0, 0, 0, 0) - testFadeIn(t, "random, 8", newRandom, 0, 0, 0, 0, 0, 0) - testFadeIn(t, "random, 9", newRandom, fadeInDuration/2, fadeInDuration/3, fadeInDuration/4) - - testFadeIn(t, "consistent-hash, 0", newConsistentHashForTest, old, old) - testFadeIn(t, "consistent-hash, 1", newConsistentHashForTest, 0, old) - testFadeIn(t, "consistent-hash, 2", newConsistentHashForTest, 0, 0) - testFadeIn(t, "consistent-hash, 3", newConsistentHashForTest, old, 0) - testFadeIn(t, "consistent-hash, 4", newConsistentHashForTest, old, old, old, 0) - testFadeIn(t, "consistent-hash, 5", newConsistentHashForTest, old, old, old, 0, 0, 0) - testFadeIn(t, "consistent-hash, 6", newConsistentHashForTest, old, 0, 0, 0) - testFadeIn(t, "consistent-hash, 7", newConsistentHashForTest, old, 0, 0, 0, 0, 0, 0) - testFadeIn(t, "consistent-hash, 8", newConsistentHashForTest, 0, 0, 0, 0, 0, 0) - testFadeIn(t, "consistent-hash, 9", newConsistentHashForTest, fadeInDuration/2, fadeInDuration/3, fadeInDuration/4) -} - -func testFadeInLoadBetweenOldEps( - t *testing.T, - name string, - algorithm func([]string) routing.LBAlgorithm, - nOld int, nNew int, -) { - t.Run(name, func(t *testing.T) { - const ( - numberOfReqs = 100000 - acceptableErrorNearZero = 10 - old = fadeInDurationHuge - new = time.Duration(0) - ) - endpointAges := []time.Duration{} - for i := 0; i < nOld; i++ { - endpointAges = append(endpointAges, old) - } - for i := 0; i < nNew; i++ { - endpointAges = append(endpointAges, new) - } - - ctx, eps := initializeEndpoints(endpointAges, fadeInDurationHuge) - - a := algorithm(eps) - rnd := rand.New(rand.NewSource(time.Now().UnixNano())) - nReqs := map[string]int{} - - t.Log("test start", time.Now()) - // Emulate the load balancer loop, sending requests to it with random hash keys - // over and over again till fadeIn period is over. - for i := 0; i < numberOfReqs; i++ { - ctx.Params[ConsistentHashKey] = strconv.Itoa(rnd.Intn(100000)) - ep := a.Apply(ctx) - nReqs[ep.Host]++ - } - - expectedReqsPerOldEndpoint := numberOfReqs / nOld - for idx, ep := range eps { - if endpointAges[idx] == old { - assert.InEpsilon(t, expectedReqsPerOldEndpoint, nReqs[ep], 0.2) - } - if endpointAges[idx] == new { - assert.InDelta(t, 0, nReqs[ep], acceptableErrorNearZero) - } - } - }) -} - -func TestFadeInLoadBetweenOldEps(t *testing.T) { - for nOld := 1; nOld < 6; nOld++ { - for nNew := 0; nNew < 6; nNew++ { - testFadeInLoadBetweenOldEps(t, fmt.Sprintf("consistent-hash, %d old, %d new", nOld, nNew), newConsistentHash, nOld, nNew) - testFadeInLoadBetweenOldEps(t, fmt.Sprintf("random, %d old, %d new", nOld, nNew), newRandom, nOld, nNew) - testFadeInLoadBetweenOldEps(t, fmt.Sprintf("round-robin, %d old, %d new", nOld, nNew), newRoundRobin, nOld, nNew) - } - } -} - -func testApplyEndsWhenAllEndpointsAreFading( - t *testing.T, - name string, - algorithm func([]string) routing.LBAlgorithm, - nEndpoints int, -) { - t.Run(name, func(t *testing.T) { - // Initialize every endpoint with zero: every endpoint is new - endpointAges := make([]time.Duration, nEndpoints) - - ctx, eps := initializeEndpoints(endpointAges, fadeInDurationHuge) - - a := algorithm(eps) - ctx.Params[ConsistentHashKey] = "someConstantString" - applied := make(chan struct{}) - - go func() { - a.Apply(ctx) - close(applied) - }() - - select { - case <-time.After(time.Second): - t.Errorf("a.Apply has not finished in a reasonable time") - case <-applied: - break - } - }) -} - -func TestApplyEndsWhenAllEndpointsAreFading(t *testing.T) { - for nEndpoints := 1; nEndpoints < 10; nEndpoints++ { - testApplyEndsWhenAllEndpointsAreFading(t, "consistent-hash", newConsistentHash, nEndpoints) - testApplyEndsWhenAllEndpointsAreFading(t, "random", newRandom, nEndpoints) - testApplyEndsWhenAllEndpointsAreFading(t, "round-robin", newRoundRobin, nEndpoints) - } -} - -func benchmarkFadeIn( - b *testing.B, - name string, - algorithm func([]string) routing.LBAlgorithm, - clients int, - endpointAges ...time.Duration, -) { - b.Run(name, func(b *testing.B) { - var detectionTimes []time.Time - now := time.Now() - for _, ea := range endpointAges { - detectionTimes = append(detectionTimes, now.Add(-ea)) - } - - var eps []string - for i := range endpointAges { - eps = append(eps, string('a'+rune(i))) - } - - a := algorithm(eps) - - route := &routing.Route{ - LBFadeInDuration: fadeInDuration, - LBFadeInExponent: 1, - } - registry := routing.NewEndpointRegistry(routing.RegistryOptions{}) - for i := range eps { - route.LBEndpoints = append(route.LBEndpoints, routing.LBEndpoint{ - Host: eps[i], - Detected: detectionTimes[i], - }) - registry.SetDetectedTime(eps[i], detectionTimes[i]) - } - - var wg sync.WaitGroup - - // Emulate the load balancer loop, sending requests to it with random hash keys - // over and over again till fadeIn period is over. - b.ResetTimer() - for i := 0; i < clients; i++ { - wg.Add(1) - go func(i int) { - defer wg.Done() - - rnd := rand.New(rand.NewSource(time.Now().UnixNano())) - ctx := &routing.LBContext{ - Params: map[string]interface{}{}, - Route: route, - Registry: registry, - } - - for j := 0; j < b.N/clients; j++ { - ctx.Params[ConsistentHashKey] = strconv.Itoa(rnd.Intn(100000)) - _ = a.Apply(ctx) - } - }(i) - } - - wg.Wait() - }) -} - -func repeatedSlice(v time.Duration, n int) []time.Duration { - var s []time.Duration - for i := 0; i < n; i++ { - s = append(s, v) - } - return s -} - -func BenchmarkFadeIn(b *testing.B) { - old := 2 * fadeInDuration - clients := []int{1, 4, 16, 64, 256} - for _, c := range clients { - benchmarkFadeIn(b, fmt.Sprintf("random, 11, %d clients", c), newRandom, c, repeatedSlice(old, 200)...) - } - - for _, c := range clients { - benchmarkFadeIn(b, fmt.Sprintf("round-robin, 11, %d clients", c), newRoundRobin, c, repeatedSlice(old, 200)...) - } - - for _, c := range clients { - benchmarkFadeIn(b, fmt.Sprintf("consistent-hash, 11, %d clients", c), newConsistentHash, c, repeatedSlice(old, 200)...) - } -} diff --git a/loadbalancer/locked_source.go b/loadbalancer/locked_source.go index a662a35c4d..7a54bd980c 100644 --- a/loadbalancer/locked_source.go +++ b/loadbalancer/locked_source.go @@ -11,7 +11,7 @@ type lockedSource struct { r rand.Source } -func newLockedSource() *lockedSource { +func NewLockedSource() *lockedSource { return &lockedSource{r: rand.NewSource(time.Now().UnixNano())} } diff --git a/loadbalancer/locked_source_test.go b/loadbalancer/locked_source_test.go index c3cfa68df9..fb512b5b98 100644 --- a/loadbalancer/locked_source_test.go +++ b/loadbalancer/locked_source_test.go @@ -12,7 +12,7 @@ func loadTestLockedSource(s *lockedSource, n int) { } func TestLockedSourceForConcurrentUse(t *testing.T) { - s := newLockedSource() + s := NewLockedSource() var wg sync.WaitGroup for i := 0; i < 10; i++ { diff --git a/proxy/fadein_internal_test.go b/proxy/fadein_internal_test.go new file mode 100644 index 0000000000..0de6d95d88 --- /dev/null +++ b/proxy/fadein_internal_test.go @@ -0,0 +1,281 @@ +package proxy + +import ( + "fmt" + "math/rand" + "strings" + "sync" + "testing" + "time" + + "github.com/zalando/skipper/eskip" + "github.com/zalando/skipper/loadbalancer" + "github.com/zalando/skipper/routing" +) + +const ( + fadeInDuration = 500 * time.Millisecond + fadeInDurationHuge = 24 * time.Hour // we need this to be sure we're at the very beginning of fading in + bucketCount = 20 + monotonyTolerance = 0.4 // we need to use a high tolerance for CI testing +) + +func absint(i int) int { + if i < 0 { + return -i + } + + return i +} + +func tolerance(prev, next int) int { + return int(float64(prev+next) * monotonyTolerance / 2) +} + +func checkMonotony(direction, prev, next int) bool { + t := tolerance(prev, next) + switch direction { + case 1: + return next-prev >= -t + case -1: + return next-prev <= t + default: + return absint(next-prev) < t + } +} + +func initializeEndpoints(endpointAges []time.Duration, fadeInDuration time.Duration) (*routing.Route, *Proxy, []string) { + var detectionTimes []time.Time + now := time.Now() + for _, ea := range endpointAges { + detectionTimes = append(detectionTimes, now.Add(-ea)) + } + + var eps []string + for i := range endpointAges { + eps = append(eps, fmt.Sprintf("http://ep-%d-%s.test", i, endpointAges[i])) + } + + registry := routing.NewEndpointRegistry(routing.RegistryOptions{}) + eskipRoute := eskip.Route{BackendType: eskip.LBBackend} + for i := range eps { + eskipRoute.LBEndpoints = append(eskipRoute.LBEndpoints, eps[i]) + registry.SetDetectedTime(eps[i], detectionTimes[i]) + } + + route := &routing.Route{ + Route: eskipRoute, + LBFadeInDuration: fadeInDuration, + LBFadeInExponent: 1, + LBEndpoints: []routing.LBEndpoint{}, + } + + rt := loadbalancer.NewAlgorithmProvider().Do([]*routing.Route{route}) + route = rt[0] + + eps = []string{} + for i := range endpointAges { + eps = append(eps, fmt.Sprintf("ep-%d-%s.test:80", i, endpointAges[i])) + registry.SetDetectedTime(eps[i], detectionTimes[i]) + } + + proxy := &Proxy{registry: registry, rnd: rand.New(loadbalancer.NewLockedSource())} + return route, proxy, eps +} + +func testFadeIn( + t *testing.T, + name string, + endpointAges ...time.Duration, +) { + t.Run(name, func(t *testing.T) { + route, proxy, eps := initializeEndpoints(endpointAges, fadeInDuration) + + t.Log("test start", time.Now()) + var stats []string + stop := time.After(fadeInDuration) + // Emulate the load balancer loop, sending requests to it with random hash keys + // over and over again till fadeIn period is over. + func() { + for { + ep := proxy.selectEndpoint(&context{route: route}) + stats = append(stats, ep.Host) + select { + case <-stop: + return + default: + } + } + }() + + // Split fade-in period into buckets and count how many times each endpoint was selected. + t.Log("test done", time.Now()) + t.Log("CSV timestamp," + strings.Join(eps, ",")) + bucketSize := len(stats) / bucketCount + var allBuckets []map[string]int + for i := 0; i < bucketCount; i++ { + bucketStats := make(map[string]int) + for j := i * bucketSize; j < (i+1)*bucketSize; j++ { + bucketStats[stats[j]]++ + } + + allBuckets = append(allBuckets, bucketStats) + } + + directions := make(map[string]int) + for _, epi := range eps { + first := allBuckets[0][epi] + last := allBuckets[len(allBuckets)-1][epi] + t := tolerance(first, last) + switch { + case last-first > t: + directions[epi] = 1 + case last-first < t: + directions[epi] = -1 + } + } + + for i := range allBuckets { + // trim first and last (warmup and settling) + if i < 2 || i == len(allBuckets)-1 { + continue + } + + for _, epi := range eps { + if !checkMonotony( + directions[epi], + allBuckets[i-1][epi], + allBuckets[i][epi], + ) { + t.Error("non-monotonic change", epi, i) + } + } + } + + for i, bucketStats := range allBuckets { + var showStats []string + for _, epi := range eps { + showStats = append(showStats, fmt.Sprintf("%d", bucketStats[epi])) + } + + // Print CSV-like output for, where row number represents time and + // column represents endpoint. You can visualize it using + // ./skptesting/run_fadein_test.sh from the skipper repo root. + t.Log("CSV " + fmt.Sprintf("%d,", i) + strings.Join(showStats, ",")) + } + }) +} + +func TestFadeInMonotony(t *testing.T) { + old := 2 * fadeInDuration + testFadeIn(t, "round-robin, 0", old, old) + testFadeIn(t, "round-robin, 1", 0, old) + testFadeIn(t, "round-robin, 2", 0, 0) + testFadeIn(t, "round-robin, 3", old, 0) + testFadeIn(t, "round-robin, 4", old, old, old, 0) + testFadeIn(t, "round-robin, 5", old, old, old, 0, 0, 0) + testFadeIn(t, "round-robin, 6", old, 0, 0, 0) + testFadeIn(t, "round-robin, 7", old, 0, 0, 0, 0, 0, 0) + testFadeIn(t, "round-robin, 8", 0, 0, 0, 0, 0, 0) + testFadeIn(t, "round-robin, 9", fadeInDuration/2, fadeInDuration/3, fadeInDuration/4) +} + +func testSetRequestURLEndsWhenAllEndpointsAreFading( + t *testing.T, + name string, + nEndpoints int, +) { + t.Run(name, func(t *testing.T) { + // Initialize every endpoint with zero: every endpoint is new + endpointAges := make([]time.Duration, nEndpoints) + + route, proxy, _ := initializeEndpoints(endpointAges, fadeInDurationHuge) + + applied := make(chan struct{}) + go func() { + _ = proxy.selectEndpoint(&context{route: route}) + close(applied) + }() + + select { + case <-time.After(time.Second): + t.Errorf("a.Apply has not finished in a reasonable time") + case <-applied: + break + } + }) +} + +func TestSetRequestURLEndsWhenAllEndpointsAreFading(t *testing.T) { + for nEndpoints := 1; nEndpoints < 10; nEndpoints++ { + testSetRequestURLEndsWhenAllEndpointsAreFading(t, "round-robin", nEndpoints) + } +} + +func benchmarkFadeIn( + b *testing.B, + name string, + clients int, + endpointAges ...time.Duration, +) { + b.Run(name, func(b *testing.B) { + var detectionTimes []time.Time + now := time.Now() + for _, ea := range endpointAges { + detectionTimes = append(detectionTimes, now.Add(-ea)) + } + + var eps []string + for i := range endpointAges { + eps = append(eps, string('a'+rune(i))) + } + + route := &routing.Route{ + LBFadeInDuration: fadeInDuration, + LBFadeInExponent: 1, + } + registry := routing.NewEndpointRegistry(routing.RegistryOptions{}) + for i := range eps { + route.LBEndpoints = append(route.LBEndpoints, routing.LBEndpoint{ + Host: eps[i], + Detected: detectionTimes[i], + }) + registry.SetDetectedTime(eps[i], detectionTimes[i]) + } + proxy := &Proxy{rnd: rand.New(loadbalancer.NewLockedSource()), registry: registry} + + var wg sync.WaitGroup + + // Emulate the load balancer loop, sending requests to it with random hash keys + // over and over again till fadeIn period is over. + b.ResetTimer() + for i := 0; i < clients; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + + for j := 0; j < b.N/clients; j++ { + _ = proxy.selectEndpoint(&context{route: route}) + } + }(i) + } + + wg.Wait() + }) +} + +func repeatedSlice(v time.Duration, n int) []time.Duration { + var s []time.Duration + for i := 0; i < n; i++ { + s = append(s, v) + } + return s +} + +func BenchmarkFadeIn(b *testing.B) { + old := 2 * fadeInDuration + clients := []int{1, 4, 16, 64, 256} + for _, c := range clients { + benchmarkFadeIn(b, fmt.Sprintf("round-robin, 11, %d clients", c), c, repeatedSlice(old, 200)...) + } +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 923ad8c472..fb356c0ffc 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -8,6 +8,8 @@ import ( "errors" "fmt" "io" + "math" + "math/rand" "net" "net/http" "net/http/httptrace" @@ -329,6 +331,7 @@ type Proxy struct { defaultHTTPStatus int routing *routing.Routing registry *routing.EndpointRegistry + rnd *rand.Rand roundTripper http.RoundTripper priorityRoutes []PriorityRoute flags Flags @@ -468,16 +471,52 @@ func setRequestURLForDynamicBackend(u *url.URL, stateBag map[string]interface{}) } } -func selectEndpoint(ctx *context, registry *routing.EndpointRegistry) *routing.LBEndpoint { +func fadeIn(now time.Time, duration time.Duration, exponent float64, detected time.Time) float64 { + rel := now.Sub(detected) + fadingIn := rel > 0 && rel < duration + if !fadingIn { + return 1 + } + + return math.Pow(float64(rel)/float64(duration), exponent) +} + +func (p *Proxy) filterFadeIn(endpoints []routing.LBEndpoint, rt *routing.Route) []routing.LBEndpoint { + now := time.Now() + threshold := p.rnd.Float64() + + filtered := make([]routing.LBEndpoint, 0, len(endpoints)) + for _, e := range endpoints { + detected := p.registry.GetMetrics(e.Host).DetectedTime() + f := fadeIn( + now, + rt.LBFadeInDuration, + rt.LBFadeInExponent, + detected, + ) + if threshold < f { + filtered = append(filtered, e) + } + } + return filtered +} + +func (p *Proxy) selectEndpoint(ctx *context) *routing.LBEndpoint { rt := ctx.route + endpoints := rt.LBEndpoints + endpoints = p.filterFadeIn(endpoints, rt) + if len(endpoints) == 0 { + endpoints = rt.LBEndpoints + } lbctx := &routing.LBContext{ Request: ctx.request, Route: rt, - LBEndpoints: rt.LBEndpoints, + LBEndpoints: endpoints, Params: ctx.StateBag(), - Registry: registry, + Registry: p.registry, } + e := rt.LBAlgorithm.Apply(lbctx) return &e @@ -485,7 +524,7 @@ func selectEndpoint(ctx *context, registry *routing.EndpointRegistry) *routing.L // creates an outgoing http request to be forwarded to the route endpoint // based on the augmented incoming request -func mapRequest(ctx *context, requestContext stdlibcontext.Context, removeHopHeaders bool, registry *routing.EndpointRegistry) (*http.Request, *routing.LBEndpoint, error) { +func (p *Proxy) mapRequest(ctx *context, requestContext stdlibcontext.Context) (*http.Request, *routing.LBEndpoint, error) { var endpoint *routing.LBEndpoint r := ctx.request rt := ctx.route @@ -498,7 +537,7 @@ func mapRequest(ctx *context, requestContext stdlibcontext.Context, removeHopHea setRequestURLFromRequest(u, r) setRequestURLForDynamicBackend(u, stateBag) case eskip.LBBackend: - endpoint = selectEndpoint(ctx, registry) + endpoint = p.selectEndpoint(ctx) u.Scheme = endpoint.Scheme u.Host = endpoint.Host default: @@ -517,7 +556,7 @@ func mapRequest(ctx *context, requestContext stdlibcontext.Context, removeHopHea } rr.ContentLength = r.ContentLength - if removeHopHeaders { + if p.flags.HopHeadersRemoval() { rr.Header = cloneHeaderExcluding(r.Header, hopHeaders) } else { rr.Header = cloneHeader(r.Header) @@ -742,6 +781,8 @@ func WithParams(p Params) *Proxy { clientTLS: tr.TLSClientConfig, hostname: hostname, onPanicSometimes: rate.Sometimes{First: 3, Interval: 1 * time.Minute}, + /* #nosec */ + rnd: rand.New(loadbalancer.NewLockedSource()), } } @@ -837,7 +878,7 @@ func (p *Proxy) makeUpgradeRequest(ctx *context, req *http.Request) { } func (p *Proxy) makeBackendRequest(ctx *context, requestContext stdlibcontext.Context) (*http.Response, *proxyError) { - req, endpoint, err := mapRequest(ctx, requestContext, p.flags.HopHeadersRemoval(), p.registry) + req, endpoint, err := p.mapRequest(ctx, requestContext) if err != nil { return nil, &proxyError{err: fmt.Errorf("could not map backend request: %w", err)} } @@ -1125,7 +1166,7 @@ func (p *Proxy) do(ctx *context) (err error) { ctx.setResponse(loopCTX.response, p.flags.PreserveOriginal()) ctx.proxySpan = loopCTX.proxySpan } else if p.flags.Debug() { - debugReq, _, err := mapRequest(ctx, ctx.request.Context(), p.flags.HopHeadersRemoval(), p.registry) + debugReq, _, err := p.mapRequest(ctx, ctx.request.Context()) if err != nil { perr := &proxyError{err: err} p.makeErrorResponse(ctx, perr) diff --git a/routing/routing.go b/routing/routing.go index c6eb6f75a2..8b8b5dc6a1 100644 --- a/routing/routing.go +++ b/routing/routing.go @@ -191,7 +191,7 @@ type LBAlgorithm interface { // on that data which endpoint to call from the backends type LBContext struct { Request *http.Request - Route *Route + Route *Route // Deprecated LBEndpoints []LBEndpoint Params map[string]interface{} Registry *EndpointRegistry diff --git a/skptesting/run_fadein_test.sh b/skptesting/run_fadein_test.sh index 77b7c521af..baa9f78769 100755 --- a/skptesting/run_fadein_test.sh +++ b/skptesting/run_fadein_test.sh @@ -1,7 +1,7 @@ #!/bin/bash function run_test() { - GO111MODULE=on go test ./loadbalancer -run="$1" -count=1 -v | awk '/fadein_test.go:[0-9]+: CSV/ {print $3}' + GO111MODULE=on go test ./proxy -run="$1" -count=1 -v | awk '/fadein_internal_test.go:[0-9]+: CSV/ {print $3}' } cwd=$( dirname "${BASH_SOURCE[0]}" ) @@ -10,7 +10,7 @@ if [ -z "${1+x}" ] then echo "$0 [...]" echo "Example:" - echo "$0 TestFadeIn/round-robin,_4 TestFadeIn/round-robin,_3" + echo "$0 TestFadeInMonotony/round-robin,_4 TestFadeInMonotony/round-robin,_3" else d=$(mktemp -d) for t in "$@"