Skip to content

Commit 31d2a4e

Browse files
committed
Update ResultSet's Union and Intersect methods to favor explicit results
1 parent bf581b0 commit 31d2a4e

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

pkg/authz/query/resultset.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,9 @@ func (rs *ResultSet) Union(other *ResultSet) *ResultSet {
9999
}
100100

101101
for iter := other.List(); iter != nil; iter = iter.Next() {
102-
isImplicit := iter.IsImplicit
103-
if resultSet.Has(iter.ObjectType, iter.ObjectId, iter.Relation) {
104-
isImplicit = isImplicit && resultSet.Get(iter.ObjectType, iter.ObjectId, iter.Relation).IsImplicit
102+
if !resultSet.Has(iter.ObjectType, iter.ObjectId, iter.Relation) || !iter.IsImplicit {
103+
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, iter.IsImplicit)
105104
}
106-
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, isImplicit)
107105
}
108106

109107
return resultSet
@@ -112,10 +110,13 @@ func (rs *ResultSet) Union(other *ResultSet) *ResultSet {
112110
func (rs *ResultSet) Intersect(other *ResultSet) *ResultSet {
113111
resultSet := NewResultSet()
114112
for iter := rs.List(); iter != nil; iter = iter.Next() {
115-
isImplicit := iter.IsImplicit
116113
if other.Has(iter.ObjectType, iter.ObjectId, iter.Relation) {
117-
isImplicit = isImplicit || other.Get(iter.ObjectType, iter.ObjectId, iter.Relation).IsImplicit
118-
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, isImplicit)
114+
otherRes := other.Get(iter.ObjectType, iter.ObjectId, iter.Relation)
115+
if !otherRes.IsImplicit {
116+
resultSet.Add(otherRes.ObjectType, otherRes.ObjectId, otherRes.Relation, otherRes.Warrant, otherRes.IsImplicit)
117+
} else {
118+
resultSet.Add(iter.ObjectType, iter.ObjectId, iter.Relation, iter.Warrant, iter.IsImplicit)
119+
}
119120
}
120121
}
121122

@@ -125,7 +126,11 @@ func (rs *ResultSet) Intersect(other *ResultSet) *ResultSet {
125126
func (rs *ResultSet) String() string {
126127
var strs []string
127128
for iter := rs.List(); iter != nil; iter = iter.Next() {
128-
strs = append(strs, fmt.Sprintf("%s => %s", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String()))
129+
if iter.IsImplicit {
130+
strs = append(strs, fmt.Sprintf("%s => %s [implicit]", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String()))
131+
} else {
132+
strs = append(strs, fmt.Sprintf("%s => %s", key(iter.ObjectType, iter.ObjectId, iter.Relation), iter.Warrant.String()))
133+
}
129134
}
130135

131136
return strings.Join(strs, ", ")

0 commit comments

Comments
 (0)