Skip to content

Commit

Permalink
Merge pull request #19 from xmgtony/develop
Browse files Browse the repository at this point in the history
✨ 添加多service方法事务支持
  • Loading branch information
xmgtony authored Jun 29, 2023
2 parents a831df3 + 6becb7f commit dc74e7c
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 66 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

### 更新日志

*2023-06-29:增加事务支持,多个service方法可以放在一个事务里执行,使用方式参考 internal/service/tx_demo.go 文件。*

*2023-03-04:调整依赖注入wire.go实现,wire解决复杂依赖较为困难,每次对代码有破坏性更改,改为使用传参解决。*

*2022-05-27:添加新的演示功能,用户记账,用于演示脚手架多个router,handler,service的使用,详情见使用文档,升级viper版本,解决依赖安全问题。*
Expand Down
4 changes: 2 additions & 2 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package main

import (
"apiserver-gin/internal/middleware"
"apiserver-gin/internal/repo/mysql"
"apiserver-gin/pkg/config"
"apiserver-gin/pkg/db"
"apiserver-gin/pkg/log"
"apiserver-gin/pkg/version"
"apiserver-gin/server"
Expand All @@ -19,7 +19,7 @@ func main() {
// 加载配置文件
c := config.Load(appOpt.ConfigFilePath)
log.InitLogger(&c.LogConfig, c.AppName) // 日志
ds := db.NewDefaultMysql(c.DBConfig) // 创建数据库链接,使用默认的实现方式
ds := mysql.NewDefaultMysql(c.DBConfig) // 创建数据库链接,使用默认的实现方式
// 创建HTTPServer
srv := server.NewHttpServer(config.GlobalConfig)
srv.RegisterOnShutdown(func() {
Expand Down
3 changes: 1 addition & 2 deletions cmd/wire_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions internal/repo/mysql/account_bill.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ func NewAccountBillRepo(_ds db.IDataSource) *accountBillRepo {
}

func (ab *accountBillRepo) Save(ctx context.Context, bill *model.AccountBill) error {
return ab.ds.Master().Create(bill).Error
return ab.ds.Master(ctx).Create(bill).Error
}

func (ab *accountBillRepo) SelectListByUserId(ctx context.Context, userId int64) ([]model.AccountBill, error) {
var accountBills []model.AccountBill
err := ab.ds.Master().Where("user_id = ?", userId).Find(&accountBills).Error
err := ab.ds.Master(ctx).Where("user_id = ?", userId).Find(&accountBills).Error
if err != nil {
return nil, err
}
Expand Down
74 changes: 74 additions & 0 deletions internal/repo/mysql/mysql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// author: xmgtony
// date: 2023-06-29 14:47
// version:

package mysql

import (
"apiserver-gin/pkg/config"
"apiserver-gin/pkg/db"
"context"
"gorm.io/gorm"
)

// var _ IDataSource = new(*defaultMysqlDataSource) 也可
var _ db.IDataSource = (*defaultMysqlDataSource)(nil)

// defaultMysqlDataSource 默认mysql数据源实现
type defaultMysqlDataSource struct {
master *gorm.DB // 定义私有属性,用来持有主库链接,防止每次创建,创建后直接返回该变量。
slave *gorm.DB // 同上,从库链接
}

func (d *defaultMysqlDataSource) Master(ctx context.Context) *gorm.DB {
// 事物, 根据事物的key取出tx
tx, ok := ctx.Value(contextTxKey{}).(*gorm.DB)
if ok {
return tx
}
if d.master == nil {
panic("The [master] connection is nil, Please initialize it first.")
}
return d.master
}

func (d *defaultMysqlDataSource) Slave(ctx context.Context) *gorm.DB {
tx, ok := ctx.Value(contextTxKey{}).(*gorm.DB)
if ok {
return tx
}
if d.slave == nil {
panic("The [slave] connection is nil, Please initialize it first.")
}
return d.slave
}

func (d *defaultMysqlDataSource) Close() {
// 关闭主库链接
if d.master != nil {
m, err := d.master.DB()
if err != nil {
_ = m.Close()
}
}
// 关闭从库链接
if d.slave != nil {
s, err := d.slave.DB()
if err != nil {
_ = s.Close()
}
}
}

func NewDefaultMysql(c config.DBConfig) *defaultMysqlDataSource {
return &defaultMysqlDataSource{
master: db.GetMysqlConn(
c.Username,
c.Password,
c.Host,
c.Port,
c.Dbname,
c.MaximumPoolSize,
c.MaximumIdleSize),
}
}
3 changes: 3 additions & 0 deletions internal/repo/mysql/provider_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ package mysql

import (
"apiserver-gin/internal/repo"
"apiserver-gin/pkg/db"
"github.com/google/wire"
)

var ProviderSet = wire.NewSet(
NewTransaction,
wire.Bind(new(db.Transaction), new(*transaction)),
NewUserRepo,
wire.Bind(new(repo.UserRepo), new(*userRepo)),
NewAccountBillRepo,
Expand Down
30 changes: 30 additions & 0 deletions internal/repo/mysql/tx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Created on 2023/3/15.
// @author tony
// email [email protected]
// description 事物控制接口

package mysql

import (
"apiserver-gin/pkg/db"
"context"
"gorm.io/gorm"
)

type contextTxKey struct{}

// 事物默认实现
type transaction struct {
ds db.IDataSource
}

func NewTransaction(_ds db.IDataSource) *transaction {
return &transaction{ds: _ds}
}

func (t *transaction) Execute(ctx context.Context, fn func(ctx context.Context) error) error {
return t.ds.Master(ctx).Transaction(func(tx *gorm.DB) error {
withValue := context.WithValue(ctx, contextTxKey{}, tx)
return fn(withValue)
})
}
6 changes: 3 additions & 3 deletions internal/repo/mysql/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@ func NewUserRepo(_ds db.IDataSource) *userRepo {

func (ur *userRepo) GetUserByName(ctx context.Context, name string) (*model.User, error) {
user := &model.User{}
err := ur.ds.Master().Where("name = ?", name).Find(user).Error
err := ur.ds.Master(ctx).Where("name = ?", name).Find(user).Error
return user, err
}

func (ur *userRepo) GetUserById(ctx context.Context, uid int64) (*model.User, error) {
user := &model.User{}
err := ur.ds.Master().Where("id = ?", uid).Find(user).Error
err := ur.ds.Master(ctx).Where("id = ?", uid).Find(user).Error
return user, err
}

func (ur *userRepo) GetUserByMobile(ctx context.Context, mobile string) (*model.User, error) {
user := &model.User{}
err := ur.ds.Master().
err := ur.ds.Master(ctx).
Where("mobile = ?", mobile).
Where("enabled_status = 1").
First(user).Error
Expand Down
2 changes: 2 additions & 0 deletions internal/service/provider_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ var ProviderSet = wire.NewSet(
wire.Bind(new(UserService), new(*userService)),
NewAccountBillService,
wire.Bind(new(AccountBillService), new(*accountBillService)),
NewTxDemoService,
wire.Bind(new(TxDemoService), new(*txDemoService)),
)
49 changes: 49 additions & 0 deletions internal/service/tx_demo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// author: xmgtony
// date: 2023-06-29 15:00
// version: 事务操作演示

package service

import (
"apiserver-gin/pkg/db"
"context"
)

// TxDemoService txDemo服务接口
type TxDemoService interface {
SaveWithTx(ctx context.Context)
}

// txDemoService 默认实现
type txDemoService struct {
userService UserService
billService AccountBillService
tx db.Transaction
}

func NewTxDemoService(us UserService, bs AccountBillService, tx db.Transaction) *txDemoService {
return &txDemoService{
userService: us,
billService: bs,
tx: tx,
}
}

func (tds *txDemoService) SaveWithTx(ctx context.Context) {
err := tds.tx.Execute(ctx, func(context context.Context) error {
// TODO 这里只是举例,实际请根据业务执行多个service操作
// 操作1
// tds.userService.Save(context, user)
// 操作2
// tds.billService.Save(context, bill)
//if (条件1) {
// 返回err则回滚事务
// return err
//}
return nil
})
if err != nil {
// 处理error
return
}
}
62 changes: 5 additions & 57 deletions pkg/db/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,75 +6,23 @@
package db

import (
"apiserver-gin/pkg/config"
"context"
"fmt"
"time"

"gorm.io/driver/mysql"
"gorm.io/gorm"
)

// var _ IDataSource = new(*defaultMysqlDataSource) 也可
var _ IDataSource = (*defaultMysqlDataSource)(nil)

// IDataSource 定义数据库数据源接口,按照业务需求可以返回主库链接Master和从库链接Slave
type IDataSource interface {
Master() *gorm.DB
Slave() *gorm.DB
Master(ctx context.Context) *gorm.DB
Slave(ctx context.Context) *gorm.DB
Close()
}

// defaultMysqlDataSource 默认mysql数据源实现
type defaultMysqlDataSource struct {
master *gorm.DB // 定义私有属性,用来持有主库链接,防止每次创建,创建后直接返回该变量。
slave *gorm.DB // 同上,从库链接
}

func (d *defaultMysqlDataSource) Master() *gorm.DB {
if d.master == nil {
panic("The [master] connection is nil, Please initialize it first.")
}
return d.master
}

func (d *defaultMysqlDataSource) Slave() *gorm.DB {
if d.master == nil {
panic("The [slave] connection is nil, Please initialize it first.")
}
return d.slave
}

func (d *defaultMysqlDataSource) Close() {
// 关闭主库链接
if d.master != nil {
m, err := d.master.DB()
if err != nil {
_ = m.Close()
}
}
// 关闭从库链接
if d.slave != nil {
s, err := d.slave.DB()
if err != nil {
_ = s.Close()
}
}
}

func NewDefaultMysql(c config.DBConfig) *defaultMysqlDataSource {
return &defaultMysqlDataSource{
master: connect(
c.Username,
c.Password,
c.Host,
c.Port,
c.Dbname,
c.MaximumPoolSize,
c.MaximumIdleSize),
}
}

func connect(user, password, host, port, dbname string, maxPoolSize, maxIdle int) *gorm.DB {
// GetMysqlConn 创建Mysql链接
func GetMysqlConn(user, password, host, port, dbname string, maxPoolSize, maxIdle int) *gorm.DB {
dsn := fmt.Sprintf("%s:%s@(%s:%s)/%s?charset=utf8&parseTime=True&loc=Local",
user,
password,
Expand Down
13 changes: 13 additions & 0 deletions pkg/db/tx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// author: xmgtony
// date: 2023-06-29 14:38
// version: 事务接口

package db

import "context"

// Transaction 事物接口
type Transaction interface {
// Execute 执行一个事务方法,func为一个需要保证事务完整性的业务方法
Execute(ctx context.Context, fn func(ctx context.Context) error) error
}

0 comments on commit dc74e7c

Please sign in to comment.