Skip to content

Commit

Permalink
Decouple fadeIn from loadbalancer
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Zavodskikh <[email protected]>
  • Loading branch information
Roman Zavodskikh committed Jan 17, 2024
1 parent ce65352 commit b975b7d
Show file tree
Hide file tree
Showing 9 changed files with 504 additions and 605 deletions.
164 changes: 19 additions & 145 deletions loadbalancer/algorithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ package loadbalancer
import (
"errors"
"fmt"
"math"
"math/rand"
"sort"
"sync"
"sync/atomic"
"time"

"github.com/cespare/xxhash/v2"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -53,106 +51,14 @@ var (
defaultAlgorithm = newRoundRobin
)

func fadeInState(now time.Time, duration time.Duration, detected time.Time) (time.Duration, bool) {
rel := now.Sub(detected)
return rel, rel > 0 && rel < duration
}

func fadeIn(now time.Time, duration time.Duration, exponent float64, detected time.Time) float64 {
rel, fadingIn := fadeInState(now, duration, detected)
if !fadingIn {
return 1
}

return math.Pow(float64(rel)/float64(duration), exponent)
}

func shiftWeighted(rnd *rand.Rand, ctx *routing.LBContext, now time.Time) routing.LBEndpoint {
var sum float64
rt := ctx.Route
ep := ctx.LBEndpoints
for _, epi := range ep {
wi := fadeIn(now, rt.LBFadeInDuration, rt.LBFadeInExponent, epi.Metrics.DetectedTime())
sum += wi
}

choice := ep[len(ep)-1]
r := rnd.Float64() * sum
var upto float64
for i, epi := range ep {
upto += fadeIn(now, rt.LBFadeInDuration, rt.LBFadeInExponent, epi.Metrics.DetectedTime())
if upto > r {
choice = ep[i]
break
}
}

return choice
}

func shiftToRemaining(rnd *rand.Rand, ctx *routing.LBContext, wi []int, now time.Time) routing.LBEndpoint {
notFadingIndexes := wi
ep := ctx.LBEndpoints

// if all endpoints are fading, the simplest approach is to use the oldest,
// this departs from the desired curve, but guarantees monotonic fade-in. From
// the perspective of the oldest endpoint, this is temporarily the same as if
// there was no fade-in.
if len(notFadingIndexes) == 0 {
return shiftWeighted(rnd, ctx, now)
}

// otherwise equally distribute between the old endpoints
return ep[notFadingIndexes[rnd.Intn(len(notFadingIndexes))]]
}

func withFadeIn(rnd *rand.Rand, ctx *routing.LBContext, choice int, algo routing.LBAlgorithm) routing.LBEndpoint {
ep := ctx.LBEndpoints
now := time.Now()
f := fadeIn(
now,
ctx.Route.LBFadeInDuration,
ctx.Route.LBFadeInExponent,
ctx.LBEndpoints[choice].Metrics.DetectedTime(),
)

if rnd.Float64() < f {
return ep[choice]
}
notFadingIndexes := make([]int, 0, len(ep))
for i := 0; i < len(ep); i++ {
if _, fadingIn := fadeInState(now, ctx.Route.LBFadeInDuration, ep[i].Metrics.DetectedTime()); !fadingIn {
notFadingIndexes = append(notFadingIndexes, i)
}
}

switch a := algo.(type) {
case *roundRobin:
return shiftToRemaining(a.rnd, ctx, notFadingIndexes, now)
case *random:
return shiftToRemaining(a.rnd, ctx, notFadingIndexes, now)
case *consistentHash:
// If all endpoints are fading, normal consistent hash result
if len(notFadingIndexes) == 0 {
return ep[choice]
}
// otherwise calculate consistent hash again using endpoints which are not fading
return ep[a.chooseConsistentHashEndpoint(ctx, skipFadingEndpoints(notFadingIndexes))]
default:
return ep[choice]
}
}

type roundRobin struct {
index int64
rnd *rand.Rand
}

func newRoundRobin(endpoints []string) routing.LBAlgorithm {
rnd := rand.New(newLockedSource()) // #nosec
rnd := rand.New(NewLockedSource()) // #nosec
return &roundRobin{
index: int64(rnd.Intn(len(endpoints))),
rnd: rnd,
}
}

Expand All @@ -162,13 +68,8 @@ func (r *roundRobin) Apply(ctx *routing.LBContext) routing.LBEndpoint {
return ctx.LBEndpoints[0]
}

index := int(atomic.AddInt64(&r.index, 1) % int64(len(ctx.LBEndpoints)))

if ctx.Route.LBFadeInDuration <= 0 {
return ctx.LBEndpoints[index]
}

return withFadeIn(r.rnd, ctx, index, r)
choice := int(atomic.AddInt64(&r.index, 1) % int64(len(ctx.LBEndpoints)))
return ctx.LBEndpoints[choice]
}

type random struct {
Expand All @@ -178,7 +79,7 @@ type random struct {
func newRandom(endpoints []string) routing.LBAlgorithm {
// #nosec
return &random{
rnd: rand.New(newLockedSource()),
rnd: rand.New(NewLockedSource()),
}
}

Expand All @@ -188,12 +89,8 @@ func (r *random) Apply(ctx *routing.LBContext) routing.LBEndpoint {
return ctx.LBEndpoints[0]
}

i := r.rnd.Intn(len(ctx.LBEndpoints))
if ctx.Route.LBFadeInDuration <= 0 {
return ctx.LBEndpoints[i]
}

return withFadeIn(r.rnd, ctx, i, r)
choice := r.rnd.Intn(len(ctx.LBEndpoints))
return ctx.LBEndpoints[choice]
}

type (
Expand All @@ -203,7 +100,6 @@ type (
}
consistentHash struct {
hashRing []endpointHash // list of endpoints sorted by hash value
rnd *rand.Rand
}
)

Expand All @@ -214,10 +110,8 @@ func (ch *consistentHash) Swap(i, j int) {
}

func newConsistentHashInternal(endpoints []string, hashesPerEndpoint int) routing.LBAlgorithm {
rnd := rand.New(newLockedSource()) // #nosec
ch := &consistentHash{
hashRing: make([]endpointHash, hashesPerEndpoint*len(endpoints)),
rnd: rnd,
}
for i, ep := range endpoints {
endpointStartIndex := hashesPerEndpoint * i
Expand All @@ -238,21 +132,21 @@ func hash(s string) uint64 {
}

// Returns index in hash ring with the closest hash to key's hash
func (ch *consistentHash) searchRing(key string, skipEndpoint func(int) bool) int {
func (ch *consistentHash) searchRing(key string, ctx *routing.LBContext) int {
h := hash(key)
i := sort.Search(ch.Len(), func(i int) bool { return ch.hashRing[i].hash >= h })
if i == ch.Len() { // rollover
i = 0
}
for skipEndpoint(ch.hashRing[i].index) {
for ctx.SkipEndpoint(ch.hashRing[i].index) {
i = (i + 1) % ch.Len()
}
return i
}

// Returns index of endpoint with closest hash to key's hash
func (ch *consistentHash) search(key string, skipEndpoint func(int) bool) int {
ringIndex := ch.searchRing(key, skipEndpoint)
func (ch *consistentHash) search(key string, ctx *routing.LBContext) int {
ringIndex := ch.searchRing(key, ctx)
return ch.hashRing[ringIndex].index
}

Expand All @@ -267,14 +161,14 @@ func computeLoadAverage(ctx *routing.LBContext) float64 {

// Returns index of endpoint with closest hash to key's hash, which is also below the target load
// skipEndpoint function is used to skip endpoints we don't want, such as fading endpoints
func (ch *consistentHash) boundedLoadSearch(key string, balanceFactor float64, ctx *routing.LBContext, skipEndpoint func(int) bool) int {
ringIndex := ch.searchRing(key, skipEndpoint)
func (ch *consistentHash) boundedLoadSearch(key string, balanceFactor float64, ctx *routing.LBContext) int {
ringIndex := ch.searchRing(key, ctx)
averageLoad := computeLoadAverage(ctx)
targetLoad := averageLoad * balanceFactor
// Loop round ring, starting at endpoint with closest hash. Stop when we find one whose load is less than targetLoad.
for i := 0; i < ch.Len(); i++ {
endpointIndex := ch.hashRing[ringIndex].index
if skipEndpoint(endpointIndex) {
if ctx.SkipEndpoint(endpointIndex) {
continue
}
load := ctx.LBEndpoints[endpointIndex].Metrics.InflightRequests()
Expand All @@ -295,46 +189,26 @@ func (ch *consistentHash) Apply(ctx *routing.LBContext) routing.LBEndpoint {
return ctx.LBEndpoints[0]
}

choice := ch.chooseConsistentHashEndpoint(ctx, noSkippedEndpoints)

if ctx.Route.LBFadeInDuration <= 0 {
return ctx.LBEndpoints[choice]
}

return withFadeIn(ch.rnd, ctx, choice, ch)
choice := ch.chooseConsistentHashEndpoint(ctx)
return ctx.LBEndpoints[choice]
}

func (ch *consistentHash) chooseConsistentHashEndpoint(ctx *routing.LBContext, skipEndpoint func(int) bool) int {
func (ch *consistentHash) chooseConsistentHashEndpoint(ctx *routing.LBContext) int {
key, ok := ctx.Params[ConsistentHashKey].(string)
if !ok {
key = snet.RemoteHost(ctx.Request).String()
}
balanceFactor, ok := ctx.Params[ConsistentHashBalanceFactor].(float64)
var choice int
if !ok {
choice = ch.search(key, skipEndpoint)
choice = ch.search(key, ctx)
} else {
choice = ch.boundedLoadSearch(key, balanceFactor, ctx, skipEndpoint)
choice = ch.boundedLoadSearch(key, balanceFactor, ctx)
}

return choice
}

func skipFadingEndpoints(notFadingEndpoints []int) func(int) bool {
return func(i int) bool {
for _, notFadingEndpoint := range notFadingEndpoints {
if i == notFadingEndpoint {
return false
}
}
return true
}
}

func noSkippedEndpoints(_ int) bool {
return false
}

type powerOfRandomNChoices struct {
mu sync.Mutex
rnd *rand.Rand
Expand All @@ -343,7 +217,7 @@ type powerOfRandomNChoices struct {

// newPowerOfRandomNChoices selects N random backends and picks the one with less outstanding requests.
func newPowerOfRandomNChoices([]string) routing.LBAlgorithm {
rnd := rand.New(newLockedSource()) // #nosec
rnd := rand.New(NewLockedSource()) // #nosec
return &powerOfRandomNChoices{
rnd: rnd,
numberOfChoices: powerOfRandomNChoicesDefaultN,
Expand Down
11 changes: 10 additions & 1 deletion loadbalancer/algorithm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,16 @@ func TestApply(t *testing.T) {
func TestConsistentHashSearch(t *testing.T) {
apply := func(key string, endpoints []string) string {
ch := newConsistentHash(endpoints).(*consistentHash)
return endpoints[ch.search(key, noSkippedEndpoints)]
ctx := &routing.LBContext{Route: &routing.Route{LBEndpoints: make([]routing.LBEndpoint, 0, len(endpoints))}}
for _, ep := range endpoints {
scheme, host, err := net.SchemeHost(ep)
if err != nil {
t.Fatal(err)
}
ctx.LBEndpoints = append(ctx.LBEndpoints, routing.LBEndpoint{Host: host, Scheme: scheme})
ctx.Route.LBEndpoints = append(ctx.Route.LBEndpoints, routing.LBEndpoint{Host: host, Scheme: scheme})
}
return endpoints[ch.search(key, ctx)]
}

endpoints := []string{"http://127.0.0.1:8080", "http://127.0.0.2:8080", "http://127.0.0.3:8080"}
Expand Down
Loading

0 comments on commit b975b7d

Please sign in to comment.