概述
在编码过程中,特别是在写 dbo 层的数据库操作时,我们经常会遇到写事务的地方,在很长一段时间我都是使用的正常逻辑如下:
//在新建和更新的时候注意自动创建关联和更新 func (a *ArticleRepo) Create(article interface{}) (bool, error) { db := a.db.Begin() var tags []*models.Tag user := &models.User{} ac := article.(*models.Article) if err := db.Create(&ac).Error; err != nil { db.Rollback() return false, err } for _, t := range ac.Tags { tag := &models.Tag{} if err := db.Where("name = ? ", t.Name).First(&tag).Error; err != nil { tag.Name = t.Name tag.UserId = ac.AuthorID } tags = append(tags, tag) } if err := db.Where("id = ?", ac.AuthorID).First(&user).Error; err != nil { db.Rollback() return false, errors.New("请登陆") } if err := db.Model(&user).UpdateColumn("lottery_num", gorm.Expr("lottery_num + ?", 1)).Error; err != nil { db.Rollback() return false, err } if err := db.Model(&ac).Association("Tags"). Append(tags).Error; err != nil { db.Rollback() return false, err } if err := db.Commit().Error; err != nil { db.Rollback() return false, err } return true, nil }
我发现,在处理过程中会经常要写到 db.rollback
于是在见识过比较多的常用写法后写法以“匿名函数”的写法提取出主要流程完成主体业务省下大量无用代码,同时也避免忘记 rollback 和 commit,目前的写法以 transction_test.go 中为准:
首先创建事务接受上下文,db 源,fn 方法,然后将错误 recover,进行 wraper。
transction.go:
package dbx import ( "context" "fmt" "github.com/pkg/errors" ) func NewTransaction(ctx context.Context, db Database, fn func(ctx context.Context, tx Transaction) error) (err error) { tx, err := db.Begin() if err != nil { return errors.Wrap(err, "begin") } // recover defer func() { if r := recover(); r != nil { var ok bool err, ok = r.(error) if !ok { err = fmt.Errorf("%v", r) } err = errors.WithMessage(err, "recover") } if err != nil { if e := tx.Rollback(); e != nil { err = errors.WithMessagef(err, "rollback %v", e) } return } err = errors.Wrap(tx.Commit(), "tx commit") }() return fn(ctx, tx) }
sqlconn.go:
package dbx import ( "context" "database/sql" "github.com/jmoiron/sqlx" ) type ( Database interface { Begin() (Transaction, error) Ping() error } Transaction interface { Commit() error Rollback() error Get(dst interface{}, query string, args ...interface{}) error Select(dst interface{}, query string, args ...interface{}) error Exec(query string, args ...interface{}) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) GetContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error SelectContext(ctx context.Context, dst interface{}, query string, args ...interface{}) error } ) type DatabaseSQLX struct { *sqlx.DB } func (db *DatabaseSQLX) Begin() (Transaction, error) { tx, err := db.Beginx() return &Tx{Tx: tx}, err } type Tx struct { *sqlx.Tx } func (t *Tx) Commit() error { return t.Tx.Commit() } func (t *Tx) Rollback() error { return t.Tx.Rollback() }
mysql.go
package dbx import ( "fmt" "github.com/jmoiron/sqlx" "github.com/pkg/errors" ) const DefaultMySQLDSN = "%s:%s@tcp(%s:%d)/%s?charset=utf8&parseTime=true" func ConnectMySQL(format string, data DSN) (err error) { dsn := fmt.Sprintf(format, data.User, data.Password, data.Host, data.Port, data.DBName) db, err := sqlx.Connect("mysql", dsn) if err != nil { return errors.WithStack(err) } db.SetMaxIdleConns(5) db.SetMaxOpenConns(30) DB = &DatabaseSQLX{DB: db} return errors.WithStack(DB.Ping()) }
db.go
type DSN struct { Host string Port int User string Password string DBName string }
trainsaction.go
func TestNewTransaction(t *testing.T) { ctx := context.Background() mockDB := &MockDB{} t.Run("success", func(t *testing.T) { defer resetTrans() assert.Equal(t, nil, NewTransaction(ctx, mockDB, func(ctx context.Context, tx Transaction) error { return nil })) assert.Equal(t, false, rollback) assert.Equal(t, true, commit) }) t.Run("return error", func(t *testing.T) { defer resetTrans() returnErr := errors.New("return error") err := NewTransaction(ctx, mockDB, func(ctx context.Context, tx Transaction) error { return errors.WithStack(returnErr) }) assert.Equal(t, true, rollback) assert.Equal(t, false, commit) assert.Equal(t, returnErr, errors.Cause(err)) }) t.Run("panic error", func(t *testing.T) { defer resetTrans() panicErr := errors.New("panic error") err := NewTransaction(ctx, mockDB, func(ctx context.Context, tx Transaction) error { panic(panicErr) }) assert.Equal(t, true, rollback) assert.Equal(t, false, commit) assert.Equal(t, panicErr, errors.Cause(err)) }) }
欢迎来到这里!
我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。
注册 关于