在上一篇文章,我分享了自己在新增和更新的场景下,自己使用gorm的一些心得和扩展。本文,我将分享一些在查询的方面的心得。
首先,我把查询按照涉及到的表的数量分为:
- 单表查询
- 多表查询
按照查询范围又可以分为:
- 查询一个
- 范围查询
- 查询一组
- 有序查询
- 查询前几个
- 分页查询
在日常使用中,单表查询占据了多半的场景,把这部分的代码按照查询范围做一些封装,可以大大减少冗余的代码。
单表查询
于是,我仿照gorm API的风格,做了如下的封装:
ps:以下例子均以假定已定义user对象
查询一个
func (dw *DBExtension) GetOne(result interface{}, query interface{}, args ...interface{}) (found bool, err error) {
var (
tableNameAble TableNameAble
ok bool
)
if tableNameAble, ok = query.(TableNameAble); !ok {
if tableNameAble, ok = result.(TableNameAble); !ok {
return false, errors.New("neither the query nor result implement TableNameAble")
}
}
err = dw.Table(tableNameAble.TableName()).Where(query, args...).First(result).Error
if err == gorm.ErrRecordNotFound {
dw.logger.LogInfoc("mysql", fmt.Sprintf("record not found for query %s, the query is %+v, args are %+v", tableNameAble.TableName(), query, args))
return false, nil
}
if err != nil {
dw.logger.LogErrorc("mysql", err, fmt.Sprintf("failed to query %s, the query is %+v, args are %+v", tableNameAble.TableName(), query, args))
return false, err
}
return true, nil
}
复制代码
这段值得说明的就是对查询不到数据时的处理,gorm是报了gorm.ErrRecordNotFound的error, 我是对这个错误做了特殊处理,用found这个boolean值表述这个特殊状态。
调用代码如下:
condition := User{Id:1}
result := User{}
if found, err := dw.GetOne(&result, condition); !found {
//not found
if err != nil {
// has error
return err
}
}
复制代码
也可以这样写,更加灵活的指定的查询条件:
result := User{}
if found, err := dw.GetOne(&result, "id = ?", 1); !found {
//not found
if err != nil {
// has error
return err
}
}
复制代码
两种写法执行的语句都是:
select * from test.user where id = 1
复制代码
范围查询
针对四种范国查询,我做了如下封装:
func (dw *DBExtension) GetList(result interface{}, query interface{}, args ...interface{}) error {
return dw.getListCore(result, "", 0, 0, query, args)
}
func (dw *DBExtension) GetOrderedList(result interface{}, order string, query interface{}, args ...interface{}) error {
return dw.getListCore(result, order, 0, 0, query, args)
}
func (dw *DBExtension) GetFirstNRecords(result interface{}, order string, limit int, query interface{}, args ...interface{}) error {
return dw.getListCore(result, order, limit, 0, query, args)
}
func (dw *DBExtension) GetPageRangeList(result interface{}, order string, limit, offset int, query interface{}, args ...interface{}) error {
return dw.getListCore(result, order, limit, offset, query, args)
}
func (dw *DBExtension) getListCore(result interface{}, order string, limit, offset int, query interface{}, args []interface{}) error {
var (
tableNameAble TableNameAble
ok bool
)
if tableNameAble, ok = query.(TableNameAble); !ok {
// type Result []*Item{}
// result := &Result{}
resultType := reflect.TypeOf(result)
if resultType.Kind() != reflect.Ptr {
return errors.New("result is not a pointer")
}
sliceType := resultType.Elem()
if sliceType.Kind() != reflect.Slice {
return errors.New("result doesn't point to a slice")
}
// *Item
itemPtrType := sliceType.Elem()
// Item
itemType := itemPtrType.Elem()
elemValue := reflect.New(itemType)
elemValueType := reflect.TypeOf(elemValue)
tableNameAbleType := reflect.TypeOf((*TableNameAble)(nil)).Elem()
if elemValueType.Implements(tableNameAbleType) {
return errors.New("neither the query nor result implement TableNameAble")
}
tableNameAble = elemValue.Interface().(TableNameAble)
}
db := dw.Table(tableNameAble.TableName()).Where(query, args...)
if len(order) != 0 {
db = db.Order(order)
}
if offset > 0 {
db = db.Offset(offset)
}
if limit > 0 {
db = db.Limit(limit)
}
if err := db.Find(result).Error; err != nil {
dw.logger.LogErrorc("mysql", err, fmt.Sprintf("failed to query %s, query is %+v, args are %+v, order is %s, limit is %d", tableNameAble.TableName(), query, args, order, limit))
return err
}
return nil
}
复制代码
为了减少冗余的代码,通用的逻辑写在getListCore函数里,里面用到了一些golang反射的知识。
但只要记得golang的反射和其它语言的反射最大的不同,是golang的反射是基本值而不是类型的,一切就好理解了。
其中的一个小技巧是如何判断一个类型是否实现了某个接口,用到了指向nil的指针。
elemValue := reflect.New(itemType)
elemValueType := reflect.TypeOf(elemValue)
tableNameAbleType := reflect.TypeOf((*TableNameAble)(nil)).Elem()
if elemValueType.Implements(tableNameAbleType) {
return errors.New("neither the query nor result implement TableNameAble")
}
复制代码
关于具体的使用,就不再一一举例子了,熟悉gorm api的同学可以一眼看出。
多表查询
关于多表查询,因为不同场景很难抽取出不同,也就没有再做封装,但是我的经验是优先多使用gorm的方法,而不是自己拼sql。你想要做的gorm都可以实现。
这里,我偷个懒,贴出自己在项目中写的最复杂的一段代码,供各位看官娱乐。
一个复杂的例子
这段代码是从埋点数据的中间表,为了用通用的代码实现不同展示场景下的查询,代码设计的比较灵活,其中涉及了关联多表的查询,按查询条件动态过滤和聚合,还有分页查询的逻辑。
func buildCommonStatisticQuery(tableName, startDate, endDate string) *gorm.DB {
query := models.DB().Table(tableName)
if startDate == endDate || endDate == "" {
query = query.Where("date = ?", startDate)
} else {
query = query.Where("date >= ? and date <= ?", startDate, endDate)
}
return query
}
func buildElementsStatisticQuery(startDate, endDate, elemId string, elemType int32) *gorm.DB {
query := buildCommonStatisticQuery("spotanalysis.element_statistics", startDate, endDate)
if elemId != "" && elemType != 0 {
query = query.Where("element_id = ? and element_type = ?", elemId, elemType)
}
return query
}
func CountElementsStatistics(count *int32, startDate, endDate, instId, appId, elemId string, elemType int32, groupFields []string ) error {
query := buildElementsStatisticQuery(startDate, endDate, elemId, elemType)
query = whereInstAndApp(query, instId, appId)
if len(groupFields) != 0 {
query = query.Select(fmt.Sprintf("count(distinct(concat(%s)))", strings.Join(groupFields, ",")))
} else {
query = query.Select("count(id)")
}
query = query.Count(count)
return query.Error
}
func GetElementsStatistics(result interface{}, startDate, endDate, instId, appId, elemId string, elemType int32, groupFields []string, orderBy string, ascOrder bool, limit, offset int32) error {
query := buildElementsStatisticQuery(startDate, endDate, elemId, elemType)
if len(groupFields) != 0 {
groupBy := strings.Join(groupFields, "`,`")
groupBy = "`" + groupBy + "`"
query = query.Group(groupBy)
query = havingInstAndApp(query, instId, appId)
sumFields := strings.Join([]string{
"SUM(`element_statistics`.`mp_count`) AS `mp_count`",
"SUM(`element_statistics`.`h5_count`) AS `h5_count`",
"SUM(`element_statistics`.`total_count`) AS `total_count`",
"SUM(`element_statistics`.`collection_count`) AS `collection_count`",
"SUM(`element_statistics`.`mp_share_count`) AS `mp_share_count`",
"SUM(`element_statistics`.`h5_share_count`) AS `h5_share_count`",
"SUM(`element_statistics`.`poster_share_count`) AS `poster_share_count`",
"SUM(`element_statistics`.`total_share_count`) AS `total_share_count`",
}, ",")
query = query.Select(groupBy + "," + sumFields)
} else {
query = whereInstAndApp(query, instId, appId)
}
query = getPagedList(query, orderBy, ascOrder, limit, offset)
return query.Find(result).Error
}
func getPagedList(query *gorm.DB, orderBy string, ascOrder bool, limit , offset int32) *gorm.DB {
if orderBy != "" {
if ascOrder {
orderBy += " asc"
} else {
orderBy += " desc"
}
query = query.Order(orderBy)
}
if offset != 0 {
query = query.Offset(offset)
}
if limit != 0 {
query = query.Limit(limit)
}
return query
}
func whereInstAndApp(query *gorm.DB, instId string, appId string) *gorm.DB {
query = query.Where("inst_id = ?", instId)
if appId != "" {
query = query.Where("app_id = ?", appId)
}
return query
}
func havingInstAndApp(query *gorm.DB, instId string, appId string) *gorm.DB {
query = query.Having("inst_id = ?", instId)
if appId != "" {
query = query.Having("app_id = ?", appId)
}
return query
}
复制代码
感谢各位看官耐心看完,如果本文对你有用,请点个赞~~~
如果能到代码仓库:Github:Ksloveyuan/gorm-ex 给个✩star✩, 楼主就更加感谢了!
有疑问加站长微信联系(非本文作者)