Skip to content

Commit

Permalink
Added LBEndpoint field in routing.LBContext struct
Browse files Browse the repository at this point in the history
The routing.Route field of routing.LBContext struct will be
eventually removed because loadbalancer in general does not
require the full information about route, only about endpoints
to balance load between.

Signed-off-by: Roman Zavodskikh <[email protected]>
  • Loading branch information
Roman Zavodskikh committed Oct 5, 2023
1 parent faa363c commit 6369d6e
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 39 deletions.
40 changes: 20 additions & 20 deletions loadbalancer/algorithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func fadeIn(now time.Time, duration time.Duration, exponent float64, detected ti
func shiftWeighted(rnd *rand.Rand, ctx *routing.LBContext, now time.Time) routing.LBEndpoint {
var sum float64
rt := ctx.Route
ep := ctx.Route.LBEndpoints
ep := ctx.LBEndpoints
for _, epi := range ep {
detected := ctx.Registry.GetMetrics(epi.Host).DetectedTime()
wi := fadeIn(now, rt.LBFadeInDuration, rt.LBFadeInExponent, detected)
Expand All @@ -97,7 +97,7 @@ func shiftWeighted(rnd *rand.Rand, ctx *routing.LBContext, now time.Time) routin

func shiftToRemaining(rnd *rand.Rand, ctx *routing.LBContext, wi []int, now time.Time) routing.LBEndpoint {
notFadingIndexes := wi
ep := ctx.Route.LBEndpoints
ep := ctx.LBEndpoints

// if all endpoints are fading, the simplest approach is to use the oldest,
// this departs from the desired curve, but guarantees monotonic fade-in. From
Expand All @@ -112,9 +112,9 @@ func shiftToRemaining(rnd *rand.Rand, ctx *routing.LBContext, wi []int, now time
}

func withFadeIn(rnd *rand.Rand, ctx *routing.LBContext, choice int, algo routing.LBAlgorithm) routing.LBEndpoint {
ep := ctx.Route.LBEndpoints
ep := ctx.LBEndpoints
now := time.Now()
detected := ctx.Registry.GetMetrics(ctx.Route.LBEndpoints[choice].Host).DetectedTime()
detected := ctx.Registry.GetMetrics(ctx.LBEndpoints[choice].Host).DetectedTime()
f := fadeIn(
now,
ctx.Route.LBFadeInDuration,
Expand Down Expand Up @@ -165,14 +165,14 @@ func newRoundRobin(endpoints []string) routing.LBAlgorithm {

// Apply implements routing.LBAlgorithm with a roundrobin algorithm.
func (r *roundRobin) Apply(ctx *routing.LBContext) routing.LBEndpoint {
if len(ctx.Route.LBEndpoints) == 1 {
return ctx.Route.LBEndpoints[0]
if len(ctx.LBEndpoints) == 1 {
return ctx.LBEndpoints[0]
}

index := int(atomic.AddInt64(&r.index, 1) % int64(len(ctx.Route.LBEndpoints)))
index := int(atomic.AddInt64(&r.index, 1) % int64(len(ctx.LBEndpoints)))

if ctx.Route.LBFadeInDuration <= 0 {
return ctx.Route.LBEndpoints[index]
return ctx.LBEndpoints[index]
}

return withFadeIn(r.rnd, ctx, index, r)
Expand All @@ -191,13 +191,13 @@ func newRandom(endpoints []string) routing.LBAlgorithm {

// Apply implements routing.LBAlgorithm with a stateless random algorithm.
func (r *random) Apply(ctx *routing.LBContext) routing.LBEndpoint {
if len(ctx.Route.LBEndpoints) == 1 {
return ctx.Route.LBEndpoints[0]
if len(ctx.LBEndpoints) == 1 {
return ctx.LBEndpoints[0]
}

i := r.rnd.Intn(len(ctx.Route.LBEndpoints))
i := r.rnd.Intn(len(ctx.LBEndpoints))
if ctx.Route.LBFadeInDuration <= 0 {
return ctx.Route.LBEndpoints[i]
return ctx.LBEndpoints[i]
}

return withFadeIn(r.rnd, ctx, i, r)
Expand Down Expand Up @@ -265,7 +265,7 @@ func (ch *consistentHash) search(key string, skipEndpoint func(int) bool) int {

func computeLoadAverage(ctx *routing.LBContext) float64 {
sum := 1.0 // add 1 to include the request that just arrived
endpoints := ctx.Route.LBEndpoints
endpoints := ctx.LBEndpoints
for _, v := range endpoints {
sum += float64(v.Metrics.GetInflightRequests())
}
Expand All @@ -284,7 +284,7 @@ func (ch *consistentHash) boundedLoadSearch(key string, balanceFactor float64, c
if skipEndpoint(endpointIndex) {
continue
}
load := ctx.Route.LBEndpoints[endpointIndex].Metrics.GetInflightRequests()
load := ctx.LBEndpoints[endpointIndex].Metrics.GetInflightRequests()
// We know there must be an endpoint whose load <= average load.
// Since targetLoad >= average load (balancerFactor >= 1), there must also be an endpoint with load <= targetLoad.
if load <= int(targetLoad) {
Expand All @@ -298,14 +298,14 @@ func (ch *consistentHash) boundedLoadSearch(key string, balanceFactor float64, c

// Apply implements routing.LBAlgorithm with a consistent hash algorithm.
func (ch *consistentHash) Apply(ctx *routing.LBContext) routing.LBEndpoint {
if len(ctx.Route.LBEndpoints) == 1 {
return ctx.Route.LBEndpoints[0]
if len(ctx.LBEndpoints) == 1 {
return ctx.LBEndpoints[0]
}

choice := ch.chooseConsistentHashEndpoint(ctx, noSkippedEndpoints)

if ctx.Route.LBFadeInDuration <= 0 {
return ctx.Route.LBEndpoints[choice]
return ctx.LBEndpoints[choice]
}

return withFadeIn(ch.rnd, ctx, choice, ch)
Expand Down Expand Up @@ -359,15 +359,15 @@ func newPowerOfRandomNChoices([]string) routing.LBAlgorithm {

// Apply implements routing.LBAlgorithm with power of random N choices algorithm.
func (p *powerOfRandomNChoices) Apply(ctx *routing.LBContext) routing.LBEndpoint {
ne := len(ctx.Route.LBEndpoints)
ne := len(ctx.LBEndpoints)

p.mx.Lock()
defer p.mx.Unlock()

best := ctx.Route.LBEndpoints[p.rnd.Intn(ne)]
best := ctx.LBEndpoints[p.rnd.Intn(ne)]

for i := 1; i < p.numberOfChoices; i++ {
ce := ctx.Route.LBEndpoints[p.rnd.Intn(ne)]
ce := ctx.LBEndpoints[p.rnd.Intn(ne)]

if p.getScore(ce) > p.getScore(best) {
best = ce
Expand Down
33 changes: 18 additions & 15 deletions loadbalancer/algorithm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,10 @@ func TestApply(t *testing.T) {
rt := p.Do([]*routing.Route{r})

lbctx := &routing.LBContext{
Request: req,
Route: rt[0],
Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}),
Request: req,
Route: rt[0],
Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}),
LBEndpoints: rt[0].LBEndpoints,
}

h := make(map[string]int)
Expand Down Expand Up @@ -316,13 +317,14 @@ func TestConsistentHashBoundedLoadSearch(t *testing.T) {
}})[0]
ch := route.LBAlgorithm.(*consistentHash)
ctx := &routing.LBContext{
Request: r,
Route: route,
Params: map[string]interface{}{ConsistentHashBalanceFactor: 1.25},
Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}),
Request: r,
Route: route,
Params: map[string]interface{}{ConsistentHashBalanceFactor: 1.25},
Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}),
LBEndpoints: route.LBEndpoints,
}
noLoad := ch.Apply(ctx)
nonBounded := ch.Apply(&routing.LBContext{Request: r, Route: route, Params: map[string]interface{}{}})
nonBounded := ch.Apply(&routing.LBContext{Request: r, Route: route, Params: map[string]interface{}{}, LBEndpoints: route.LBEndpoints})

if noLoad != nonBounded {
t.Error("When no endpoints are overloaded, the chosen endpoint should be the same as standard consistentHash")
Expand Down Expand Up @@ -364,16 +366,16 @@ func TestConsistentHashKey(t *testing.T) {
},
}})[0]

defaultEndpoint := ch.Apply(&routing.LBContext{Request: r, Route: rt, Params: make(map[string]interface{})})
remoteHostEndpoint := ch.Apply(&routing.LBContext{Request: r, Route: rt, Params: map[string]interface{}{ConsistentHashKey: net.RemoteHost(r).String()}})
defaultEndpoint := ch.Apply(&routing.LBContext{Request: r, Route: rt, Params: make(map[string]interface{}), LBEndpoints: rt.LBEndpoints})
remoteHostEndpoint := ch.Apply(&routing.LBContext{Request: r, Route: rt, Params: map[string]interface{}{ConsistentHashKey: net.RemoteHost(r).String()}, LBEndpoints: rt.LBEndpoints})

if defaultEndpoint != remoteHostEndpoint {
t.Error("remote host should be used as a default key")
}

for i, ep := range endpoints {
key := fmt.Sprintf("%s-%d", ep, 1) // "ep-0" to "ep-99" is the range of keys for this endpoint. If we use this as the hash key it should select endpoint ep.
selected := ch.Apply(&routing.LBContext{Request: r, Route: rt, Params: map[string]interface{}{ConsistentHashKey: key}})
selected := ch.Apply(&routing.LBContext{Request: r, Route: rt, Params: map[string]interface{}{ConsistentHashKey: key}, LBEndpoints: rt.LBEndpoints})
if selected != rt.LBEndpoints[i] {
t.Errorf("expected: %v, got %v", rt.LBEndpoints[i], selected)
}
Expand All @@ -393,10 +395,11 @@ func TestConsistentHashBoundedLoadDistribution(t *testing.T) {
ch := route.LBAlgorithm.(*consistentHash)
balanceFactor := 1.25
ctx := &routing.LBContext{
Request: r,
Route: route,
Params: map[string]interface{}{ConsistentHashBalanceFactor: balanceFactor},
Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}),
Request: r,
Route: route,
Params: map[string]interface{}{ConsistentHashBalanceFactor: balanceFactor},
Registry: routing.NewEndpointRegistry(routing.RegistryOptions{}),
LBEndpoints: route.LBEndpoints,
}

for i := 0; i < 100; i++ {
Expand Down
1 change: 1 addition & 0 deletions loadbalancer/fadein_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func initializeEndpoints(endpointAges []time.Duration, fadeInDuration time.Durat
})
ctx.Registry.SetDetectedTime(eps[i], detectionTimes[i])
}
ctx.LBEndpoints = ctx.Route.LBEndpoints

return ctx, eps
}
Expand Down
2 changes: 2 additions & 0 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@ func setRequestURLForDynamicBackend(u *url.URL, stateBag map[string]interface{})
}

func setRequestURLForLoadBalancedBackend(u *url.URL, rt *routing.Route, lbctx *routing.LBContext) *routing.LBEndpoint {
lbctx.LBEndpoints = rt.LBEndpoints

e := rt.LBAlgorithm.Apply(lbctx)
u.Scheme = e.Scheme
u.Host = e.Host
Expand Down
9 changes: 5 additions & 4 deletions routing/routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,11 @@ type LBAlgorithm interface {
// LBContext is used to pass data to the load balancer to decide based
// on that data which endpoint to call from the backends
type LBContext struct {
Request *http.Request
Route *Route
Params map[string]interface{}
Registry *EndpointRegistry
Request *http.Request
Route *Route
Params map[string]interface{}
Registry *EndpointRegistry
LBEndpoints []LBEndpoint
}

// NewLBContext is used to create a new LBContext, to pass data to the
Expand Down

0 comments on commit 6369d6e

Please sign in to comment.