Skip to content

Commit 538bab6

Browse files
authoredFeb 26, 2025
privilege: handle username with "@" correctly for role(RBAC) related code (pingcap#59739)
close pingcap#59552
1 parent f66e8b1 commit 538bab6

File tree

5 files changed

+60
-28
lines changed

5 files changed

+60
-28
lines changed
 

‎pkg/privilege/privileges/cache.go

+33-21
Original file line numberDiff line numberDiff line change
@@ -249,18 +249,21 @@ type defaultRoleRecord struct {
249249

250250
// roleGraphEdgesTable is used to cache relationship between and role.
251251
type roleGraphEdgesTable struct {
252-
roleList map[string]*auth.RoleIdentity
252+
roleList map[auth.RoleIdentity]*auth.RoleIdentity
253253
}
254254

255255
// Find method is used to find role from table
256256
func (g roleGraphEdgesTable) Find(user, host string) bool {
257257
if host == "" {
258258
host = "%"
259259
}
260-
key := user + "@" + host
261260
if g.roleList == nil {
262261
return false
263262
}
263+
key := auth.RoleIdentity{
264+
Username: user,
265+
Hostname: host,
266+
}
264267
_, ok := g.roleList[key]
265268
return ok
266269
}
@@ -362,7 +365,7 @@ type MySQLPrivilege struct {
362365

363366
globalPriv bTree[itemGlobalPriv]
364367
dynamicPriv bTree[itemDynamicPriv]
365-
roleGraph map[string]roleGraphEdgesTable
368+
roleGraph map[auth.RoleIdentity]roleGraphEdgesTable
366369
}
367370

368371
func newMySQLPrivilege() *MySQLPrivilege {
@@ -400,7 +403,7 @@ func (p *MySQLPrivilege) FindAllRole(activeRoles []*auth.RoleIdentity) []*auth.R
400403
if _, ok := visited[role.String()]; !ok {
401404
visited[role.String()] = true
402405
ret = append(ret, role)
403-
key := role.Username + "@" + role.Hostname
406+
key := *role
404407
if edgeTable, ok := p.roleGraph[key]; ok {
405408
for _, v := range edgeTable.roleList {
406409
if _, ok := visited[v.String()]; !ok {
@@ -419,7 +422,10 @@ func (p *MySQLPrivilege) FindRole(user string, host string, role *auth.RoleIdent
419422
rec := p.matchUser(user, host)
420423
r := p.matchUser(role.Username, role.Hostname)
421424
if rec != nil && r != nil {
422-
key := rec.User + "@" + rec.Host
425+
key := auth.RoleIdentity{
426+
Username: rec.User,
427+
Hostname: rec.Host,
428+
}
423429
return p.roleGraph[key].Find(role.Username, role.Hostname)
424430
}
425431
return false
@@ -497,16 +503,12 @@ func (p *MySQLPrivilege) LoadAll(ctx sqlexec.RestrictedSQLExecutor) error {
497503
return nil
498504
}
499505

500-
func findUserAndAllRoles(all map[string]struct{}, roleGraph map[string]roleGraphEdgesTable) {
506+
func findUserAndAllRoles(all map[string]struct{}, roleGraph map[auth.RoleIdentity]roleGraphEdgesTable) {
501507
for {
502508
before := len(all)
503509

504510
for userHost, value := range roleGraph {
505-
user, _, found := strings.Cut(userHost, "@")
506-
if !found {
507-
// this should never happen
508-
continue
509-
}
511+
user := userHost.Username
510512
if _, ok := all[user]; ok {
511513
// If a user is in map, all its role should also added
512514
for _, role := range value.roleList {
@@ -528,7 +530,7 @@ func (p *MySQLPrivilege) loadSomeUsers(ctx sqlexec.RestrictedSQLExecutor, userLi
528530
logutil.BgLogger().Warn("loadSomeUsers called with a long user list", zap.Int("len", len(userList)))
529531
}
530532
// Load the full role edge table first.
531-
p.roleGraph = make(map[string]roleGraphEdgesTable)
533+
p.roleGraph = make(map[auth.RoleIdentity]roleGraphEdgesTable)
532534
err := p.loadTable(ctx, sqlLoadRoleGraph, p.decodeRoleEdgesTable)
533535
if err != nil {
534536
return nil, errors.Trace(err)
@@ -680,7 +682,7 @@ func noSuchTable(err error) bool {
680682

681683
// LoadRoleGraph loads the mysql.role_edges table from database.
682684
func (p *MySQLPrivilege) LoadRoleGraph(ctx sqlexec.RestrictedSQLExecutor) error {
683-
p.roleGraph = make(map[string]roleGraphEdgesTable)
685+
p.roleGraph = make(map[auth.RoleIdentity]roleGraphEdgesTable)
684686
err := p.loadTable(ctx, sqlLoadRoleGraph, p.decodeRoleEdgesTable)
685687
if err != nil {
686688
return errors.Trace(err)
@@ -1192,11 +1194,17 @@ func (p *MySQLPrivilege) decodeRoleEdgesTable(row chunk.Row, fs []*resolve.Resul
11921194
toUser = row.GetString(i)
11931195
}
11941196
}
1195-
fromKey := fromUser + "@" + fromHost
1196-
toKey := toUser + "@" + toHost
1197+
fromKey := auth.RoleIdentity{
1198+
Username: fromUser,
1199+
Hostname: fromHost,
1200+
}
1201+
toKey := auth.RoleIdentity{
1202+
Username: toUser,
1203+
Hostname: toHost,
1204+
}
11971205
roleGraph, ok := p.roleGraph[toKey]
11981206
if !ok {
1199-
roleGraph = roleGraphEdgesTable{roleList: make(map[string]*auth.RoleIdentity)}
1207+
roleGraph = roleGraphEdgesTable{roleList: make(map[auth.RoleIdentity]*auth.RoleIdentity)}
12001208
p.roleGraph[toKey] = roleGraph
12011209
}
12021210
roleGraph.roleList[fromKey] = &auth.RoleIdentity{Username: fromUser, Hostname: fromHost}
@@ -1759,15 +1767,16 @@ func (p *MySQLPrivilege) showGrants(ctx sessionctx.Context, user, host string, r
17591767
slices.Sort(gs[sortFromIdx:])
17601768

17611769
// Show role grants.
1762-
graphKey := user + "@" + host
1770+
graphKey := auth.RoleIdentity{
1771+
Username: user,
1772+
Hostname: host,
1773+
}
17631774
edgeTable, ok := p.roleGraph[graphKey]
17641775
g = ""
17651776
if ok {
17661777
sortedRes := make([]string, 0, 10)
17671778
for k := range edgeTable.roleList {
1768-
role := strings.Split(k, "@")
1769-
roleName, roleHost := role[0], role[1]
1770-
tmp := fmt.Sprintf("'%s'@'%s'", roleName, roleHost)
1779+
tmp := fmt.Sprintf("'%s'@'%s'", k.Username, k.Hostname)
17711780
sortedRes = append(sortedRes, tmp)
17721781
}
17731782
slices.Sort(sortedRes)
@@ -1994,7 +2003,10 @@ func (p *MySQLPrivilege) getDefaultRoles(user, host string) []*auth.RoleIdentity
19942003
}
19952004

19962005
func (p *MySQLPrivilege) getAllRoles(user, host string) []*auth.RoleIdentity {
1997-
key := user + "@" + host
2006+
key := auth.RoleIdentity{
2007+
Username: user,
2008+
Hostname: host,
2009+
}
19982010
edgeTable, ok := p.roleGraph[key]
19992011
ret := make([]*auth.RoleIdentity, 0, len(edgeTable.roleList))
20002012
if ok {

‎pkg/privilege/privileges/cache_test.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -315,13 +315,13 @@ func TestLoadRoleGraph(t *testing.T) {
315315
p = privileges.NewMySQLPrivilege()
316316
require.NoError(t, p.LoadRoleGraph(se.GetRestrictedSQLExecutor()))
317317
graph := p.RoleGraph()
318-
require.True(t, graph["root@%"].Find("r_2", "%"))
319-
require.True(t, graph["root@%"].Find("r_4", "%"))
320-
require.True(t, graph["user2@%"].Find("r_1", "%"))
321-
require.True(t, graph["user1@%"].Find("r_3", "%"))
322-
_, ok := graph["illedal"]
318+
require.True(t, graph[auth.RoleIdentity{Username: "root", Hostname: "%"}].Find("r_2", "%"))
319+
require.True(t, graph[auth.RoleIdentity{Username: "root", Hostname: "%"}].Find("r_4", "%"))
320+
require.True(t, graph[auth.RoleIdentity{Username: "user2", Hostname: "%"}].Find("r_1", "%"))
321+
require.True(t, graph[auth.RoleIdentity{Username: "user1", Hostname: "%"}].Find("r_3", "%"))
322+
_, ok := graph[auth.RoleIdentity{Username: "illedal"}]
323323
require.False(t, ok)
324-
require.False(t, graph["root@%"].Find("r_1", "%"))
324+
require.False(t, graph[auth.RoleIdentity{Username: "root", Hostname: "%"}].Find("r_1", "%"))
325325
}
326326

327327
func TestRoleGraphBFS(t *testing.T) {

‎pkg/privilege/privileges/tidb_auth_token_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
jwsRepo "github.com/lestrrat-go/jwx/v2/jws"
3131
jwtRepo "github.com/lestrrat-go/jwx/v2/jwt"
3232
"github.com/lestrrat-go/jwx/v2/jwt/openid"
33+
"github.com/pingcap/tidb/pkg/parser/auth"
3334
"github.com/pingcap/tidb/pkg/util/hack"
3435
"github.com/stretchr/testify/require"
3536
)
@@ -482,7 +483,7 @@ func (p *MySQLPrivilege) GlobalPriv(user string) []globalPrivRecord {
482483
return ret.data
483484
}
484485

485-
func (p *MySQLPrivilege) RoleGraph() map[string]roleGraphEdgesTable {
486+
func (p *MySQLPrivilege) RoleGraph() map[auth.RoleIdentity]roleGraphEdgesTable {
486487
return p.roleGraph
487488
}
488489

‎tests/integrationtest/r/privilege/privileges.result

+10
Original file line numberDiff line numberDiff line change
@@ -696,3 +696,13 @@ update privilege__privileges.tt1 inner join t_f
696696
set t_f.fullname=t_f.fullname
697697
where tt1.id=t_f.id;
698698
Error 1288 (HY000): The target table t_f of the UPDATE is not updatable
699+
drop user if exists u1;
700+
create user u1;
701+
create role 'aa@bb';
702+
grant 'aa@bb' to u1;
703+
show grants for u1;
704+
Grants for u1@%
705+
GRANT USAGE ON *.* TO 'u1'@'%'
706+
GRANT 'aa@bb'@'%' TO 'u1'@'%'
707+
drop user u1;
708+
drop role 'aa@bb';

‎tests/integrationtest/t/privilege/privileges.test

+9
Original file line numberDiff line numberDiff line change
@@ -947,3 +947,12 @@ with t_f as (
947947

948948
disconnect u53490;
949949
connection default;
950+
951+
# TestIssue59552
952+
drop user if exists u1;
953+
create user u1;
954+
create role 'aa@bb';
955+
grant 'aa@bb' to u1;
956+
show grants for u1;
957+
drop user u1;
958+
drop role 'aa@bb';

0 commit comments

Comments
 (0)
Please sign in to comment.