diff --git a/README.md b/README.md index e049d34..224edc3 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,8 @@ ### 更新日志 +*2023-06-29:增加事务支持,多个service方法可以放在一个事务里执行,使用方式参考 internal/service/tx_demo.go 文件。* + *2023-03-04:调整依赖注入wire.go实现,wire解决复杂依赖较为困难,每次对代码有破坏性更改,改为使用传参解决。* *2022-05-27:添加新的演示功能,用户记账,用于演示脚手架多个router,handler,service的使用,详情见使用文档,升级viper版本,解决依赖安全问题。* diff --git a/cmd/main.go b/cmd/main.go index 317d174..299f328 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -2,8 +2,8 @@ package main import ( "apiserver-gin/internal/middleware" + "apiserver-gin/internal/repo/mysql" "apiserver-gin/pkg/config" - "apiserver-gin/pkg/db" "apiserver-gin/pkg/log" "apiserver-gin/pkg/version" "apiserver-gin/server" @@ -19,7 +19,7 @@ func main() { // 加载配置文件 c := config.Load(appOpt.ConfigFilePath) log.InitLogger(&c.LogConfig, c.AppName) // 日志 - ds := db.NewDefaultMysql(c.DBConfig) // 创建数据库链接,使用默认的实现方式 + ds := mysql.NewDefaultMysql(c.DBConfig) // 创建数据库链接,使用默认的实现方式 // 创建HTTPServer srv := server.NewHttpServer(config.GlobalConfig) srv.RegisterOnShutdown(func() { diff --git a/cmd/wire_gen.go b/cmd/wire_gen.go index 0dca0c3..213d613 100644 --- a/cmd/wire_gen.go +++ b/cmd/wire_gen.go @@ -1,6 +1,6 @@ // Code generated by Wire. DO NOT EDIT. -//go:generate go run github.com/google/wire/cmd/wire +//go:generate wire //go:build !wireinject // +build !wireinject @@ -20,7 +20,6 @@ import ( // Injectors from wire.go: -// initRouter 初始化router func initRouter(ds db.IDataSource) server.Router { userRepo := mysql.NewUserRepo(ds) userService := service.NewUserService(userRepo) diff --git a/internal/repo/mysql/account_bill.go b/internal/repo/mysql/account_bill.go index 8301ba2..080ee17 100644 --- a/internal/repo/mysql/account_bill.go +++ b/internal/repo/mysql/account_bill.go @@ -25,12 +25,12 @@ func NewAccountBillRepo(_ds db.IDataSource) *accountBillRepo { } func (ab *accountBillRepo) Save(ctx context.Context, bill *model.AccountBill) error { - return ab.ds.Master().Create(bill).Error + return ab.ds.Master(ctx).Create(bill).Error } func (ab *accountBillRepo) SelectListByUserId(ctx context.Context, userId int64) ([]model.AccountBill, error) { var accountBills []model.AccountBill - err := ab.ds.Master().Where("user_id = ?", userId).Find(&accountBills).Error + err := ab.ds.Master(ctx).Where("user_id = ?", userId).Find(&accountBills).Error if err != nil { return nil, err } diff --git a/internal/repo/mysql/mysql.go b/internal/repo/mysql/mysql.go new file mode 100644 index 0000000..01c180e --- /dev/null +++ b/internal/repo/mysql/mysql.go @@ -0,0 +1,74 @@ +// author: xmgtony +// date: 2023-06-29 14:47 +// version: + +package mysql + +import ( + "apiserver-gin/pkg/config" + "apiserver-gin/pkg/db" + "context" + "gorm.io/gorm" +) + +// var _ IDataSource = new(*defaultMysqlDataSource) 也可 +var _ db.IDataSource = (*defaultMysqlDataSource)(nil) + +// defaultMysqlDataSource 默认mysql数据源实现 +type defaultMysqlDataSource struct { + master *gorm.DB // 定义私有属性,用来持有主库链接,防止每次创建,创建后直接返回该变量。 + slave *gorm.DB // 同上,从库链接 +} + +func (d *defaultMysqlDataSource) Master(ctx context.Context) *gorm.DB { + // 事物, 根据事物的key取出tx + tx, ok := ctx.Value(contextTxKey{}).(*gorm.DB) + if ok { + return tx + } + if d.master == nil { + panic("The [master] connection is nil, Please initialize it first.") + } + return d.master +} + +func (d *defaultMysqlDataSource) Slave(ctx context.Context) *gorm.DB { + tx, ok := ctx.Value(contextTxKey{}).(*gorm.DB) + if ok { + return tx + } + if d.slave == nil { + panic("The [slave] connection is nil, Please initialize it first.") + } + return d.slave +} + +func (d *defaultMysqlDataSource) Close() { + // 关闭主库链接 + if d.master != nil { + m, err := d.master.DB() + if err != nil { + _ = m.Close() + } + } + // 关闭从库链接 + if d.slave != nil { + s, err := d.slave.DB() + if err != nil { + _ = s.Close() + } + } +} + +func NewDefaultMysql(c config.DBConfig) *defaultMysqlDataSource { + return &defaultMysqlDataSource{ + master: db.GetMysqlConn( + c.Username, + c.Password, + c.Host, + c.Port, + c.Dbname, + c.MaximumPoolSize, + c.MaximumIdleSize), + } +} diff --git a/internal/repo/mysql/provider_set.go b/internal/repo/mysql/provider_set.go index 3a9ad4e..7388cb3 100644 --- a/internal/repo/mysql/provider_set.go +++ b/internal/repo/mysql/provider_set.go @@ -7,10 +7,13 @@ package mysql import ( "apiserver-gin/internal/repo" + "apiserver-gin/pkg/db" "github.com/google/wire" ) var ProviderSet = wire.NewSet( + NewTransaction, + wire.Bind(new(db.Transaction), new(*transaction)), NewUserRepo, wire.Bind(new(repo.UserRepo), new(*userRepo)), NewAccountBillRepo, diff --git a/internal/repo/mysql/tx.go b/internal/repo/mysql/tx.go new file mode 100644 index 0000000..91c5d03 --- /dev/null +++ b/internal/repo/mysql/tx.go @@ -0,0 +1,30 @@ +// Created on 2023/3/15. +// @author tony +// email xmgtony@gmail.com +// description 事物控制接口 + +package mysql + +import ( + "apiserver-gin/pkg/db" + "context" + "gorm.io/gorm" +) + +type contextTxKey struct{} + +// 事物默认实现 +type transaction struct { + ds db.IDataSource +} + +func NewTransaction(_ds db.IDataSource) *transaction { + return &transaction{ds: _ds} +} + +func (t *transaction) Execute(ctx context.Context, fn func(ctx context.Context) error) error { + return t.ds.Master(ctx).Transaction(func(tx *gorm.DB) error { + withValue := context.WithValue(ctx, contextTxKey{}, tx) + return fn(withValue) + }) +} diff --git a/internal/repo/mysql/user.go b/internal/repo/mysql/user.go index b503edd..7d974ff 100644 --- a/internal/repo/mysql/user.go +++ b/internal/repo/mysql/user.go @@ -21,19 +21,19 @@ func NewUserRepo(_ds db.IDataSource) *userRepo { func (ur *userRepo) GetUserByName(ctx context.Context, name string) (*model.User, error) { user := &model.User{} - err := ur.ds.Master().Where("name = ?", name).Find(user).Error + err := ur.ds.Master(ctx).Where("name = ?", name).Find(user).Error return user, err } func (ur *userRepo) GetUserById(ctx context.Context, uid int64) (*model.User, error) { user := &model.User{} - err := ur.ds.Master().Where("id = ?", uid).Find(user).Error + err := ur.ds.Master(ctx).Where("id = ?", uid).Find(user).Error return user, err } func (ur *userRepo) GetUserByMobile(ctx context.Context, mobile string) (*model.User, error) { user := &model.User{} - err := ur.ds.Master(). + err := ur.ds.Master(ctx). Where("mobile = ?", mobile). Where("enabled_status = 1"). First(user).Error diff --git a/internal/service/provider_set.go b/internal/service/provider_set.go index f5f68f8..315ecdc 100644 --- a/internal/service/provider_set.go +++ b/internal/service/provider_set.go @@ -12,4 +12,6 @@ var ProviderSet = wire.NewSet( wire.Bind(new(UserService), new(*userService)), NewAccountBillService, wire.Bind(new(AccountBillService), new(*accountBillService)), + NewTxDemoService, + wire.Bind(new(TxDemoService), new(*txDemoService)), ) diff --git a/internal/service/tx_demo.go b/internal/service/tx_demo.go new file mode 100644 index 0000000..1431c61 --- /dev/null +++ b/internal/service/tx_demo.go @@ -0,0 +1,49 @@ +// author: xmgtony +// date: 2023-06-29 15:00 +// version: 事务操作演示 + +package service + +import ( + "apiserver-gin/pkg/db" + "context" +) + +// TxDemoService txDemo服务接口 +type TxDemoService interface { + SaveWithTx(ctx context.Context) +} + +// txDemoService 默认实现 +type txDemoService struct { + userService UserService + billService AccountBillService + tx db.Transaction +} + +func NewTxDemoService(us UserService, bs AccountBillService, tx db.Transaction) *txDemoService { + return &txDemoService{ + userService: us, + billService: bs, + tx: tx, + } +} + +func (tds *txDemoService) SaveWithTx(ctx context.Context) { + err := tds.tx.Execute(ctx, func(context context.Context) error { + // TODO 这里只是举例,实际请根据业务执行多个service操作 + // 操作1 + // tds.userService.Save(context, user) + // 操作2 + // tds.billService.Save(context, bill) + //if (条件1) { + // 返回err则回滚事务 + // return err + //} + return nil + }) + if err != nil { + // 处理error + return + } +} diff --git a/pkg/db/mysql.go b/pkg/db/mysql.go index 2646372..ce312a0 100644 --- a/pkg/db/mysql.go +++ b/pkg/db/mysql.go @@ -6,7 +6,7 @@ package db import ( - "apiserver-gin/pkg/config" + "context" "fmt" "time" @@ -14,67 +14,15 @@ import ( "gorm.io/gorm" ) -// var _ IDataSource = new(*defaultMysqlDataSource) 也可 -var _ IDataSource = (*defaultMysqlDataSource)(nil) - // IDataSource 定义数据库数据源接口,按照业务需求可以返回主库链接Master和从库链接Slave type IDataSource interface { - Master() *gorm.DB - Slave() *gorm.DB + Master(ctx context.Context) *gorm.DB + Slave(ctx context.Context) *gorm.DB Close() } -// defaultMysqlDataSource 默认mysql数据源实现 -type defaultMysqlDataSource struct { - master *gorm.DB // 定义私有属性,用来持有主库链接,防止每次创建,创建后直接返回该变量。 - slave *gorm.DB // 同上,从库链接 -} - -func (d *defaultMysqlDataSource) Master() *gorm.DB { - if d.master == nil { - panic("The [master] connection is nil, Please initialize it first.") - } - return d.master -} - -func (d *defaultMysqlDataSource) Slave() *gorm.DB { - if d.master == nil { - panic("The [slave] connection is nil, Please initialize it first.") - } - return d.slave -} - -func (d *defaultMysqlDataSource) Close() { - // 关闭主库链接 - if d.master != nil { - m, err := d.master.DB() - if err != nil { - _ = m.Close() - } - } - // 关闭从库链接 - if d.slave != nil { - s, err := d.slave.DB() - if err != nil { - _ = s.Close() - } - } -} - -func NewDefaultMysql(c config.DBConfig) *defaultMysqlDataSource { - return &defaultMysqlDataSource{ - master: connect( - c.Username, - c.Password, - c.Host, - c.Port, - c.Dbname, - c.MaximumPoolSize, - c.MaximumIdleSize), - } -} - -func connect(user, password, host, port, dbname string, maxPoolSize, maxIdle int) *gorm.DB { +// GetMysqlConn 创建Mysql链接 +func GetMysqlConn(user, password, host, port, dbname string, maxPoolSize, maxIdle int) *gorm.DB { dsn := fmt.Sprintf("%s:%s@(%s:%s)/%s?charset=utf8&parseTime=True&loc=Local", user, password, diff --git a/pkg/db/tx.go b/pkg/db/tx.go new file mode 100644 index 0000000..1db6467 --- /dev/null +++ b/pkg/db/tx.go @@ -0,0 +1,13 @@ +// author: xmgtony +// date: 2023-06-29 14:38 +// version: 事务接口 + +package db + +import "context" + +// Transaction 事物接口 +type Transaction interface { + // Execute 执行一个事务方法,func为一个需要保证事务完整性的业务方法 + Execute(ctx context.Context, fn func(ctx context.Context) error) error +}