Skip to content

Commit

Permalink
Guilherme/ora 1168 do stake weighted binary selection then sort of re…
Browse files Browse the repository at this point in the history
…puter (#115)

Co-authored-by: Kenny <[email protected]>
  • Loading branch information
guilherme-brandao and kpeluso authored Apr 11, 2024
1 parent 3e4da21 commit 31df69f
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 42 deletions.
29 changes: 27 additions & 2 deletions math/dec.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import (
// but when copying the big.Int structure can be shared between Decimal instances causing corruption.
// This was originally discovered in regen0-network/mainnet#15.
type Dec struct {
dec apd.Decimal
dec apd.Decimal
isNaN bool
}

// constants for more convenient intent behind dec.Cmp values.
Expand Down Expand Up @@ -61,6 +62,10 @@ var dec128Context = apd.Context{
Traps: apd.DefaultTraps,
}

func NewNaN() Dec {
return Dec{apd.Decimal{}, true}
}

func NewDecFromString(s string) (Dec, error) {
if s == "" {
s = "0"
Expand All @@ -70,7 +75,7 @@ func NewDecFromString(s string) (Dec, error) {
return Dec{}, ErrInvalidDecString.Wrap(err.Error())
}

d1 := Dec{*d}
d1 := Dec{*d, false}
if d1.dec.Form == apd.Infinite {
return d1, ErrInfiniteString.Wrapf(s)
}
Expand Down Expand Up @@ -439,10 +444,30 @@ func (x Dec) Cmp(y Dec) int {
return x.dec.Cmp(&y.dec)
}

func (x Dec) Gt(y Dec) bool {
return x.dec.Cmp(&y.dec) == 1
}

func (x Dec) Gte(y Dec) bool {
return x.dec.Cmp(&y.dec) == 1 || x.dec.Cmp(&y.dec) == 0
}

func (x Dec) Lt(y Dec) bool {
return x.dec.Cmp(&y.dec) == -1
}

func (x Dec) Lte(y Dec) bool {
return x.dec.Cmp(&y.dec) == -1 || x.dec.Cmp(&y.dec) == 0
}

func (x Dec) Equal(y Dec) bool {
return x.dec.Cmp(&y.dec) == 0
}

func (x Dec) IsNaN() bool {
return x.isNaN
}

// IsZero returns true if the decimal is zero.
func (x Dec) IsZero() bool {
return x.dec.IsZero()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func CalcForcastImpliedInferences(
maxjRijk = R_ik[j]
first = false
} else {
if R_ik[j].Cmp(maxjRijk) == alloraMath.GreaterThan {
if R_ik[j].Gt(maxjRijk) {
maxjRijk = R_ik[j]
}
}
Expand Down
12 changes: 6 additions & 6 deletions x/emissions/keeper/inference_synthesis/network_inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func FindMaxRegretAmongWorkersWithLosses(
fmt.Println("Error getting inferer regret: ", err)
return MaximalRegrets{}, err // TODO: THIS OR continue ??
}
if maxInfererRegret.Cmp(infererRegret.Value) == alloraMath.LessThan {
if maxInfererRegret.Lt(infererRegret.Value) {
maxInfererRegret = infererRegret.Value
}
}
Expand All @@ -54,7 +54,7 @@ func FindMaxRegretAmongWorkersWithLosses(
fmt.Println("Error getting forecaster regret: ", err)
return MaximalRegrets{}, err // TODO: THIS OR continue ??
}
if maxForecasterRegret.Cmp(forecasterRegret.Value) == alloraMath.LessThan {
if maxForecasterRegret.Lt(forecasterRegret.Value) {
maxForecasterRegret = forecasterRegret.Value
}
}
Expand All @@ -67,7 +67,7 @@ func FindMaxRegretAmongWorkersWithLosses(
fmt.Println("Error getting forecaster regret: ", err)
return MaximalRegrets{}, err // TODO: THIS OR continue ??
}
if maxOneInForecasterRegret[forecaster].Cmp(oneInForecasterRegret.Value) == alloraMath.LessThan {
if maxOneInForecasterRegret[forecaster].Lt(oneInForecasterRegret.Value) {
maxOneInForecasterRegret[forecaster] = oneInForecasterRegret.Value
}
}
Expand All @@ -76,7 +76,7 @@ func FindMaxRegretAmongWorkersWithLosses(
fmt.Println("Error getting one-in forecaster self regret: ", err)
return MaximalRegrets{}, err // TODO: THIS OR continue ??
}
if maxOneInForecasterRegret[forecaster].Cmp(oneInForecasterSelfRegret.Value) == alloraMath.LessThan {
if maxOneInForecasterRegret[forecaster].Lt(oneInForecasterSelfRegret.Value) {
maxOneInForecasterRegret[forecaster] = oneInForecasterSelfRegret.Value
}
}
Expand All @@ -103,7 +103,7 @@ func CalcWeightedInference(
epsilon alloraMath.Dec,
pInferenceSynthesis alloraMath.Dec,
) (InferenceValue, error) {
if maxRegret.Cmp(epsilon) == alloraMath.LessThan {
if maxRegret.Lt(epsilon) {
fmt.Println("Error maxRegret < epsilon: ", maxRegret, epsilon)
return InferenceValue{}, emissions.ErrFractionDivideByZero
}
Expand Down Expand Up @@ -185,7 +185,7 @@ func CalcWeightedInference(
}

// Normalize the network combined inference
if sumWeights.Cmp(epsilon) == alloraMath.LessThan {
if sumWeights.Lt(epsilon) {
return InferenceValue{}, emissions.ErrSumWeightsLessThanEta
}
ret, err := unnormalizedI_i.Quo(sumWeights)
Expand Down
2 changes: 1 addition & 1 deletion x/emissions/keeper/inference_synthesis/network_losses.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func RunningWeightedAvgUpdate(
if err != nil {
return WorkerRunningWeightedLoss{}, err
}
if runningWeightedAvg.SumWeight.Cmp(epsilon) == alloraMath.LessThan {
if runningWeightedAvg.SumWeight.Lt(epsilon) {
return *runningWeightedAvg, emissions.ErrFractionDivideByZero
}
weightFrac, err := weight.Quo(runningWeightedAvg.SumWeight)
Expand Down
25 changes: 25 additions & 0 deletions x/emissions/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,31 @@ func (k *Keeper) GetWorkerLatestInferenceByTopicId(
return k.inferences.Get(ctx, key)
}

// GetTopicWorkers returns a list of workers registered for a given topic ID.
func (k *Keeper) GetTopicWorkers(ctx context.Context, topicId TopicId) ([]sdk.AccAddress, error) {
var workers []sdk.AccAddress

rng := collections.NewPrefixedPairRange[TopicId, Worker](topicId)

// Iterate over the workers registered for the given topic ID
iter, err := k.topicWorkers.Iterate(ctx, rng)
if err != nil {
return nil, err
}
defer iter.Close()

for ; iter.Valid(); iter.Next() {
pair, err := iter.Key()
if err != nil {
return nil, err
}
workerAddr := pair.K2()
workers = append(workers, workerAddr)
}

return workers, nil
}

// Returns the last block height at which rewards emissions were updated
func (k *Keeper) GetLastRewardsUpdate(ctx context.Context) (int64, error) {
lastRewardsUpdate, err := k.lastRewardsUpdate.Get(ctx)
Expand Down
114 changes: 85 additions & 29 deletions x/emissions/module/rewards/rewards_internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rewards

import (
"math"
"sort"

alloraMath "github.com/allora-network/allora-chain/math"
"github.com/allora-network/allora-chain/x/emissions/types"
Expand Down Expand Up @@ -277,57 +278,94 @@ func GetStakeWeightedLoss(reputersStakes, reputersReportedLosses []alloraMath.De
func GetStakeWeightedLossMatrix(
reputersAdjustedStakes []alloraMath.Dec,
reputersReportedLosses [][]alloraMath.Dec,
) ([]alloraMath.Dec, error) {
) ([]alloraMath.Dec, []alloraMath.Dec, error) {
if len(reputersAdjustedStakes) == 0 || len(reputersReportedLosses) == 0 {
return nil, types.ErrInvalidSliceLength
return nil, nil, types.ErrInvalidSliceLength
}
var err error = nil

// Calculate total stake for normalization
totalStake := alloraMath.ZeroDec()
for _, stake := range reputersAdjustedStakes {
totalStake, err = totalStake.Add(stake)
if err != nil {
return nil, err
}
}

// Ensure every loss array is non-empty and calculate geometric mean
stakeWeightedLoss := make([]alloraMath.Dec, len(reputersReportedLosses[0]))
mostDistantValues := make([]alloraMath.Dec, len(reputersReportedLosses[0]))
for j := 0; j < len(reputersReportedLosses[0]); j++ {
// Calculate total stake to consider
// Skip stakes of reputers with NaN losses
totalStakeToConsider := alloraMath.ZeroDec()
for i, losses := range reputersReportedLosses {
// Skip if loss is NaN
if losses[j].IsNaN() {
continue
}

totalStakeToConsider, err = totalStakeToConsider.Add(reputersAdjustedStakes[i])
if err != nil {
return nil, nil, err
}
}

logSum := alloraMath.ZeroDec()
for i, losses := range reputersReportedLosses {
// Skip if loss is NaN
if losses[j].IsNaN() {
continue
}

logLosses, err := alloraMath.Log10(losses[j])
if err != nil {
return nil, err
return nil, nil, err
}
logLossesTimesStake, err := logLosses.Mul(reputersAdjustedStakes[i])
if err != nil {
return nil, err
return nil, nil, err
}
logLossesTimesStakeOverTotalStake, err := logLossesTimesStake.Quo(totalStake)
logLossesTimesStakeOverTotalStake, err := logLossesTimesStake.Quo(totalStakeToConsider)
if err != nil {
return nil, err
return nil, nil, err
}
logSum, err = logSum.Add(logLossesTimesStakeOverTotalStake)
if err != nil {
return nil, err
return nil, nil, err
}
}
ten := alloraMath.NewDecFromInt64(10)
stakeWeightedLoss[j], err = alloraMath.Pow(ten, logSum)
if err != nil {
return nil, err
return nil, nil, err
}

// Find most distant value from consensus value
maxDistance, err := alloraMath.OneDec().Mul(alloraMath.MustNewDecFromString("-1")) // Initialize with an impossible value
if err != nil {
return nil, nil, err
}
for _, losses := range reputersReportedLosses {
// Skip if loss is NaN
if losses[j].IsNaN() {
continue
}

logLosses, err := alloraMath.Log10(losses[j])
if err != nil {
return nil, nil, err
}
distance, err := logSum.Sub(logLosses)
if err != nil {
return nil, nil, err
}
if distance.Gt(maxDistance) {
maxDistance = distance
mostDistantValues[j] = losses[j]
}
}
}

return stakeWeightedLoss, nil
return stakeWeightedLoss, mostDistantValues, nil
}

// GetConsensusScore calculates the proximity to consensus score for a reputer.
// T_im
func GetConsensusScore(reputerLosses, consensusLosses []alloraMath.Dec) (alloraMath.Dec, error) {
fTolerance := alloraMath.MustNewDecFromString("0.01")
func GetConsensusScore(reputerLosses, consensusLosses, mostDistantValues []alloraMath.Dec) (alloraMath.Dec, error) {
fTolerance := alloraMath.MustNewDecFromString("0.01") // TODO: Use module param
if len(reputerLosses) != len(consensusLosses) {
return alloraMath.ZeroDec(), types.ErrInvalidSliceLength
}
Expand Down Expand Up @@ -355,6 +393,10 @@ func GetConsensusScore(reputerLosses, consensusLosses []alloraMath.Dec) (alloraM

var distanceSquared alloraMath.Dec
for i, rLoss := range reputerLosses {
// Attribute most distant value if loss is NaN
if rLoss.IsNaN() {
rLoss = mostDistantValues[i]
}
rLossOverConsensusLoss, err := rLoss.Quo(consensusLosses[i])
if err != nil {
return alloraMath.ZeroDec(), err
Expand All @@ -363,7 +405,7 @@ func GetConsensusScore(reputerLosses, consensusLosses []alloraMath.Dec) (alloraM
if err != nil {
return alloraMath.ZeroDec(), err
}
log10RLossOverCLossSquared, err := log10RLossOverCLoss.Mul(log10RLossOverCLoss)
log10RLossOverCLossSquared, err := log10RLossOverCLoss.Mul(log10RLossOverCLoss) // == Pow(x,2)
if err != nil {
return alloraMath.ZeroDec(), err
}
Expand Down Expand Up @@ -418,8 +460,8 @@ func GetAllConsensusScores(
adjustedStakes = append(adjustedStakes, adjustedStake)
}

// Get consensus loss vector
consensus, err := GetStakeWeightedLossMatrix(adjustedStakes, allLosses)
// Get consensus loss vector and retrieve most distant values from
consensus, mostDistantValues, err := GetStakeWeightedLossMatrix(adjustedStakes, allLosses)
if err != nil {
return nil, err
}
Expand All @@ -428,7 +470,7 @@ func GetAllConsensusScores(
scores := make([]alloraMath.Dec, numReputers)
for i := int64(0); i < numReputers; i++ {
losses := allLosses[i]
scores[i], err = GetConsensusScore(losses, consensus)
scores[i], err = GetConsensusScore(losses, consensus, mostDistantValues)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -466,8 +508,7 @@ func GetAllReputersOutput(
var maxGradient alloraMath.Dec = alloraMath.OneDec()
finalScores := make([]alloraMath.Dec, numReputers)

for maxGradient.Cmp(maxGradientThreshold) == alloraMath.GreaterThan &&
i.Cmp(imax) == alloraMath.LessThan {
for maxGradient.Gt(maxGradientThreshold) && i.Lt(imax) {
i, err = i.Add(alloraMath.OneDec())
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -561,7 +602,7 @@ func GetAllReputersOutput(
if err != nil {
return nil, nil, err
}
if listenedStakeFraction.Cmp(minStakeFraction) == alloraMath.LessThan {
if listenedStakeFraction.Lt(minStakeFraction) {
for l := range coefficients {
coeffDiff, err := coefficients[l].Sub(oldCoefficients[l])
if err != nil {
Expand Down Expand Up @@ -644,7 +685,7 @@ func maxAbsDifference(a, b []alloraMath.Dec) (alloraMath.Dec, error) {
return alloraMath.Dec{}, err
}
diff := subtraction.Abs()
if diff.Cmp(maxDiff) == alloraMath.GreaterThan {
if diff.Gt(maxDiff) {
maxDiff = diff
}
}
Expand Down Expand Up @@ -1113,19 +1154,34 @@ func ExtractValues(bundle *types.ValueBundle) []alloraMath.Dec {
// Extract direct alloraMath.Dec values
values = append(values, bundle.CombinedValue, bundle.NaiveValue)

// Extract values from slices of WorkerAttributedValue
// Sort and Extract values from slices of ValueBundle
sort.Slice(bundle.InfererValues, func(i, j int) bool {
return bundle.InfererValues[i].Worker < bundle.InfererValues[j].Worker
})
for _, v := range bundle.InfererValues {
values = append(values, v.Value)
}
sort.Slice(bundle.ForecasterValues, func(i, j int) bool {
return bundle.ForecasterValues[i].Worker < bundle.ForecasterValues[j].Worker
})
for _, v := range bundle.ForecasterValues {
values = append(values, v.Value)
}
sort.Slice(bundle.OneOutInfererValues, func(i, j int) bool {
return bundle.OneOutInfererValues[i].Worker < bundle.OneOutInfererValues[j].Worker
})
for _, v := range bundle.OneOutInfererValues {
values = append(values, v.Value)
}
sort.Slice(bundle.OneOutForecasterValues, func(i, j int) bool {
return bundle.OneOutForecasterValues[i].Worker < bundle.OneOutForecasterValues[j].Worker
})
for _, v := range bundle.OneOutForecasterValues {
values = append(values, v.Value)
}
sort.Slice(bundle.OneInForecasterValues, func(i, j int) bool {
return bundle.OneInForecasterValues[i].Worker < bundle.OneInForecasterValues[j].Worker
})
for _, v := range bundle.OneInForecasterValues {
values = append(values, v.Value)
}
Expand Down
Loading

0 comments on commit 31df69f

Please sign in to comment.