概述
在编码过程中,特别是在写 dbo 层的数据库操作时,我们经常会遇到写事务的地方,在很长一段时间我都是使用的正常逻辑如下:
//在新建和更新的时候注意自动创建关联和更新
func (a *ArticleRepo) Create(article interface{}) (bool, error) {
db := a.db.Begin()
var tags []*models.Tag
user := &models.User{}
ac := article.(*models.Article)
if err := db.Create(&ac).Error; err != nil {
db.Rollback()
return false, err
}
for _, t := range ac.Tags {
tag := &models.Tag{}
if err := db.Where("name = ? ", t.Name).First(&tag).Error; err != nil {
tag.Name = t.Name
tag.UserId = ac.AuthorID
}
tags = append(tags, tag)
}
if err := db.Where("id = ?", ac.AuthorID).First(&user).Error; err != nil {
db.Rollback()
return false, errors.New("请登陆")
}
if err := db.Model(&user).UpdateColumn("lottery_num",
gorm.Expr("lottery_num + ?", 1)).Error; err != nil {
db.Rollback()
return false, err
}
if err := db.Model(&ac).Association("Tags").
Append(tags).Error; err != nil {
db.Rollback()
return false, err
}
if err := db.Commit().Error; err != nil {
db.Rollback()
return false, err
}
return true, nil
}
我发现,在处理过程中会经常要写到 db.rollback
于是在见识过比较多的常用写法后写法以“匿名函数”的写法提取出主要流程完成主体业务省下大量无用代码,同时也避免忘记 rollback 和 commit,目前的写法以 transction_test.go 中为准:
首先创建事务接受上下文,db 源,fn 方法,然后将错误 recover,进行 wraper。
transction.go:
package dbx
import (
"context"
"fmt"
"github.com/pkg/errors"
)
func NewTransaction(ctx context.Context, db Database, fn func(ctx context.Context, tx Transaction) error) (err error) {
tx, err := db.Begin()
if err != nil {
return errors.Wrap(err, "begin")
}
// recover
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok = r.(error)
if !ok {
err = fmt.Errorf("%v", r)
}
err = errors.WithMessage(err, "recover")
}
if err != nil {
if e := tx.Rollback(); e != nil {
err = errors.WithMessagef(err, "rollback %v", e)
}
return
}
err = errors.Wrap(tx.Commit(), "tx commit")
}()
return fn(ctx, tx)
}
sqlconn.go:
package dbx
import (
"context"
"database/sql"
"github.com/jmoiron/sqlx"
)
type (
Database interface {
Begin() (Transaction, error)
Ping() error
}
Transaction interface {
Commit() error
Rollback() error
Get(dst interface{}, query string, args ...interface{}) error
Select(dst interface{}, query string, args ...interface{}) error
Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
GetContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error
SelectContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error
}
)
type DatabaseSQLX struct {
*sqlx.DB
}
func (db *DatabaseSQLX) Begin() (Transaction, error) {
tx, err := db.Beginx()
return &Tx{Tx: tx}, err
}
type Tx struct {
*sqlx.Tx
}
func (t *Tx) Commit() error {
return t.Tx.Commit()
}
func (t *Tx) Rollback() error {
return t.Tx.Rollback()
}
mysql.go
package dbx
import (
"fmt"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
)
const DefaultMySQLDSN = "%s:%s@tcp(%s:%d)/%s?charset=utf8&parseTime=true"
func ConnectMySQL(format string, data DSN) (err error) {
dsn := fmt.Sprintf(format, data.User, data.Password, data.Host, data.Port, data.DBName)
db, err := sqlx.Connect("mysql", dsn)
if err != nil {
return errors.WithStack(err)
}
db.SetMaxIdleConns(5)
db.SetMaxOpenConns(30)
DB = &DatabaseSQLX{DB: db}
return errors.WithStack(DB.Ping())
}
db.go
type DSN struct {
Host string
Port int
User string
Password string
DBName string
}
trainsaction.go
func TestNewTransaction(t *testing.T) {
ctx := context.Background()
mockDB := &MockDB{}
t.Run("success", func(t *testing.T) {
defer resetTrans()
assert.Equal(t, nil, NewTransaction(ctx, mockDB, func(ctx context.Context, tx Transaction) error {
return nil
}))
assert.Equal(t, false, rollback)
assert.Equal(t, true, commit)
})
t.Run("return error", func(t *testing.T) {
defer resetTrans()
returnErr := errors.New("return error")
err := NewTransaction(ctx, mockDB, func(ctx context.Context, tx Transaction) error {
return errors.WithStack(returnErr)
})
assert.Equal(t, true, rollback)
assert.Equal(t, false, commit)
assert.Equal(t, returnErr, errors.Cause(err))
})
t.Run("panic error", func(t *testing.T) {
defer resetTrans()
panicErr := errors.New("panic error")
err := NewTransaction(ctx, mockDB, func(ctx context.Context, tx Transaction) error {
panic(panicErr)
})
assert.Equal(t, true, rollback)
assert.Equal(t, false, commit)
assert.Equal(t, panicErr, errors.Cause(err))
})
}
欢迎来到这里!
我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。
注册 关于