Skip to content

Commit

Permalink
Decouple fadeIn from loadbalancer
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Zavodskikh <[email protected]>
  • Loading branch information
Roman Zavodskikh committed Oct 6, 2023
1 parent 40f4634 commit e0db556
Show file tree
Hide file tree
Showing 9 changed files with 353 additions and 555 deletions.
170 changes: 17 additions & 153 deletions loadbalancer/algorithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@ package loadbalancer
import (
"errors"
"fmt"
"math"
"math/rand"
"net"
"net/url"
"sort"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/cespare/xxhash/v2"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -56,110 +54,14 @@ var (
defaultAlgorithm = newRoundRobin
)

func fadeInState(now time.Time, duration time.Duration, detected time.Time) (time.Duration, bool) {
rel := now.Sub(detected)
return rel, rel > 0 && rel < duration
}

func fadeIn(now time.Time, duration time.Duration, exponent float64, detected time.Time) float64 {
rel, fadingIn := fadeInState(now, duration, detected)
if !fadingIn {
return 1
}

return math.Pow(float64(rel)/float64(duration), exponent)
}

func shiftWeighted(rnd *rand.Rand, ctx *routing.LBContext, now time.Time) routing.LBEndpoint {
var sum float64
rt := ctx.Route
ep := ctx.LBEndpoints
for _, epi := range ep {
detected := ctx.Registry.GetMetrics(epi.Host).DetectedTime()
wi := fadeIn(now, rt.LBFadeInDuration, rt.LBFadeInExponent, detected)
sum += wi
}

choice := ep[len(ep)-1]
r := rnd.Float64() * sum
var upto float64
for i, epi := range ep {
detected := ctx.Registry.GetMetrics(epi.Host).DetectedTime()
upto += fadeIn(now, rt.LBFadeInDuration, rt.LBFadeInExponent, detected)
if upto > r {
choice = ep[i]
break
}
}

return choice
}

func shiftToRemaining(rnd *rand.Rand, ctx *routing.LBContext, wi []int, now time.Time) routing.LBEndpoint {
notFadingIndexes := wi
ep := ctx.LBEndpoints

// if all endpoints are fading, the simplest approach is to use the oldest,
// this departs from the desired curve, but guarantees monotonic fade-in. From
// the perspective of the oldest endpoint, this is temporarily the same as if
// there was no fade-in.
if len(notFadingIndexes) == 0 {
return shiftWeighted(rnd, ctx, now)
}

// otherwise equally distribute between the old endpoints
return ep[notFadingIndexes[rnd.Intn(len(notFadingIndexes))]]
}

func withFadeIn(rnd *rand.Rand, ctx *routing.LBContext, choice int, algo routing.LBAlgorithm) routing.LBEndpoint {
ep := ctx.LBEndpoints
now := time.Now()
detected := ctx.Registry.GetMetrics(ctx.LBEndpoints[choice].Host).DetectedTime()
f := fadeIn(
now,
ctx.Route.LBFadeInDuration,
ctx.Route.LBFadeInExponent,
detected,
)

if rnd.Float64() < f {
return ep[choice]
}
notFadingIndexes := make([]int, 0, len(ep))
for i := 0; i < len(ep); i++ {
detected := ctx.Registry.GetMetrics(ep[i].Host).DetectedTime()
if _, fadingIn := fadeInState(now, ctx.Route.LBFadeInDuration, detected); !fadingIn {
notFadingIndexes = append(notFadingIndexes, i)
}
}

switch a := algo.(type) {
case *roundRobin:
return shiftToRemaining(a.rnd, ctx, notFadingIndexes, now)
case *random:
return shiftToRemaining(a.rnd, ctx, notFadingIndexes, now)
case *consistentHash:
// If all endpoints are fading, normal consistent hash result
if len(notFadingIndexes) == 0 {
return ep[choice]
}
// otherwise calculate consistent hash again using endpoints which are not fading
return ep[a.chooseConsistentHashEndpoint(ctx, skipFadingEndpoints(notFadingIndexes))]
default:
return ep[choice]
}
}

type roundRobin struct {
index int64
rnd *rand.Rand
}

func newRoundRobin(endpoints []string) routing.LBAlgorithm {
rnd := rand.New(newLockedSource()) // #nosec
rnd := rand.New(NewLockedSource()) // #nosec
return &roundRobin{
index: int64(rnd.Intn(len(endpoints))),
rnd: rnd,
}
}

Expand All @@ -169,13 +71,8 @@ func (r *roundRobin) Apply(ctx *routing.LBContext) routing.LBEndpoint {
return ctx.LBEndpoints[0]
}

index := int(atomic.AddInt64(&r.index, 1) % int64(len(ctx.LBEndpoints)))

if ctx.Route.LBFadeInDuration <= 0 {
return ctx.LBEndpoints[index]
}

return withFadeIn(r.rnd, ctx, index, r)
choice := int(atomic.AddInt64(&r.index, 1) % int64(len(ctx.LBEndpoints)))
return ctx.LBEndpoints[choice]
}

type random struct {
Expand All @@ -185,7 +82,7 @@ type random struct {
func newRandom(endpoints []string) routing.LBAlgorithm {
// #nosec
return &random{
rnd: rand.New(newLockedSource()),
rnd: rand.New(NewLockedSource()),
}
}

Expand All @@ -195,12 +92,8 @@ func (r *random) Apply(ctx *routing.LBContext) routing.LBEndpoint {
return ctx.LBEndpoints[0]
}

i := r.rnd.Intn(len(ctx.LBEndpoints))
if ctx.Route.LBFadeInDuration <= 0 {
return ctx.LBEndpoints[i]
}

return withFadeIn(r.rnd, ctx, i, r)
choice := r.rnd.Intn(len(ctx.LBEndpoints))
return ctx.LBEndpoints[choice]
}

type (
Expand All @@ -210,7 +103,6 @@ type (
}
consistentHash struct {
hashRing []endpointHash // list of endpoints sorted by hash value
rnd *rand.Rand
}
)

Expand All @@ -221,10 +113,8 @@ func (ch *consistentHash) Swap(i, j int) {
}

func newConsistentHashInternal(endpoints []string, hashesPerEndpoint int) routing.LBAlgorithm {
rnd := rand.New(newLockedSource()) // #nosec
ch := &consistentHash{
hashRing: make([]endpointHash, hashesPerEndpoint*len(endpoints)),
rnd: rnd,
}
for i, ep := range endpoints {
endpointStartIndex := hashesPerEndpoint * i
Expand All @@ -245,21 +135,18 @@ func hash(s string) uint64 {
}

// Returns index in hash ring with the closest hash to key's hash
func (ch *consistentHash) searchRing(key string, skipEndpoint func(int) bool) int {
func (ch *consistentHash) searchRing(key string) int {
h := hash(key)
i := sort.Search(ch.Len(), func(i int) bool { return ch.hashRing[i].hash >= h })
if i == ch.Len() { // rollover
i = 0
}
for skipEndpoint(ch.hashRing[i].index) {
i = (i + 1) % ch.Len()
}
return i
}

// Returns index of endpoint with closest hash to key's hash
func (ch *consistentHash) search(key string, skipEndpoint func(int) bool) int {
ringIndex := ch.searchRing(key, skipEndpoint)
func (ch *consistentHash) search(key string) int {
ringIndex := ch.searchRing(key)
return ch.hashRing[ringIndex].index
}

Expand All @@ -274,16 +161,13 @@ func computeLoadAverage(ctx *routing.LBContext) float64 {

// Returns index of endpoint with closest hash to key's hash, which is also below the target load
// skipEndpoint function is used to skip endpoints we don't want, such as fading endpoints
func (ch *consistentHash) boundedLoadSearch(key string, balanceFactor float64, ctx *routing.LBContext, skipEndpoint func(int) bool) int {
ringIndex := ch.searchRing(key, skipEndpoint)
func (ch *consistentHash) boundedLoadSearch(key string, balanceFactor float64, ctx *routing.LBContext) int {
ringIndex := ch.searchRing(key)
averageLoad := computeLoadAverage(ctx)
targetLoad := averageLoad * balanceFactor
// Loop round ring, starting at endpoint with closest hash. Stop when we find one whose load is less than targetLoad.
for i := 0; i < ch.Len(); i++ {
endpointIndex := ch.hashRing[ringIndex].index
if skipEndpoint(endpointIndex) {
continue
}
load := ctx.Registry.GetMetrics(ctx.LBEndpoints[endpointIndex].Host).InflightRequests()
// We know there must be an endpoint whose load <= average load.
// Since targetLoad >= average load (balancerFactor >= 1), there must also be an endpoint with load <= targetLoad.
Expand All @@ -302,46 +186,26 @@ func (ch *consistentHash) Apply(ctx *routing.LBContext) routing.LBEndpoint {
return ctx.LBEndpoints[0]
}

choice := ch.chooseConsistentHashEndpoint(ctx, noSkippedEndpoints)

if ctx.Route.LBFadeInDuration <= 0 {
return ctx.LBEndpoints[choice]
}

return withFadeIn(ch.rnd, ctx, choice, ch)
choice := ch.chooseConsistentHashEndpoint(ctx)
return ctx.LBEndpoints[choice]
}

func (ch *consistentHash) chooseConsistentHashEndpoint(ctx *routing.LBContext, skipEndpoint func(int) bool) int {
func (ch *consistentHash) chooseConsistentHashEndpoint(ctx *routing.LBContext) int {
key, ok := ctx.Params[ConsistentHashKey].(string)
if !ok {
key = snet.RemoteHost(ctx.Request).String()
}
balanceFactor, ok := ctx.Params[ConsistentHashBalanceFactor].(float64)
var choice int
if !ok {
choice = ch.search(key, skipEndpoint)
choice = ch.search(key)
} else {
choice = ch.boundedLoadSearch(key, balanceFactor, ctx, skipEndpoint)
choice = ch.boundedLoadSearch(key, balanceFactor, ctx)
}

return choice
}

func skipFadingEndpoints(notFadingEndpoints []int) func(int) bool {
return func(i int) bool {
for _, notFadingEndpoint := range notFadingEndpoints {
if i == notFadingEndpoint {
return false
}
}
return true
}
}

func noSkippedEndpoints(_ int) bool {
return false
}

type powerOfRandomNChoices struct {
mx sync.Mutex
rnd *rand.Rand
Expand All @@ -350,7 +214,7 @@ type powerOfRandomNChoices struct {

// newPowerOfRandomNChoices selects N random backends and picks the one with less outstanding requests.
func newPowerOfRandomNChoices([]string) routing.LBAlgorithm {
rnd := rand.New(newLockedSource()) // #nosec
rnd := rand.New(NewLockedSource()) // #nosec
return &powerOfRandomNChoices{
rnd: rnd,
numberOfChoices: powerOfRandomNChoicesDefaultN,
Expand Down
2 changes: 1 addition & 1 deletion loadbalancer/algorithm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ func TestApply(t *testing.T) {
func TestConsistentHashSearch(t *testing.T) {
apply := func(key string, endpoints []string) string {
ch := newConsistentHash(endpoints).(*consistentHash)
return endpoints[ch.search(key, noSkippedEndpoints)]
return endpoints[ch.search(key)]
}

endpoints := []string{"http://127.0.0.1:8080", "http://127.0.0.2:8080", "http://127.0.0.3:8080"}
Expand Down
Loading

0 comments on commit e0db556

Please sign in to comment.