Skip to content

Commit

Permalink
feat:support same database has multiple tables with different shardin…
Browse files Browse the repository at this point in the history
…gkey rules
  • Loading branch information
liuxinjie committed Aug 9, 2024
1 parent 5a63663 commit 61167de
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 15 deletions.
4 changes: 2 additions & 2 deletions examples/order.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ func main() {
)`)
}

middleware := sharding.Register(sharding.Config{
middleware := sharding.Register(map[any]sharding.Config{"orders": {
ShardingKey: "user_id",
NumberOfShards: 64,
PrimaryKeyGenerator: sharding.PKSnowflake,
}, "orders")
}})
db.Use(middleware)

// this record will insert to orders_02
Expand Down
14 changes: 6 additions & 8 deletions sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ type Sharding struct {
querys sync.Map
snowflakeNodes []*snowflake.Node

_config Config
_tables []any
_configs map[any]Config

mutex sync.RWMutex
}
Expand Down Expand Up @@ -103,24 +102,23 @@ type Config struct {
PrimaryKeyGeneratorFn func(tableIdx int64) int64
}

func Register(config Config, tables ...any) *Sharding {
func Register(configs map[any]Config) *Sharding {
return &Sharding{
_config: config,
_tables: tables,
_configs: configs,
}
}

func (s *Sharding) compile() error {
if s.configs == nil {
s.configs = make(map[string]Config)
}
for _, table := range s._tables {
for table, config := range s._configs {
if t, ok := table.(string); ok {
s.configs[t] = s._config
s.configs[t] = config
} else {
stmt := &gorm.Statement{DB: s.DB}
if err := stmt.Parse(table); err == nil {
s.configs[stmt.Table] = s._config
s.configs[stmt.Table] = config
} else {
return err
}
Expand Down
10 changes: 5 additions & 5 deletions sharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ func init() {
},
}

middleware = Register(shardingConfig, &Order{})
middlewareNoID = Register(shardingConfigNoID, &Order{})
middleware = Register(map[any]Config{&Order{}: shardingConfig})
middlewareNoID = Register(map[any]Config{&Order{}: shardingConfigNoID})

fmt.Println("Clean only tables ...")
dropTables()
Expand Down Expand Up @@ -391,7 +391,7 @@ func TestPKSnowflake(t *testing.T) {
})
}
shardingConfig.PrimaryKeyGenerator = PKSnowflake
middleware := Register(shardingConfig, &Order{})
middleware := Register(map[any]Config{&Order{}: shardingConfig})
db.Use(middleware)

node, _ := snowflake.NewNode(0)
Expand All @@ -412,7 +412,7 @@ func TestPKPGSequence(t *testing.T) {
DisableForeignKeyConstraintWhenMigrating: true,
})
shardingConfig.PrimaryKeyGenerator = PKPGSequence
middleware := Register(shardingConfig, &Order{})
middleware := Register(map[any]Config{&Order{}: shardingConfig})
db.Use(middleware)

db.Exec("SELECT setval('" + pgSeqName("orders") + "', 42)")
Expand All @@ -430,7 +430,7 @@ func TestPKMySQLSequence(t *testing.T) {
DisableForeignKeyConstraintWhenMigrating: true,
})
shardingConfig.PrimaryKeyGenerator = PKMySQLSequence
middleware := Register(shardingConfig, &Order{})
middleware := Register(map[any]Config{&Order{}: shardingConfig})
db.Use(middleware)

db.Exec("UPDATE `" + mySQLSeqName("orders") + "` SET id = 42")
Expand Down

0 comments on commit 61167de

Please sign in to comment.