From e21b122b4e0ebb42ad98fe4ecf738bf380e487a3 Mon Sep 17 00:00:00 2001 From: Steve Coffman Date: Wed, 25 Dec 2024 14:01:53 -0500 Subject: [PATCH] Add replace rule function (#338) Signed-off-by: Steve Coffman --- validator/validator.go | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/validator/validator.go b/validator/validator.go index fb4db19..50b966b 100644 --- a/validator/validator.go +++ b/validator/validator.go @@ -17,12 +17,15 @@ type Rule struct { var specifiedRules []Rule -// AddRule adds rule to the rule set. -// f is called once each time `Validate` is executed. +// AddRule adds a rule to the rule set. +// ruleFunc is called once each time `Validate` is executed. func AddRule(name string, ruleFunc RuleFunc) { specifiedRules = append(specifiedRules, Rule{Name: name, RuleFunc: ruleFunc}) } +// RemoveRule removes an existing rule from the rule set +// if one of the same name exists. +// The rule set is global, so it is not safe for concurrent changes func RemoveRule(name string) { var result []Rule // nolint:prealloc // using initialized with len(rules) produces a race condition for _, r := range specifiedRules { @@ -34,6 +37,28 @@ func RemoveRule(name string) { specifiedRules = result } +// ReplaceRule replaces an existing rule from the rule set +// if one of the same name exists. +// If no match is found, it will add a new rule to the rule set. +// The rule set is global, so it is not safe for concurrent changes +func ReplaceRule(name string, ruleFunc RuleFunc) { + var found bool + var result []Rule // nolint:prealloc // using initialized with len(rules) produces a race condition + for _, r := range specifiedRules { + if r.Name == name { + found = true + result = append(result, Rule{Name: name, RuleFunc: ruleFunc}) + continue + } + result = append(result, r) + } + if !found { + specifiedRules = append(specifiedRules, Rule{Name: name, RuleFunc: ruleFunc}) + return + } + specifiedRules = result +} + func Validate(schema *Schema, doc *QueryDocument, rules ...Rule) gqlerror.List { if rules == nil { rules = specifiedRules