Skip to content

Commit

Permalink
Decouple fadeIn from loadbalancer
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Zavodskikh <[email protected]>
  • Loading branch information
Roman Zavodskikh committed Sep 28, 2023
1 parent c823514 commit ca7934c
Show file tree
Hide file tree
Showing 10 changed files with 377 additions and 575 deletions.
189 changes: 29 additions & 160 deletions loadbalancer/algorithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@ package loadbalancer
import (
"errors"
"fmt"
"math"
"math/rand"
"net"
"net/url"
"sort"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/cespare/xxhash/v2"
log "github.com/sirupsen/logrus"
Expand Down Expand Up @@ -56,107 +54,13 @@ var (
defaultAlgorithm = newRoundRobin
)

func fadeInState(now time.Time, duration time.Duration, detected time.Time) (time.Duration, bool) {
rel := now.Sub(detected)
return rel, rel > 0 && rel < duration
}

func fadeIn(now time.Time, duration time.Duration, exponent float64, detected time.Time) float64 {
rel, fadingIn := fadeInState(now, duration, detected)
if !fadingIn {
return 1
}

return math.Pow(float64(rel)/float64(duration), exponent)
}

func shiftWeighted(rnd *rand.Rand, ctx *routing.LBContext, now time.Time) routing.LBEndpoint {
var sum float64
rt := ctx.Route
ep := ctx.Route.LBEndpoints
for _, epi := range ep {
detected := ctx.Registry.GetMetrics(epi.Host).DetectedTime()
wi := fadeIn(now, rt.LBFadeInDuration, rt.LBFadeInExponent, detected)
sum += wi
}

choice := ep[len(ep)-1]
r := rnd.Float64() * sum
var upto float64
for i, epi := range ep {
detected := ctx.Registry.GetMetrics(epi.Host).DetectedTime()
upto += fadeIn(now, rt.LBFadeInDuration, rt.LBFadeInExponent, detected)
if upto > r {
choice = ep[i]
break
}
}

return choice
}

func shiftToRemaining(rnd *rand.Rand, ctx *routing.LBContext, wi []int, now time.Time) routing.LBEndpoint {
notFadingIndexes := wi
ep := ctx.Route.LBEndpoints

// if all endpoints are fading, the simplest approach is to use the oldest,
// this departs from the desired curve, but guarantees monotonic fade-in. From
// the perspective of the oldest endpoint, this is temporarily the same as if
// there was no fade-in.
if len(notFadingIndexes) == 0 {
return shiftWeighted(rnd, ctx, now)
}

// otherwise equally distribute between the old endpoints
return ep[notFadingIndexes[rnd.Intn(len(notFadingIndexes))]]
}

func withFadeIn(rnd *rand.Rand, ctx *routing.LBContext, choice int, algo routing.LBAlgorithm) routing.LBEndpoint {
ep := ctx.Route.LBEndpoints
now := time.Now()
detected := ctx.Registry.GetMetrics(ctx.Route.LBEndpoints[choice].Host).DetectedTime()
f := fadeIn(
now,
ctx.Route.LBFadeInDuration,
ctx.Route.LBFadeInExponent,
detected,
)

if rnd.Float64() < f {
return ep[choice]
}
notFadingIndexes := make([]int, 0, len(ep))
for i := 0; i < len(ep); i++ {
detected := ctx.Registry.GetMetrics(ep[i].Host).DetectedTime()
if _, fadingIn := fadeInState(now, ctx.Route.LBFadeInDuration, detected); !fadingIn {
notFadingIndexes = append(notFadingIndexes, i)
}
}

switch a := algo.(type) {
case *roundRobin:
return shiftToRemaining(a.rnd, ctx, notFadingIndexes, now)
case *random:
return shiftToRemaining(a.rnd, ctx, notFadingIndexes, now)
case *consistentHash:
// If all endpoints are fading, normal consistent hash result
if len(notFadingIndexes) == 0 {
return ep[choice]
}
// otherwise calculate consistent hash again using endpoints which are not fading
return ep[a.chooseConsistentHashEndpoint(ctx, skipFadingEndpoints(notFadingIndexes))]
default:
return ep[choice]
}
}

type roundRobin struct {
index int64
rnd *rand.Rand
}

func newRoundRobin(endpoints []string) routing.LBAlgorithm {
rnd := rand.New(newLockedSource()) // #nosec
rnd := rand.New(routing.NewLockedSource()) // #nosec
return &roundRobin{
index: int64(rnd.Intn(len(endpoints))),
rnd: rnd,
Expand All @@ -165,17 +69,12 @@ func newRoundRobin(endpoints []string) routing.LBAlgorithm {

// Apply implements routing.LBAlgorithm with a roundrobin algorithm.
func (r *roundRobin) Apply(ctx *routing.LBContext) routing.LBEndpoint {
if len(ctx.Route.LBEndpoints) == 1 {
return ctx.Route.LBEndpoints[0]
if len(ctx.Endpoints) == 1 {
return ctx.Endpoints[0]
}

index := int(atomic.AddInt64(&r.index, 1) % int64(len(ctx.Route.LBEndpoints)))

if ctx.Route.LBFadeInDuration <= 0 {
return ctx.Route.LBEndpoints[index]
}

return withFadeIn(r.rnd, ctx, index, r)
choice := int(atomic.AddInt64(&r.index, 1) % int64(len(ctx.Endpoints)))
return ctx.Endpoints[choice]
}

type random struct {
Expand All @@ -185,22 +84,18 @@ type random struct {
func newRandom(endpoints []string) routing.LBAlgorithm {
// #nosec
return &random{
rnd: rand.New(newLockedSource()),
rnd: rand.New(routing.NewLockedSource()),
}
}

// Apply implements routing.LBAlgorithm with a stateless random algorithm.
func (r *random) Apply(ctx *routing.LBContext) routing.LBEndpoint {
if len(ctx.Route.LBEndpoints) == 1 {
return ctx.Route.LBEndpoints[0]
if len(ctx.Endpoints) == 1 {
return ctx.Endpoints[0]
}

i := r.rnd.Intn(len(ctx.Route.LBEndpoints))
if ctx.Route.LBFadeInDuration <= 0 {
return ctx.Route.LBEndpoints[i]
}

return withFadeIn(r.rnd, ctx, i, r)
choice := r.rnd.Intn(len(ctx.Endpoints))
return ctx.Endpoints[choice]
}

type (
Expand All @@ -221,7 +116,7 @@ func (ch *consistentHash) Swap(i, j int) {
}

func newConsistentHashInternal(endpoints []string, hashesPerEndpoint int) routing.LBAlgorithm {
rnd := rand.New(newLockedSource()) // #nosec
rnd := rand.New(routing.NewLockedSource()) // #nosec
ch := &consistentHash{
hashRing: make([]endpointHash, hashesPerEndpoint*len(endpoints)),
rnd: rnd,
Expand All @@ -245,27 +140,24 @@ func hash(s string) uint64 {
}

// Returns index in hash ring with the closest hash to key's hash
func (ch *consistentHash) searchRing(key string, skipEndpoint func(int) bool) int {
func (ch *consistentHash) searchRing(key string) int {
h := hash(key)
i := sort.Search(ch.Len(), func(i int) bool { return ch.hashRing[i].hash >= h })
if i == ch.Len() { // rollover
i = 0
}
for skipEndpoint(ch.hashRing[i].index) {
i = (i + 1) % ch.Len()
}
return i
}

// Returns index of endpoint with closest hash to key's hash
func (ch *consistentHash) search(key string, skipEndpoint func(int) bool) int {
ringIndex := ch.searchRing(key, skipEndpoint)
func (ch *consistentHash) search(key string) int {
ringIndex := ch.searchRing(key)
return ch.hashRing[ringIndex].index
}

func computeLoadAverage(ctx *routing.LBContext) float64 {
sum := 1.0 // add 1 to include the request that just arrived
endpoints := ctx.Route.LBEndpoints
endpoints := ctx.Endpoints
for _, v := range endpoints {
sum += float64(v.Metrics.GetInflightRequests())
}
Expand All @@ -274,17 +166,14 @@ func computeLoadAverage(ctx *routing.LBContext) float64 {

// Returns index of endpoint with closest hash to key's hash, which is also below the target load
// skipEndpoint function is used to skip endpoints we don't want, such as fading endpoints
func (ch *consistentHash) boundedLoadSearch(key string, balanceFactor float64, ctx *routing.LBContext, skipEndpoint func(int) bool) int {
ringIndex := ch.searchRing(key, skipEndpoint)
func (ch *consistentHash) boundedLoadSearch(key string, balanceFactor float64, ctx *routing.LBContext) int {
ringIndex := ch.searchRing(key)
averageLoad := computeLoadAverage(ctx)
targetLoad := averageLoad * balanceFactor
// Loop round ring, starting at endpoint with closest hash. Stop when we find one whose load is less than targetLoad.
for i := 0; i < ch.Len(); i++ {
endpointIndex := ch.hashRing[ringIndex].index
if skipEndpoint(endpointIndex) {
continue
}
load := ctx.Route.LBEndpoints[endpointIndex].Metrics.GetInflightRequests()
load := ctx.Endpoints[endpointIndex].Metrics.GetInflightRequests()
// We know there must be an endpoint whose load <= average load.
// Since targetLoad >= average load (balancerFactor >= 1), there must also be an endpoint with load <= targetLoad.
if load <= int(targetLoad) {
Expand All @@ -298,50 +187,30 @@ func (ch *consistentHash) boundedLoadSearch(key string, balanceFactor float64, c

// Apply implements routing.LBAlgorithm with a consistent hash algorithm.
func (ch *consistentHash) Apply(ctx *routing.LBContext) routing.LBEndpoint {
if len(ctx.Route.LBEndpoints) == 1 {
return ctx.Route.LBEndpoints[0]
}

choice := ch.chooseConsistentHashEndpoint(ctx, noSkippedEndpoints)

if ctx.Route.LBFadeInDuration <= 0 {
return ctx.Route.LBEndpoints[choice]
if len(ctx.Endpoints) == 1 {
return ctx.Endpoints[0]
}

return withFadeIn(ch.rnd, ctx, choice, ch)
choice := ch.chooseConsistentHashEndpoint(ctx)
return ctx.Endpoints[choice]
}

func (ch *consistentHash) chooseConsistentHashEndpoint(ctx *routing.LBContext, skipEndpoint func(int) bool) int {
func (ch *consistentHash) chooseConsistentHashEndpoint(ctx *routing.LBContext) int {
key, ok := ctx.Params[ConsistentHashKey].(string)
if !ok {
key = snet.RemoteHost(ctx.Request).String()
}
balanceFactor, ok := ctx.Params[ConsistentHashBalanceFactor].(float64)
var choice int
if !ok {
choice = ch.search(key, skipEndpoint)
choice = ch.search(key)
} else {
choice = ch.boundedLoadSearch(key, balanceFactor, ctx, skipEndpoint)
choice = ch.boundedLoadSearch(key, balanceFactor, ctx)
}

return choice
}

func skipFadingEndpoints(notFadingEndpoints []int) func(int) bool {
return func(i int) bool {
for _, notFadingEndpoint := range notFadingEndpoints {
if i == notFadingEndpoint {
return false
}
}
return true
}
}

func noSkippedEndpoints(_ int) bool {
return false
}

type powerOfRandomNChoices struct {
mx sync.Mutex
rnd *rand.Rand
Expand All @@ -350,7 +219,7 @@ type powerOfRandomNChoices struct {

// newPowerOfRandomNChoices selects N random backends and picks the one with less outstanding requests.
func newPowerOfRandomNChoices([]string) routing.LBAlgorithm {
rnd := rand.New(newLockedSource()) // #nosec
rnd := rand.New(routing.NewLockedSource()) // #nosec
return &powerOfRandomNChoices{
rnd: rnd,
numberOfChoices: powerOfRandomNChoicesDefaultN,
Expand All @@ -359,15 +228,15 @@ func newPowerOfRandomNChoices([]string) routing.LBAlgorithm {

// Apply implements routing.LBAlgorithm with power of random N choices algorithm.
func (p *powerOfRandomNChoices) Apply(ctx *routing.LBContext) routing.LBEndpoint {
ne := len(ctx.Route.LBEndpoints)
ne := len(ctx.Endpoints)

p.mx.Lock()
defer p.mx.Unlock()

best := ctx.Route.LBEndpoints[p.rnd.Intn(ne)]
best := ctx.Endpoints[p.rnd.Intn(ne)]

for i := 1; i < p.numberOfChoices; i++ {
ce := ctx.Route.LBEndpoints[p.rnd.Intn(ne)]
ce := ctx.Endpoints[p.rnd.Intn(ne)]

if p.getScore(ce) > p.getScore(best) {
best = ce
Expand Down
32 changes: 16 additions & 16 deletions loadbalancer/algorithm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ func TestApply(t *testing.T) {
rt := p.Do([]*routing.Route{r})

lbctx := &routing.LBContext{
Request: req,
Route: rt[0],
Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}),
Request: req,
Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}),
Endpoints: rt[0].LBEndpoints,
}

h := make(map[string]int)
Expand All @@ -279,7 +279,7 @@ func TestApply(t *testing.T) {
func TestConsistentHashSearch(t *testing.T) {
apply := func(key string, endpoints []string) string {
ch := newConsistentHash(endpoints).(*consistentHash)
return endpoints[ch.search(key, noSkippedEndpoints)]
return endpoints[ch.search(key)]
}

endpoints := []string{"http://127.0.0.1:8080", "http://127.0.0.2:8080", "http://127.0.0.3:8080"}
Expand Down Expand Up @@ -316,13 +316,13 @@ func TestConsistentHashBoundedLoadSearch(t *testing.T) {
}})[0]
ch := route.LBAlgorithm.(*consistentHash)
ctx := &routing.LBContext{
Request: r,
Route: route,
Params: map[string]interface{}{ConsistentHashBalanceFactor: 1.25},
Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}),
Request: r,
Params: map[string]interface{}{ConsistentHashBalanceFactor: 1.25},
Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}),
Endpoints: route.LBEndpoints,
}
noLoad := ch.Apply(ctx)
nonBounded := ch.Apply(&routing.LBContext{Request: r, Route: route, Params: map[string]interface{}{}})
nonBounded := ch.Apply(&routing.LBContext{Request: r, Params: map[string]interface{}{}, Endpoints: route.LBEndpoints})

if noLoad != nonBounded {
t.Error("When no endpoints are overloaded, the chosen endpoint should be the same as standard consistentHash")
Expand Down Expand Up @@ -364,16 +364,16 @@ func TestConsistentHashKey(t *testing.T) {
},
}})[0]

defaultEndpoint := ch.Apply(&routing.LBContext{Request: r, Route: rt, Params: make(map[string]interface{})})
remoteHostEndpoint := ch.Apply(&routing.LBContext{Request: r, Route: rt, Params: map[string]interface{}{ConsistentHashKey: net.RemoteHost(r).String()}})
defaultEndpoint := ch.Apply(&routing.LBContext{Request: r, Params: make(map[string]interface{}), Endpoints: rt.LBEndpoints})
remoteHostEndpoint := ch.Apply(&routing.LBContext{Request: r, Params: map[string]interface{}{ConsistentHashKey: net.RemoteHost(r).String()}, Endpoints: rt.LBEndpoints})

if defaultEndpoint != remoteHostEndpoint {
t.Error("remote host should be used as a default key")
}

for i, ep := range endpoints {
key := fmt.Sprintf("%s-%d", ep, 1) // "ep-0" to "ep-99" is the range of keys for this endpoint. If we use this as the hash key it should select endpoint ep.
selected := ch.Apply(&routing.LBContext{Request: r, Route: rt, Params: map[string]interface{}{ConsistentHashKey: key}})
selected := ch.Apply(&routing.LBContext{Request: r, Params: map[string]interface{}{ConsistentHashKey: key}, Endpoints: rt.LBEndpoints})
if selected != rt.LBEndpoints[i] {
t.Errorf("expected: %v, got %v", rt.LBEndpoints[i], selected)
}
Expand All @@ -393,10 +393,10 @@ func TestConsistentHashBoundedLoadDistribution(t *testing.T) {
ch := route.LBAlgorithm.(*consistentHash)
balanceFactor := 1.25
ctx := &routing.LBContext{
Request: r,
Route: route,
Params: map[string]interface{}{ConsistentHashBalanceFactor: balanceFactor},
Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}),
Request: r,
Params: map[string]interface{}{ConsistentHashBalanceFactor: balanceFactor},
Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}),
Endpoints: route.LBEndpoints,
}

for i := 0; i < 100; i++ {
Expand Down
Loading

0 comments on commit ca7934c

Please sign in to comment.