@@ -249,18 +249,21 @@ type defaultRoleRecord struct {
249
249
250
250
// roleGraphEdgesTable is used to cache relationship between and role.
251
251
type roleGraphEdgesTable struct {
252
- roleList map [string ]* auth.RoleIdentity
252
+ roleList map [auth. RoleIdentity ]* auth.RoleIdentity
253
253
}
254
254
255
255
// Find method is used to find role from table
256
256
func (g roleGraphEdgesTable ) Find (user , host string ) bool {
257
257
if host == "" {
258
258
host = "%"
259
259
}
260
- key := user + "@" + host
261
260
if g .roleList == nil {
262
261
return false
263
262
}
263
+ key := auth.RoleIdentity {
264
+ Username : user ,
265
+ Hostname : host ,
266
+ }
264
267
_ , ok := g .roleList [key ]
265
268
return ok
266
269
}
@@ -362,7 +365,7 @@ type MySQLPrivilege struct {
362
365
363
366
globalPriv bTree [itemGlobalPriv ]
364
367
dynamicPriv bTree [itemDynamicPriv ]
365
- roleGraph map [string ]roleGraphEdgesTable
368
+ roleGraph map [auth. RoleIdentity ]roleGraphEdgesTable
366
369
}
367
370
368
371
func newMySQLPrivilege () * MySQLPrivilege {
@@ -400,7 +403,7 @@ func (p *MySQLPrivilege) FindAllRole(activeRoles []*auth.RoleIdentity) []*auth.R
400
403
if _ , ok := visited [role .String ()]; ! ok {
401
404
visited [role .String ()] = true
402
405
ret = append (ret , role )
403
- key := role . Username + "@" + role . Hostname
406
+ key := * role
404
407
if edgeTable , ok := p .roleGraph [key ]; ok {
405
408
for _ , v := range edgeTable .roleList {
406
409
if _ , ok := visited [v .String ()]; ! ok {
@@ -419,7 +422,10 @@ func (p *MySQLPrivilege) FindRole(user string, host string, role *auth.RoleIdent
419
422
rec := p .matchUser (user , host )
420
423
r := p .matchUser (role .Username , role .Hostname )
421
424
if rec != nil && r != nil {
422
- key := rec .User + "@" + rec .Host
425
+ key := auth.RoleIdentity {
426
+ Username : rec .User ,
427
+ Hostname : rec .Host ,
428
+ }
423
429
return p .roleGraph [key ].Find (role .Username , role .Hostname )
424
430
}
425
431
return false
@@ -497,16 +503,12 @@ func (p *MySQLPrivilege) LoadAll(ctx sqlexec.RestrictedSQLExecutor) error {
497
503
return nil
498
504
}
499
505
500
- func findUserAndAllRoles (all map [string ]struct {}, roleGraph map [string ]roleGraphEdgesTable ) {
506
+ func findUserAndAllRoles (all map [string ]struct {}, roleGraph map [auth. RoleIdentity ]roleGraphEdgesTable ) {
501
507
for {
502
508
before := len (all )
503
509
504
510
for userHost , value := range roleGraph {
505
- user , _ , found := strings .Cut (userHost , "@" )
506
- if ! found {
507
- // this should never happen
508
- continue
509
- }
511
+ user := userHost .Username
510
512
if _ , ok := all [user ]; ok {
511
513
// If a user is in map, all its role should also added
512
514
for _ , role := range value .roleList {
@@ -528,7 +530,7 @@ func (p *MySQLPrivilege) loadSomeUsers(ctx sqlexec.RestrictedSQLExecutor, userLi
528
530
logutil .BgLogger ().Warn ("loadSomeUsers called with a long user list" , zap .Int ("len" , len (userList )))
529
531
}
530
532
// Load the full role edge table first.
531
- p .roleGraph = make (map [string ]roleGraphEdgesTable )
533
+ p .roleGraph = make (map [auth. RoleIdentity ]roleGraphEdgesTable )
532
534
err := p .loadTable (ctx , sqlLoadRoleGraph , p .decodeRoleEdgesTable )
533
535
if err != nil {
534
536
return nil , errors .Trace (err )
@@ -680,7 +682,7 @@ func noSuchTable(err error) bool {
680
682
681
683
// LoadRoleGraph loads the mysql.role_edges table from database.
682
684
func (p * MySQLPrivilege ) LoadRoleGraph (ctx sqlexec.RestrictedSQLExecutor ) error {
683
- p .roleGraph = make (map [string ]roleGraphEdgesTable )
685
+ p .roleGraph = make (map [auth. RoleIdentity ]roleGraphEdgesTable )
684
686
err := p .loadTable (ctx , sqlLoadRoleGraph , p .decodeRoleEdgesTable )
685
687
if err != nil {
686
688
return errors .Trace (err )
@@ -1192,11 +1194,17 @@ func (p *MySQLPrivilege) decodeRoleEdgesTable(row chunk.Row, fs []*resolve.Resul
1192
1194
toUser = row .GetString (i )
1193
1195
}
1194
1196
}
1195
- fromKey := fromUser + "@" + fromHost
1196
- toKey := toUser + "@" + toHost
1197
+ fromKey := auth.RoleIdentity {
1198
+ Username : fromUser ,
1199
+ Hostname : fromHost ,
1200
+ }
1201
+ toKey := auth.RoleIdentity {
1202
+ Username : toUser ,
1203
+ Hostname : toHost ,
1204
+ }
1197
1205
roleGraph , ok := p .roleGraph [toKey ]
1198
1206
if ! ok {
1199
- roleGraph = roleGraphEdgesTable {roleList : make (map [string ]* auth.RoleIdentity )}
1207
+ roleGraph = roleGraphEdgesTable {roleList : make (map [auth. RoleIdentity ]* auth.RoleIdentity )}
1200
1208
p .roleGraph [toKey ] = roleGraph
1201
1209
}
1202
1210
roleGraph .roleList [fromKey ] = & auth.RoleIdentity {Username : fromUser , Hostname : fromHost }
@@ -1759,15 +1767,16 @@ func (p *MySQLPrivilege) showGrants(ctx sessionctx.Context, user, host string, r
1759
1767
slices .Sort (gs [sortFromIdx :])
1760
1768
1761
1769
// Show role grants.
1762
- graphKey := user + "@" + host
1770
+ graphKey := auth.RoleIdentity {
1771
+ Username : user ,
1772
+ Hostname : host ,
1773
+ }
1763
1774
edgeTable , ok := p .roleGraph [graphKey ]
1764
1775
g = ""
1765
1776
if ok {
1766
1777
sortedRes := make ([]string , 0 , 10 )
1767
1778
for k := range edgeTable .roleList {
1768
- role := strings .Split (k , "@" )
1769
- roleName , roleHost := role [0 ], role [1 ]
1770
- tmp := fmt .Sprintf ("'%s'@'%s'" , roleName , roleHost )
1779
+ tmp := fmt .Sprintf ("'%s'@'%s'" , k .Username , k .Hostname )
1771
1780
sortedRes = append (sortedRes , tmp )
1772
1781
}
1773
1782
slices .Sort (sortedRes )
@@ -1994,7 +2003,10 @@ func (p *MySQLPrivilege) getDefaultRoles(user, host string) []*auth.RoleIdentity
1994
2003
}
1995
2004
1996
2005
func (p * MySQLPrivilege ) getAllRoles (user , host string ) []* auth.RoleIdentity {
1997
- key := user + "@" + host
2006
+ key := auth.RoleIdentity {
2007
+ Username : user ,
2008
+ Hostname : host ,
2009
+ }
1998
2010
edgeTable , ok := p .roleGraph [key ]
1999
2011
ret := make ([]* auth.RoleIdentity , 0 , len (edgeTable .roleList ))
2000
2012
if ok {
0 commit comments