Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: casbin/gorm-adapter
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: v3.23.0
Choose a base ref
...
head repository: casbin/gorm-adapter
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: master
Choose a head ref
  • 10 commits
  • 7 files changed
  • 9 contributors

Commits on Apr 2, 2024

  1. fix: initialize mutex before using (#233)

    MuZhou233 authored Apr 2, 2024

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    8e4fe6d View commit details

Commits on Apr 8, 2024

  1. feat: pass context down to gorm, remove old ContextAdapter (#234)

    * fix: pass context down to gorm
    
    * fix: delete context_adapter_test.go
    
    * fix: go mod tidy
    
    * fix: update README & delete ContextAdapter
    MuZhou233 authored Apr 8, 2024
    Copy the full SHA
    15ac848 View commit details

Commits on May 31, 2024

  1. feat: initialize transactionMu in NewAdapterByDBUseTableName (#237)

    constructor NewAdapterByDBUseTableName should initialize transactionMu,
    otherwise it will ~~panic~~ blocked due to that cas when calling Transaction
    
    Besides, a few other constructors also depend on NewAdapterByDBUseTableName,
    if transactionMu is not initialized in NewAdapterByDBUseTableName,
    it will blocked as well.
    
    Besides, why not consider use sync.Once to initialize transactionMu?
    yuikns authored May 31, 2024
    Copy the full SHA
    9ad4075 View commit details

Commits on Jul 19, 2024

  1. feat: Remove underscores from the getFullTableName method when they a…

    …re present (#241)
    
    * feat: remove the underscore from the getFullTableName method
    
    * fix: Remove underscores from the getFullTableName method when they are present
    chenxi2015 authored Jul 19, 2024
    Copy the full SHA
    6a0d216 View commit details

Commits on Aug 14, 2024

  1. feat: add OnConflict=DoNothing on create operations (#243)

    * add TestAddPolicy
    
    * add OnConflict=DoNothing on create
    
    * more db
    
    * fix test
    
    * on conflict clause seems not work in sqlserver
    longshine authored Aug 14, 2024
    Copy the full SHA
    0560ffa View commit details

Commits on Aug 19, 2024

  1. feat: add sqlite3 to error message (#245)

    MuZhou233 authored Aug 19, 2024
    Copy the full SHA
    3d3a3c7 View commit details

Commits on Nov 4, 2024

  1. feat: update dependencies to improve security (#250)

    ypli0629 authored Nov 4, 2024
    Copy the full SHA
    a7e4936 View commit details

Commits on Nov 13, 2024

  1. feat: upgrade dependencies for better security (#253)

    shrutsureja authored Nov 13, 2024
    Copy the full SHA
    aef8c1f View commit details
  2. feat: fix failure when calling SavePolicy within the Transaction meth…

    …od (#251)
    
    Co-authored-by: junfengxu <[email protected]>
    Hill1126 and junfengxu authored Nov 13, 2024
    Copy the full SHA
    16aa502 View commit details

Commits on Nov 14, 2024

  1. feat: upgrade casbin dependency to v2.100.0

    hsluoyz committed Nov 14, 2024
    Copy the full SHA
    87539c9 View commit details
Showing with 207 additions and 306 deletions.
  1. +2 −2 README.md
  2. +83 −23 adapter.go
  3. +86 −3 adapter_test.go
  4. +0 −85 context_adapter.go
  5. +0 −141 context_adapter_test.go
  6. +12 −14 go.mod
  7. +24 −38 go.sum
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -228,11 +228,11 @@ func TestGetAllowedRecordsForUser(t *testing.T) {
`gormadapter` supports adapter with context, the following is a timeout control implemented using context

```go
ca, _ := NewContextAdapter("mysql", "root:@tcp(127.0.0.1:3306)/", "casbin")
a, _ := gormadapter.NewAdapter("mysql", "mysql_username:mysql_password@tcp(127.0.0.1:3306)/") // Your driver and data source.
// Limited time 300s
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Microsecond)
defer cancel()
err := ca.AddPolicyCtx(ctx, "p", "p", []string{"alice", "data1", "read"})
err := a.AddPolicyCtx(ctx, "p", "p", []string{"alice", "data1", "read"})
if err != nil {
panic(err)
}
106 changes: 83 additions & 23 deletions adapter.go
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@ import (
"gorm.io/driver/postgres"
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/plugin/dbresolver"
)
@@ -83,6 +84,7 @@ type Adapter struct {
db *gorm.DB
isFiltered bool
transactionMu *sync.Mutex
muInitialize sync.Once
}

// finalizer is the destructor for Adapter.
@@ -198,8 +200,9 @@ func NewAdapterByDBUseTableName(db *gorm.DB, prefix string, tableName string) (*
}

a := &Adapter{
tablePrefix: prefix,
tableName: tableName,
tablePrefix: prefix,
tableName: tableName,
transactionMu: &sync.Mutex{},
}

a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{Context: db.Statement.Context})
@@ -257,9 +260,10 @@ func NewFilteredAdapter(driverName string, dataSourceName string, params ...inte
// Casbin will not automatically call LoadPolicy() for a filtered adapter.
func NewFilteredAdapterByDB(db *gorm.DB, prefix string, tableName string) (*Adapter, error) {
adapter := &Adapter{
tablePrefix: prefix,
tableName: tableName,
isFiltered: true,
tablePrefix: prefix,
tableName: tableName,
isFiltered: true,
transactionMu: &sync.Mutex{},
}
adapter.db = db.Scopes(adapter.casbinRuleTable()).Session(&gorm.Session{Context: db.Statement.Context})

@@ -310,7 +314,7 @@ func openDBConnection(driverName, dataSourceName string) (*gorm.DB, error) {
} else if driverName == "sqlite3" {
db, err = gorm.Open(sqlite.Open(dataSourceName), &gorm.Config{})
} else {
return nil, errors.New("Database dialect '" + driverName + "' is not supported. Supported databases are postgres, mysql and sqlserver")
return nil, errors.New("Database dialect '" + driverName + "' is not supported. Supported databases are postgres, mysql, sqlserver and sqlite3")
}
if err != nil {
return nil, err
@@ -388,6 +392,9 @@ func (a *Adapter) getTableInstance() *CasbinRule {

func (a *Adapter) getFullTableName() string {
if a.tablePrefix != "" {
if strings.HasSuffix(a.tablePrefix, "_") {
return a.tablePrefix + a.tableName
}
return a.tablePrefix + "_" + a.tableName
}
return a.tableName
@@ -476,8 +483,13 @@ func loadPolicyLine(line CasbinRule, model model.Model) error {

// LoadPolicy loads policy from database.
func (a *Adapter) LoadPolicy(model model.Model) error {
return a.LoadPolicyCtx(context.Background(), model)
}

// LoadPolicyCtx loads policy from database.
func (a *Adapter) LoadPolicyCtx(ctx context.Context, model model.Model) error {
var lines []CasbinRule
if err := a.db.Order("ID").Find(&lines).Error; err != nil {
if err := a.db.WithContext(ctx).Order("ID").Find(&lines).Error; err != nil {
return err
}
err := a.Preview(&lines, model)
@@ -594,8 +606,13 @@ func (a *Adapter) savePolicyLine(ptype string, rule []string) CasbinRule {

// SavePolicy saves policy to database.
func (a *Adapter) SavePolicy(model model.Model) error {
return a.SavePolicyCtx(context.Background(), model)
}

// SavePolicyCtx saves policy to database.
func (a *Adapter) SavePolicyCtx(ctx context.Context, model model.Model) error {
var err error
tx := a.db.Clauses(dbresolver.Write).Begin()
tx := a.db.WithContext(ctx).Clauses(dbresolver.Write).Begin()

err = a.truncateTable()

@@ -610,7 +627,7 @@ func (a *Adapter) SavePolicy(model model.Model) error {
for _, rule := range ast.Policy {
lines = append(lines, a.savePolicyLine(ptype, rule))
if len(lines) > flushEvery {
if err := tx.Create(&lines).Error; err != nil {
if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&lines).Error; err != nil {
tx.Rollback()
return err
}
@@ -623,7 +640,7 @@ func (a *Adapter) SavePolicy(model model.Model) error {
for _, rule := range ast.Policy {
lines = append(lines, a.savePolicyLine(ptype, rule))
if len(lines) > flushEvery {
if err := tx.Create(&lines).Error; err != nil {
if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&lines).Error; err != nil {
tx.Rollback()
return err
}
@@ -632,7 +649,7 @@ func (a *Adapter) SavePolicy(model model.Model) error {
}
}
if len(lines) > 0 {
if err := tx.Create(&lines).Error; err != nil {
if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&lines).Error; err != nil {
tx.Rollback()
return err
}
@@ -644,15 +661,25 @@ func (a *Adapter) SavePolicy(model model.Model) error {

// AddPolicy adds a policy rule to the storage.
func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
return a.AddPolicyCtx(context.Background(), sec, ptype, rule)
}

// AddPolicyCtx adds a policy rule to the storage.
func (a *Adapter) AddPolicyCtx(ctx context.Context, sec string, ptype string, rule []string) error {
line := a.savePolicyLine(ptype, rule)
err := a.db.Create(&line).Error
err := a.db.WithContext(ctx).Clauses(clause.OnConflict{DoNothing: true}).Create(&line).Error
return err
}

// RemovePolicy removes a policy rule from the storage.
func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {
return a.RemovePolicyCtx(context.Background(), sec, ptype, rule)
}

// RemovePolicyCtx removes a policy rule from the storage.
func (a *Adapter) RemovePolicyCtx(ctx context.Context, sec string, ptype string, rule []string) error {
line := a.savePolicyLine(ptype, rule)
err := a.rawDelete(a.db, line) //can't use db.Delete as we're not using primary key https://gorm.io/docs/update.html
err := a.rawDelete(ctx, a.db, line) //can't use db.Delete as we're not using primary key https://gorm.io/docs/update.html
return err
}

@@ -663,26 +690,34 @@ func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error
line := a.savePolicyLine(ptype, rule)
lines = append(lines, line)
}
return a.db.Create(&lines).Error
return a.db.Clauses(clause.OnConflict{DoNothing: true}).Create(&lines).Error
}

// Transaction perform a set of operations within a transaction
func (a *Adapter) Transaction(e casbin.IEnforcer, fc func(casbin.IEnforcer) error, opts ...*sql.TxOptions) error {
// ensure the transactionMu is initialized
if a.transactionMu == nil {
a.muInitialize.Do(func() {
if a.transactionMu == nil {
a.transactionMu = &sync.Mutex{}
}
})
}
// lock the transactionMu to ensure the transaction is thread-safe
a.transactionMu.Lock()
defer a.transactionMu.Unlock()
var err error
oriAdapter := a.db
// reload policy from database to sync with the transaction
defer func() {
e.SetAdapter(&Adapter{db: oriAdapter})
e.SetAdapter(a.Copy())
err = e.LoadPolicy()
if err != nil {
panic(err)
}
}()
copyDB := *a.db
tx := copyDB.Begin(opts...)
b := &Adapter{db: tx}
b := a.Copy()
// copy enforcer to set the new adapter with transaction tx
copyEnforcer := e
copyEnforcer.SetAdapter(b)
@@ -697,10 +732,15 @@ func (a *Adapter) Transaction(e casbin.IEnforcer, fc func(casbin.IEnforcer) erro

// RemovePolicies removes multiple policy rules from the storage.
func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) error {
return a.RemovePoliciesCtx(context.Background(), sec, ptype, rules)
}

// RemovePoliciesCtx removes multiple policy rules from the storage.
func (a *Adapter) RemovePoliciesCtx(ctx context.Context, sec string, ptype string, rules [][]string) error {
return a.db.Transaction(func(tx *gorm.DB) error {
for _, rule := range rules {
line := a.savePolicyLine(ptype, rule)
if err := a.rawDelete(tx, line); err != nil { //can't use db.Delete as we're not using primary key https://gorm.io/docs/update.html
if err := a.rawDelete(ctx, tx, line); err != nil { //can't use db.Delete as we're not using primary key https://gorm.io/docs/update.html
}
}
return nil
@@ -709,12 +749,17 @@ func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) err

// RemoveFilteredPolicy removes policy rules that match the filter from the storage.
func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
return a.RemoveFilteredPolicyCtx(context.Background(), sec, ptype, fieldIndex, fieldValues...)
}

// RemoveFilteredPolicyCtx removes policy rules that match the filter from the storage.
func (a *Adapter) RemoveFilteredPolicyCtx(ctx context.Context, sec string, ptype string, fieldIndex int, fieldValues ...string) error {
line := a.getTableInstance()

line.Ptype = ptype

if fieldIndex == -1 {
return a.rawDelete(a.db, *line)
return a.rawDelete(ctx, a.db, *line)
}

err := checkQueryField(fieldValues)
@@ -740,7 +785,7 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int,
if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
line.V5 = fieldValues[5-fieldIndex]
}
err = a.rawDelete(a.db, *line)
err = a.rawDelete(ctx, a.db, *line)
return err
}

@@ -754,7 +799,7 @@ func checkQueryField(fieldValues []string) error {
return errors.New("the query field cannot all be empty string (\"\"), please check")
}

func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error {
func (a *Adapter) rawDelete(ctx context.Context, db *gorm.DB, line CasbinRule) error {
queryArgs := []interface{}{line.Ptype}

queryStr := "ptype = ?"
@@ -783,7 +828,7 @@ func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error {
queryArgs = append(queryArgs, line.V5)
}
args := append([]interface{}{queryStr}, queryArgs...)
err := db.Delete(a.getTableInstance(), args...).Error
err := db.WithContext(ctx).Delete(a.getTableInstance(), args...).Error
return err
}

@@ -885,7 +930,7 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [
return nil, err
}
for i := range newP {
if err := tx.Create(&newP[i]).Error; err != nil {
if err := tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&newP[i]).Error; err != nil {
tx.Rollback()
return nil, err
}
@@ -900,6 +945,21 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [
return oldPolicies, tx.Commit().Error
}

func (a *Adapter) Copy() *Adapter {
oriAdapter := a.db
return &Adapter{
db: oriAdapter,
transactionMu: a.transactionMu,
driverName: a.driverName,
dataSourceName: a.dataSourceName,
databaseName: a.databaseName,
tablePrefix: a.tablePrefix,
tableName: a.tableName,
dbSpecified: a.dbSpecified,
isFiltered: a.isFiltered,
}
}

// Preview Pre-checking to avoid causing partial load success and partial failure deep
func (a *Adapter) Preview(rules *[]CasbinRule, model model.Model) error {
j := 0
Loading