美文网首页
go源码走读-gorm

go源码走读-gorm

作者: 温岭夹糕 | 来源:发表于2024-07-27 21:53 被阅读0次

目录

链接地址

1.gorm使用实例

我们常使用懒加载和惰加载结合完成单例模式

import (
    "sync"

    "gorm.io/driver/mysql"
    "gorm.io/gorm"
)

var (
    db     *gorm.DB
    doOnce sync.Once
    dsn    string = "username:psd@(ip:port)/database?database?timeout=5000ms&readTimeout=5000ms&writeTimeout=5000ms&charset=utf8mb4&parseTime=true&loc=Local"
)

func GetDB() *gorm.DB {
    var err error
    doOnce.Do(func() {
        db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
        if err != nil {
            panic(err)
        }
        // db.Table()
    })
    return db
}

本节接下来不真实连接数据库,而是用sql-mock来mock数据

2.核心DB类

即gorm.DB
首先需要理解会话这个概念,gorm.DB是该库定义的数据库类,所有执行的数据的操作都与这个类有关,以链式调用方式展开(如条件查询)

// Get first matched record
db.Where("name = ?", "jinzhu").First(&user)

每当链式调用后,就会新生成新的DB对象,该对象中存储了一些当前请求特有的状态信息,我们把这种对象叫做“会话”,即记录了执行过程上下文的对象

// gorm 中定义的数据库类
// 所有 orm 的思想
type DB struct {
    // 配置
    *Config
    // 错误
    Error        error
    // 影响的行数
    RowsAffected int64
    // 会话状态信息
    Statement    *Statement
    // 克隆次数
    clone        int
}
  • Statement:一次会话的状态信息,比如请求和响应信息
  • clone: 会话被克隆的次数. 倘若 clone = 1,代表是始祖 DB 实例;倘若 clone > 1,代表是从始祖 DB 克隆出来的会话
  • Error:一次会话执行过程中遇到的错误,一个信息里可能包含多个错误
func (db *DB) AddError(err error) error {
    if err != nil {
        // ...
        if db.Error == nil {
            db.Error = err
        } else {
            db.Error = fmt.Errorf("%v; %w", db.Error, err)
        }
    }
    return db.Error
}

我们看到使用fmt.Errorf和%w配合实现error wrapping(错误的拼接)

2.1 DB的克隆

func (db *DB) getInstance() *DB {
    if db.clone > 0 {
        tx := &DB{Config: db.Config, Error: db.Error}


        // 倘若是首次对 db 进行 clone,则需要构造出一个新的 statement 实例
        if db.clone == 1 {
            // clone with new statement
            tx.Statement = &Statement{
                DB:       tx,
                ConnPool: db.Statement.ConnPool,
                Context:  db.Statement.Context,
                Clauses:  map[string]clause.Clause{},
                Vars:     make([]interface{}, 0, 8),
            }
        // 倘若已经 db clone 过了,则还需要 clone 原先的 statement
        } else {
            // with clone statement
            tx.Statement = db.Statement.clone()
            tx.Statement.DB = tx
        }


        return tx
    }


    return db
}

主要通过对clone字段来判断:

  • clone=1,就克隆出一个新的会话
  • clone>2,就从始祖DB上克隆

2.2Statement

// Statement statement
type Statement struct {
    // 数据库实例
    *DB
    // ...
    // 表名
    Table                string
    // 操作的 po 模型
    Model                interface{}
    // ...
    // 处理结果反序列化到此处
    Dest                 interface{}
    // ...
    // 各种条件语句
    Clauses              map[string]clause.Clause
    
    // ...
    // 是否启用 distinct 模式
    Distinct             bool
    // select 语句
    Selects              []string // selected columns
    // omit 语句
    Omits                []string // omit columns
    // join 
    Joins                []join
    
    // ...
    // 连接池,通常情况下是 database/sql 库下的 *DB  类型.  在 prepare 模式为 gorm.PreparedStmtDB
    ConnPool             ConnPool
    // 操作表的概要信息
    Schema               *schema.Schema
    // 上下文,请求生命周期控制管理
    Context              context.Context
    // 在未查找到数据记录时,是否抛出 recordNotFound 错误
    RaiseErrorOnNotFound bool
    // ...
    // 执行的 sql,调用 state.Build 方法后,会将 sql 各部分文本依次追加到其中. 具体可见 2.5 小节
    SQL                  strings.Builder
    // 存储的变量
    Vars                 []interface{}
    // ...
}

因为要记录上下文,所以字段成员比较多,慢慢来看

2.3 po 模型

orm的思想就是将数据库表映射成一个数据模型(类/结构体),我们将该模型称为po模型(persist object 持久化数据模型),下面就是一个数据表的po类

type Reward struct {
    gorm.Model
    Amount sql.NullInt64 `gorm:"column:amount"`
    Type   string `gorm:"not null"`
    UserID int64  `gorm:"not null"`
}


func (r Reward) TableName() string {
    return "reward"
}

如果po模型声明了TableName方法,则隐式实现了gorm.Tabler接口

type Tabler interface {
    TableName() string
}

解析表时先尝试转为该接口,失败则直接用po模型的结构体名(会经过一定规则转化)当表名

那么问题来了,gorm.Statement.Model也是po模型,那这两个有啥区别?我们先暂时埋下一个疑问

2.4查询流程

如下测试代码

func TestQuery(t *testing.T) {
    db, mock, err := sqlmock.New()
    require.NoError(t, err)
    defer db.Close()
    gdb, err := gorm.Open(mysql.New(mysql.Config{
        SkipInitializeWithVersion: true,
        Conn:                      db,
    }), &gorm.Config{})
    require.NoError(t, err)
    rows := sqlmock.NewRows([]string{"id"}).AddRow(2)
    mock.ExpectQuery("SELECT *").WillReturnRows(rows)
    type Name struct {
        Id string
    }
    var name Name
    res := gdb.First(&name)
    t.Log(res.Error)
    t.Log(res.RowsAffected)
    t.Log(name.Id)
}

我们查看first源码

func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB){
    tx = db.Limit(1).Order(xxxx)
    tx.Statement.RaiseErrorOnNotFound = true
    tx.Statement.Dest = dest
    return tx.callbacks.Query().Execute(tx)
}
  • 先是进行limit和order的链式调用(因为first只返回一条)
  • 然后设置会话属性,最后调用callbacks执行查询方法

2.4.1 执行器processor

callbacks是db的内嵌字段config的成员

type Config struct{
    callbacks  *callbacks
}

type callbacks struct {
    processors map[string]*processor
}

type processor struct {
    db        *DB
    Clauses   []string
    fns       []func(*DB)
    callbacks []*callback
}

它的唯一成员就是gorm框架执行curd操作逻辑时用到的执行器processor,针对curd操作的处理函数会以list的形式聚合在对应类型的processor的fns字段中

    // 对应存储了 crud 等各类操作对应的执行器 processor
    // query -> query processor
    // create -> create processor
    // update -> update processor
    // delete -> delete processor

也就是说调用callbacks.Query查询方法实际就是执行query processor的fns函数成员

func (cs *callbacks) Query() *processor {
    return cs.processors["query"]
}

各类 processor 的初始化是通过 initializeCallbacks 方法完成,该方法是在gorm.Open中执行的

func initializeCallbacks(db *DB) *callbacks {
    return &callbacks{
        processors: map[string]*processor{
            "create": {db: db},
            "query":  {db: db},
            "update": {db: db},
            "delete": {db: db},
            "row":    {db: db},
            "raw":    {db: db},
        },
    }
}

再来细看processor具体结构

type processor struct {
    // 从属的 DB 实例
    db        *DB
    // 拼接 sql 时的关键字顺序. 比如 query 类,固定为 SELECT,FROM,WHERE,GROUP BY, ORDER BY, LIMIT, FOR
    Clauses   []string
    // 对应于 crud 类型的执行函数链
    fns       []func(*DB)
    callbacks []*callback
}

对应的Execute方法

func (p *processor) Execute(db *DB) *DB {
    // call scopes
    var (
        // ...
        stmt = db.Statement
        // ...
    )


    if len(stmt.BuildClauses) == 0 {
        // 根据 crud 类型,对 buildClauses 进行复制,用于后续的 sql 拼接
        stmt.BuildClauses = p.Clauses
        // ...
    }
    
    // ...
    // dest 和 model 相互赋值
    if stmt.Model == nil {
        stmt.Model = stmt.Dest
    } else if stmt.Dest == nil {
        stmt.Dest = stmt.Model
    }


    // 解析 model,获取对应表的 schema 信息
    if stmt.Model != nil {
        // ...
    }


    // 处理 dest 信息,将其添加到 stmt 当中
    if stmt.Dest != nil {
        // ...
    }


    // 执行一系列的 callback 函数,其中最核心的 create/query/update/delete 操作都被包含在其中了. 还包括了一系列前、后处理函数,具体可见第 3 章
    for _, f := range p.fns {
        f(db)
    }


    //...
    return db
}
  1. 获取会话
  2. 从会话中获取条件语句用于后续拼接sql语句
  3. 解析po模型到model中
  4. 处理dest信息,这里指我们传入fist的参数,最终会修改结果赋值
  5. 执行callback函数

我们注意这里的第三步解析po模型到model中

stmt.Parse(stmt.Model)

实际上就是再将po模型具体化,比如提取出表名,字段名,对应的数据库类型信息等保存到model中

2.4.2 条件 clause

一条执行 sql 中,各个部分都属于一个 clause,比如一条 SELECT * FROM reward WHERE id < 10 ORDER by id 的 SQL,其中就包含了 SELECT、FROM、WHERE 和 ORDER 四个 clause.
当使用方通过链式操作克隆 DB时,对应追加的状态信息就会生成一个新的 clause,追加到 statement 对应的 clauses 集合当中. 当请求实际执行时,会取出 clauses 集合,拼接生成完整的 sql 用于执行.
clause本身是一个抽象的interface

// Interface clause interface
type Interface interface {
    // clause 名称
    Name() string
    // 生成对应的 sql 部分
    Build(Builder)
    // 和同类 clause 合并
    MergeClause(*Clause)
}

不同的 clause 有不同的实现类,我们以 SELECT 为例进行展示:

type Select struct {
    // 使用使用 distinct 模式
    Distinct   bool
    // 是否 select 查询指定的列,如 select id,name
    Columns    []Column
    Expression Expression
}

sql语句的拼接是通过调用statement.Build方法实现的,入参对应的是 crud 中某一类 processor 的 BuildClauses.

func (stmt *Statement) Build(clauses ...string) {
    var firstClauseWritten bool


    for _, name := range clauses {
        if c, ok := stmt.Clauses[name]; ok {
            if firstClauseWritten {
                stmt.WriteByte(' ')
            }


            firstClauseWritten = true
            if b, ok := stmt.DB.ClauseBuilders[name]; ok {
                b(c, stmt)
            } else {
                c.Build(stmt)
            }
        }
    }
}

2.4.3 小结

综上我们小结first查询方法:

  1. 通过limit和order追加clause,追加条件过程如下 :
  • 调用getInstance方法克隆出会话
  • 调用addClause方法将条件追加到statement的clauses map中
  1. 设置statement.dest
  2. 获取query类型的执行器processor,调用execute方法执行其中的fns函数链

2.5 Query方法

既然已经知道了fns方法是最终调用方法,那么它是由谁注册的?
自然是一开始的驱动

db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})

我这里就是mysql驱动注册的,看open方法,我们以查询为例搜索注册的查询函数

func Open(dsn string) gorm.Dialector {
    dsnConf, _ := mysql.ParseDSN(dsn)
    return &Dialector{Config: &Config{DSN: dsn, DSNConfig: dsnConf}}
}

func (dialector Dialector) Initialize(db *gorm.DB) (err error){

    // ...完成 crud 类操作 callback 函数的注册
    callbacks.RegisterDefaultCallbacks(db, callbackConfig)
}

func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
    queryCallback := db.Callback().Query()
    queryCallback.Register("gorm:query", Query)
    queryCallback.Register("gorm:preload", Preload)
    queryCallback.Register("gorm:after_query", AfterQuery)
    queryCallback.Clauses = config.QueryClauses
}

func (p *processor) Register(name string, fn func(*DB)) error {
    return (&callback{processor: p}).Register(name, fn)
}

func (c *callback) Register(name string, fn func(*DB)) error {
    c.name = name
    c.handler = fn
    c.processor.callbacks = append(c.processor.callbacks, c)
    return c.processor.compile()
}

顺藤摸瓜,我们找到了query函数

func Query(db *gorm.DB) {
    if db.Error == nil {
        BuildQuerySQL(db)

        if !db.DryRun && db.Error == nil {
            rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
            if err != nil {
                db.AddError(err)
                return
            }
            defer func() {
                db.AddError(rows.Close())
            }()
            gorm.Scan(rows, db, 0)
        }
    }
}
  • 先根据clauses组装sql
  • 完成sql查询类的执行,返回查询到的行数据rows
  • 将结果反序列化到dest中

2.5.1 连接池connPool

connPool 字段,其含义是连接池,和数据库的交互操作都需要依赖它才得以执行. connPool 本身是个 interface,定义如下:

type ConnPool interface {
    PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
    ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
    QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
    QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}

connPool 根据是否启用了 prepare 预处理模式,存在不同的实现类版本:

  • 在普通模式下,connPool 的实现类为 database/sql 库下的 DB 类
  • prepare 模式下,connPool 实现类型为 gorm 中定义的 PreparedStmtDB 类

prepare是什么?直观点将就是缓存,mysql5.8之后都直接抛弃了,了解就好

// prepare 模式下的 connPool 实现类. 
type PreparedStmtDB struct {
    // 各 stmt 实例. 其中 key 为 sql 模板,stmt 是对封 database/sql 中 *Stmt 的封装 
    Stmts       map[string]*Stmt
    // ...
    Mux         *sync.RWMutex
    // 内置的 ConnPool 字段通常为 database/sql 中的 *DB
    ConnPool
}

Stmt 类是 gorm 框架对 database/sql 标准库下 Stmt 类的简单封装,两者区别并不大:

type Stmt struct {
    // database/sql 标准库下的 statement
    *sql.Stmt
    // 是否处于事务
    Transaction bool
    // 标识当前 stmt 是否已初始化完成
    prepared    chan struct{}
    prepareErr  error
}

举一反三,剩下的更新/删除/插入也是如此原理

2.6 事务

db.Transaction

func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
    panicked := true


    if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
        // ...
    } else {
        // 开启事务
        tx := db.Begin(opts...)
        if tx.Error != nil {
            return tx.Error
        }


        defer func() {
            // 倘若发生错误或者 panic,则进行 rollback 回滚
            if panicked || err != nil {
                tx.Rollback()
            }
        }()


        // 执行事务内的逻辑
        if err = fc(tx); err == nil {
            panicked = false
            // 指定成功会进行 commit 操作
            return tx.Commit().Error
        }
    }


    panicked = false
    return
}
  • 调用Begin方法启动事务,克隆出一个带事务属性的会话tx
  • 以tx为参数传入闭包函数执行,根据成功与否执行提交或回滚

2.7 了解Prepare

在 PreparedStmtDB.prepare 方法中,会通过加锁 double check 的方式,创建或复用 sql 模板对应的 stmt. 创建 stmt 的操作通过调用 conn.PrepareContext 方法完成.(通常此处的 conn 为 database/sql 库下的 sql.DB)

PreparedStmtDB.prepare 方法核心流程梳理如下:

• 加读锁,然后以 sql 模板为 key,尝试从 db.Stmts map 中获取 stmt 复用

• 倘若 stmt 不存在,则加写锁 double check

• 调用 conn.PrepareContext(...) 方法,创建新的 stmt,并存放到 map 中供后续复用

完整的代码和对应的注释展示如下:

func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
    db.Mux.RLock()
    // 以 sql 模板为 key,优先复用已有的 stmt 
    if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
        db.Mux.RUnlock()
        // 并发场景下,只允许有一个 goroutine 完成 stmt 的初始化操作
        <-stmt.prepared
        if stmt.prepareErr != nil {
            return Stmt{}, stmt.prepareErr
        }


        return *stmt, nil
    }
    db.Mux.RUnlock()


    // 加锁 double check,确认未完成 stmt 初始化则执行初始化操作
    db.Mux.Lock()
    // double check
    if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
        db.Mux.Unlock()
        // wait for other goroutines prepared
        <-stmt.prepared
        if stmt.prepareErr != nil {
            return Stmt{}, stmt.prepareErr
        }


        return *stmt, nil
    }


    // 创建 stmt 实例,并添加到 stmts map 中
    cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
    db.Stmts[query] = &cacheStmt
    // 此时可以提前解锁是因为还通过 channel 保证了其他使用者会阻塞等待初始化操作完成
    db.Mux.Unlock()


    // 所有工作执行完之后会关闭 channel,唤醒其他阻塞等待使用 stmt 的 goroutine
    defer close(cacheStmt.prepared)


    // 调用 *sql.DB 的 prepareContext 方法,创建真正的 stmt
    stmt, err := conn.PrepareContext(ctx, query)
    if err != nil {
        cacheStmt.prepareErr = err
        db.Mux.Lock()
        delete(db.Stmts, query)
        db.Mux.Unlock()
        return Stmt{}, err
    }


    db.Mux.Lock()
    cacheStmt.Stmt = stmt
    db.PreparedSQL = append(db.PreparedSQL, query)
    db.Mux.Unlock()


    return cacheStmt,nil
}

在 prepare 模式下,查询操作通过 PreparedStmtDB.QueryContext(...) 方法实现. 首先通过 PreparedStmtDB.prepare(...) 方法尝试复用 stmt,然后调用 stmt.QueryContext(...) 执行查询操作.同理执行流程也大差不多

相关文章

网友评论

      本文标题:go源码走读-gorm

      本文链接:https://www.haomeiwen.com/subject/jetrhjtx.html