直接奉上代码:
// +build !test
package mongo
import (
"context"
"runtime/debug"
"sync"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type BaseEntity interface {
GetId() string
SetId(id string)
}
type PageFilter struct {
SortBy string
SortMode int8
Limit *int64
Skip *int64
Filter map[string]interface{}
RegexFiler map[string]string
}
type MongoClient struct {
Client *mongo.Client
Ctx context.Context
}
var Mongo *MongoClient
func init() {
var once sync.Once
once.Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() // bug may happen
if client, err := mongo.Connect(ctx, options.Client().ApplyURI(Conf.String("mongo::uri"))); err == nil {
Mongo = &MongoClient{}
Mongo.Ctx = ctx
Mongo.Client = client
}
})
}
func (m *MongoClient) Create(collection string, e BaseEntity) (error, string) {
var err error
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok = r.(error)
if !ok {
debug.PrintStack()
}
}
}()
collections := m.Client.Database(Conf.String("mongo::db")).Collection(collection)
e.SetId(UUID())
if cid, err := collections.InsertOne(m.Ctx, e); err == nil {
return nil, cid.InsertedID.(primitive.ObjectID).Hex()
}
return err, ""
}
func (m *MongoClient) Get(collection, id string) (err error, e BaseEntity) {
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok = r.(error)
if !ok {
debug.PrintStack()
}
}
}()
collections := m.Client.Database(Conf.String("mongo::db")).Collection(collection)
objID, _ := primitive.ObjectIDFromHex(id)
result := collections.FindOne(m.Ctx, bson.M{"_id": objID})
result.Decode(&e)
return
}
func (m *MongoClient) GetOne(collection, id string) (err error, e interface{}) {
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok = r.(error)
if !ok {
debug.PrintStack()
}
}
}()
collections := m.Client.Database(Conf.String("mongo::db")).Collection(collection)
result := collections.FindOne(m.Ctx, bson.M{"Id": id})
result.Decode(&e)
return
}
func (m *MongoClient) Count(collection string, filter PageFilter) (err error, c int64) {
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok = r.(error)
if !ok {
debug.PrintStack()
}
}
}()
if filter.RegexFiler != nil {
for k, v := range filter.RegexFiler {
filter.Filter[k] = primitive.Regex{Pattern: v, Options: ""}
}
}
collections := m.Client.Database(Conf.String("mongo::db")).Collection(collection)
collections.CountDocuments(m.Ctx, filter.Filter, &options.CountOptions{Skip: filter.Skip, Limit: filter.Limit})
return
}
func (m *MongoClient) List(collection string, filter PageFilter) (err error, e []interface{}) {
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok = r.(error)
if !ok {
debug.PrintStack()
}
}
}()
if filter.RegexFiler != nil {
for k, v := range filter.RegexFiler {
filter.Filter[k] = primitive.Regex{Pattern: v, Options: ""}
}
}
collections := m.Client.Database(Conf.String("mongo::db")).Collection(collection)
cur, err := collections.Find(m.Ctx, filter, &options.FindOptions{Limit: filter.Limit, Skip: filter.Skip, Sort: bson.M{filter.SortBy: filter.SortMode}})
defer cur.Close(m.Ctx)
if err == nil {
for cur.Next(m.Ctx) {
var e interface{}
cur.Decode(&e)
}
}
return
}
func (m *MongoClient) Delete(collection, id string) (error, bool) {
var err error
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok = r.(error)
if !ok {
debug.PrintStack()
}
}
}()
collections := m.Client.Database(Conf.String("mongo::db")).Collection(collection)
objID, _ := primitive.ObjectIDFromHex(id)
result, err := collections.DeleteOne(m.Ctx, bson.M{"_id": objID})
return err, result.DeletedCount == 1
// result, err := collections.DeleteMany(ctx, bson.M{"phone": primitive.Regex{Pattern: "456", Options: ""}})
}
func (m *MongoClient) Modify(collection string, e BaseEntity) (error, bool) {
var err error
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok = r.(error)
if !ok {
debug.PrintStack()
}
}
}()
collections := m.Client.Database(Conf.String("mongo::db")).Collection(collection)
// collections.UpdateOne
// collections.UpdateMany
objID, _ := primitive.ObjectIDFromHex(e.GetId())
result, err := collections.ReplaceOne(m.Ctx, bson.M{"_id": objID}, e)
return err, result.ModifiedCount == 1
}
以一个简单的model为例,需要先实现BaseEntity接口,定义一个名为Test的model,如下:
type Test struct {
Id string
Name string
Creator string
CreateAt int64
}
func (e Test) GetId() string {
return e.Id
}
func (e *Test) SetId(id string) {
e.Id = id
}
调用上面定义好的client进行数据库操作,下面以创建一条记录为例:
var t Test
Mongo.Create("LicenseReview", &t)
有疑问加站长微信联系(非本文作者)