Skip to content

Commit

Permalink
Decouple fadeIn from loadbalancer
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Zavodskikh <[email protected]>
  • Loading branch information
Roman Zavodskikh committed Sep 28, 2023
1 parent c823514 commit d51d217
Show file tree
Hide file tree
Showing 10 changed files with 376 additions and 579 deletions.
192 changes: 28 additions & 164 deletions loadbalancer/algorithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@ package loadbalancer
import (
"errors"
"fmt"
"math"
"math/rand"
"net"
"net/url"
"sort"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/cespare/xxhash/v2"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -56,126 +54,25 @@ var (
defaultAlgorithm = newRoundRobin
)

func fadeInState(now time.Time, duration time.Duration, detected time.Time) (time.Duration, bool) {
rel := now.Sub(detected)
return rel, rel > 0 && rel < duration
}

func fadeIn(now time.Time, duration time.Duration, exponent float64, detected time.Time) float64 {
rel, fadingIn := fadeInState(now, duration, detected)
if !fadingIn {
return 1
}

return math.Pow(float64(rel)/float64(duration), exponent)
}

func shiftWeighted(rnd *rand.Rand, ctx *routing.LBContext, now time.Time) routing.LBEndpoint {
var sum float64
rt := ctx.Route
ep := ctx.Route.LBEndpoints
for _, epi := range ep {
detected := ctx.Registry.GetMetrics(epi.Host).DetectedTime()
wi := fadeIn(now, rt.LBFadeInDuration, rt.LBFadeInExponent, detected)
sum += wi
}

choice := ep[len(ep)-1]
r := rnd.Float64() * sum
var upto float64
for i, epi := range ep {
detected := ctx.Registry.GetMetrics(epi.Host).DetectedTime()
upto += fadeIn(now, rt.LBFadeInDuration, rt.LBFadeInExponent, detected)
if upto > r {
choice = ep[i]
break
}
}

return choice
}

func shiftToRemaining(rnd *rand.Rand, ctx *routing.LBContext, wi []int, now time.Time) routing.LBEndpoint {
notFadingIndexes := wi
ep := ctx.Route.LBEndpoints

// if all endpoints are fading, the simplest approach is to use the oldest,
// this departs from the desired curve, but guarantees monotonic fade-in. From
// the perspective of the oldest endpoint, this is temporarily the same as if
// there was no fade-in.
if len(notFadingIndexes) == 0 {
return shiftWeighted(rnd, ctx, now)
}

// otherwise equally distribute between the old endpoints
return ep[notFadingIndexes[rnd.Intn(len(notFadingIndexes))]]
}

func withFadeIn(rnd *rand.Rand, ctx *routing.LBContext, choice int, algo routing.LBAlgorithm) routing.LBEndpoint {
ep := ctx.Route.LBEndpoints
now := time.Now()
detected := ctx.Registry.GetMetrics(ctx.Route.LBEndpoints[choice].Host).DetectedTime()
f := fadeIn(
now,
ctx.Route.LBFadeInDuration,
ctx.Route.LBFadeInExponent,
detected,
)

if rnd.Float64() < f {
return ep[choice]
}
notFadingIndexes := make([]int, 0, len(ep))
for i := 0; i < len(ep); i++ {
detected := ctx.Registry.GetMetrics(ep[i].Host).DetectedTime()
if _, fadingIn := fadeInState(now, ctx.Route.LBFadeInDuration, detected); !fadingIn {
notFadingIndexes = append(notFadingIndexes, i)
}
}

switch a := algo.(type) {
case *roundRobin:
return shiftToRemaining(a.rnd, ctx, notFadingIndexes, now)
case *random:
return shiftToRemaining(a.rnd, ctx, notFadingIndexes, now)
case *consistentHash:
// If all endpoints are fading, normal consistent hash result
if len(notFadingIndexes) == 0 {
return ep[choice]
}
// otherwise calculate consistent hash again using endpoints which are not fading
return ep[a.chooseConsistentHashEndpoint(ctx, skipFadingEndpoints(notFadingIndexes))]
default:
return ep[choice]
}
}

type roundRobin struct {
index int64
rnd *rand.Rand
}

func newRoundRobin(endpoints []string) routing.LBAlgorithm {
rnd := rand.New(newLockedSource()) // #nosec
rnd := rand.New(routing.NewLockedSource()) // #nosec
return &roundRobin{
index: int64(rnd.Intn(len(endpoints))),
rnd: rnd,
}
}

// Apply implements routing.LBAlgorithm with a roundrobin algorithm.
func (r *roundRobin) Apply(ctx *routing.LBContext) routing.LBEndpoint {
if len(ctx.Route.LBEndpoints) == 1 {
return ctx.Route.LBEndpoints[0]
if len(ctx.Endpoints) == 1 {
return ctx.Endpoints[0]
}

index := int(atomic.AddInt64(&r.index, 1) % int64(len(ctx.Route.LBEndpoints)))

if ctx.Route.LBFadeInDuration <= 0 {
return ctx.Route.LBEndpoints[index]
}

return withFadeIn(r.rnd, ctx, index, r)
choice := int(atomic.AddInt64(&r.index, 1) % int64(len(ctx.Endpoints)))
return ctx.Endpoints[choice]
}

type random struct {
Expand All @@ -185,22 +82,18 @@ type random struct {
func newRandom(endpoints []string) routing.LBAlgorithm {
// #nosec
return &random{
rnd: rand.New(newLockedSource()),
rnd: rand.New(routing.NewLockedSource()),
}
}

// Apply implements routing.LBAlgorithm with a stateless random algorithm.
func (r *random) Apply(ctx *routing.LBContext) routing.LBEndpoint {
if len(ctx.Route.LBEndpoints) == 1 {
return ctx.Route.LBEndpoints[0]
if len(ctx.Endpoints) == 1 {
return ctx.Endpoints[0]
}

i := r.rnd.Intn(len(ctx.Route.LBEndpoints))
if ctx.Route.LBFadeInDuration <= 0 {
return ctx.Route.LBEndpoints[i]
}

return withFadeIn(r.rnd, ctx, i, r)
choice := r.rnd.Intn(len(ctx.Endpoints))
return ctx.Endpoints[choice]
}

type (
Expand All @@ -210,7 +103,6 @@ type (
}
consistentHash struct {
hashRing []endpointHash // list of endpoints sorted by hash value
rnd *rand.Rand
}
)

Expand All @@ -221,10 +113,8 @@ func (ch *consistentHash) Swap(i, j int) {
}

func newConsistentHashInternal(endpoints []string, hashesPerEndpoint int) routing.LBAlgorithm {
rnd := rand.New(newLockedSource()) // #nosec
ch := &consistentHash{
hashRing: make([]endpointHash, hashesPerEndpoint*len(endpoints)),
rnd: rnd,
}
for i, ep := range endpoints {
endpointStartIndex := hashesPerEndpoint * i
Expand All @@ -245,27 +135,24 @@ func hash(s string) uint64 {
}

// Returns index in hash ring with the closest hash to key's hash
func (ch *consistentHash) searchRing(key string, skipEndpoint func(int) bool) int {
func (ch *consistentHash) searchRing(key string) int {
h := hash(key)
i := sort.Search(ch.Len(), func(i int) bool { return ch.hashRing[i].hash >= h })
if i == ch.Len() { // rollover
i = 0
}
for skipEndpoint(ch.hashRing[i].index) {
i = (i + 1) % ch.Len()
}
return i
}

// Returns index of endpoint with closest hash to key's hash
func (ch *consistentHash) search(key string, skipEndpoint func(int) bool) int {
ringIndex := ch.searchRing(key, skipEndpoint)
func (ch *consistentHash) search(key string) int {
ringIndex := ch.searchRing(key)
return ch.hashRing[ringIndex].index
}

func computeLoadAverage(ctx *routing.LBContext) float64 {
sum := 1.0 // add 1 to include the request that just arrived
endpoints := ctx.Route.LBEndpoints
endpoints := ctx.Endpoints
for _, v := range endpoints {
sum += float64(v.Metrics.GetInflightRequests())
}
Expand All @@ -274,17 +161,14 @@ func computeLoadAverage(ctx *routing.LBContext) float64 {

// Returns index of endpoint with closest hash to key's hash, which is also below the target load
// skipEndpoint function is used to skip endpoints we don't want, such as fading endpoints
func (ch *consistentHash) boundedLoadSearch(key string, balanceFactor float64, ctx *routing.LBContext, skipEndpoint func(int) bool) int {
ringIndex := ch.searchRing(key, skipEndpoint)
func (ch *consistentHash) boundedLoadSearch(key string, balanceFactor float64, ctx *routing.LBContext) int {
ringIndex := ch.searchRing(key)
averageLoad := computeLoadAverage(ctx)
targetLoad := averageLoad * balanceFactor
// Loop round ring, starting at endpoint with closest hash. Stop when we find one whose load is less than targetLoad.
for i := 0; i < ch.Len(); i++ {
endpointIndex := ch.hashRing[ringIndex].index
if skipEndpoint(endpointIndex) {
continue
}
load := ctx.Route.LBEndpoints[endpointIndex].Metrics.GetInflightRequests()
load := ctx.Endpoints[endpointIndex].Metrics.GetInflightRequests()
// We know there must be an endpoint whose load <= average load.
// Since targetLoad >= average load (balancerFactor >= 1), there must also be an endpoint with load <= targetLoad.
if load <= int(targetLoad) {
Expand All @@ -298,50 +182,30 @@ func (ch *consistentHash) boundedLoadSearch(key string, balanceFactor float64, c

// Apply implements routing.LBAlgorithm with a consistent hash algorithm.
func (ch *consistentHash) Apply(ctx *routing.LBContext) routing.LBEndpoint {
if len(ctx.Route.LBEndpoints) == 1 {
return ctx.Route.LBEndpoints[0]
}

choice := ch.chooseConsistentHashEndpoint(ctx, noSkippedEndpoints)

if ctx.Route.LBFadeInDuration <= 0 {
return ctx.Route.LBEndpoints[choice]
if len(ctx.Endpoints) == 1 {
return ctx.Endpoints[0]
}

return withFadeIn(ch.rnd, ctx, choice, ch)
choice := ch.chooseConsistentHashEndpoint(ctx)
return ctx.Endpoints[choice]
}

func (ch *consistentHash) chooseConsistentHashEndpoint(ctx *routing.LBContext, skipEndpoint func(int) bool) int {
func (ch *consistentHash) chooseConsistentHashEndpoint(ctx *routing.LBContext) int {
key, ok := ctx.Params[ConsistentHashKey].(string)
if !ok {
key = snet.RemoteHost(ctx.Request).String()
}
balanceFactor, ok := ctx.Params[ConsistentHashBalanceFactor].(float64)
var choice int
if !ok {
choice = ch.search(key, skipEndpoint)
choice = ch.search(key)
} else {
choice = ch.boundedLoadSearch(key, balanceFactor, ctx, skipEndpoint)
choice = ch.boundedLoadSearch(key, balanceFactor, ctx)
}

return choice
}

func skipFadingEndpoints(notFadingEndpoints []int) func(int) bool {
return func(i int) bool {
for _, notFadingEndpoint := range notFadingEndpoints {
if i == notFadingEndpoint {
return false
}
}
return true
}
}

func noSkippedEndpoints(_ int) bool {
return false
}

type powerOfRandomNChoices struct {
mx sync.Mutex
rnd *rand.Rand
Expand All @@ -350,7 +214,7 @@ type powerOfRandomNChoices struct {

// newPowerOfRandomNChoices selects N random backends and picks the one with less outstanding requests.
func newPowerOfRandomNChoices([]string) routing.LBAlgorithm {
rnd := rand.New(newLockedSource()) // #nosec
rnd := rand.New(routing.NewLockedSource()) // #nosec
return &powerOfRandomNChoices{
rnd: rnd,
numberOfChoices: powerOfRandomNChoicesDefaultN,
Expand All @@ -359,15 +223,15 @@ func newPowerOfRandomNChoices([]string) routing.LBAlgorithm {

// Apply implements routing.LBAlgorithm with power of random N choices algorithm.
func (p *powerOfRandomNChoices) Apply(ctx *routing.LBContext) routing.LBEndpoint {
ne := len(ctx.Route.LBEndpoints)
ne := len(ctx.Endpoints)

p.mx.Lock()
defer p.mx.Unlock()

best := ctx.Route.LBEndpoints[p.rnd.Intn(ne)]
best := ctx.Endpoints[p.rnd.Intn(ne)]

for i := 1; i < p.numberOfChoices; i++ {
ce := ctx.Route.LBEndpoints[p.rnd.Intn(ne)]
ce := ctx.Endpoints[p.rnd.Intn(ne)]

if p.getScore(ce) > p.getScore(best) {
best = ce
Expand Down
Loading

0 comments on commit d51d217

Please sign in to comment.