Golang自定义基于gin框架的Session中间件

FredricZhu · · 1053 次点击 · · 开始浏览    
这是一个创建于 的文章,其中的信息可能已经有所发展或是发生改变。

工程结构如下


image.png

原理主要是利用了cookie来保存sessionID。使用sessionID来获取每个用户对应的Session。
main.go测试代码

package main

import (
    "fmt"
    "log"
    "net/http"

    "github.com/gin-gonic/gin"
    "github.com/zhuge20100104/gin_session/gsession"
)

func main() {
    r := gin.Default()
    mgrObj, err := gsession.CreateSessionMgr(gsession.Redis, "localhost:6379")
    if err != nil {
        log.Fatalf("Create manager obj failed, err: %v\n", err)
        return
    }
    sm := gsession.SessionMiddleware(mgrObj, gsession.Options{
        Path:     "/",
        Domain:   "127.0.0.1",
        MaxAge:   120,
        Secure:   false,
        HttpOnly: true,
    })
    r.Use(sm)
    r.GET("/incr", func(c *gin.Context) {
        session := c.MustGet("session").(gsession.Session)
        fmt.Printf("%#v\n", session)
        var count int
        v, err := session.Get("count")
        if err != nil {
            log.Printf("get count from session failed, err: %v\n", err)
            count = 0
        } else {
            count = v.(int)
            count++
        }
        session.Set("count", count)
        session.Save()
        c.String(http.StatusOK, "count:%v", count)
    })
    r.Run()
}

session.go

package gsession

import (
    "fmt"
    "log"

    "github.com/gin-gonic/gin"
)

type SessionMgrType string

const (
    // SessionID在cookie里面的名字
    SessionCookieName = "session_id"
    // Session对象在Context里面的名字
    SessionContextName                = "session"
    Memory             SessionMgrType = "memory"
    Redis              SessionMgrType = "redis"
)

// Session 接口
type Session interface {
    // 获取Session对象的ID
    ID() string
    // 加载redis数据到 session data
    Load() error
    // 获取key对应的value值
    Get(string) (interface{}, error)
    // 设置key对应的value值
    Set(string, interface{})
    // 删除key对应的value值
    Del(string)
    // 落盘数据到redis
    Save()
    // 设置Redis数据过期时间,内存版本无效
    SetExpired(int)
}

// SessionMgr Session管理器对象
type SessionMgr interface {
    // 初始化Redis数据库连接
    Init(addr string, options ...string) error
    // 通过SessionID获取已经初始化的Session对象
    GetSession(string) (Session, error)
    // 创建一个新的Session对象
    CreateSession() Session
    // 使用SessionID清空一个Session对象
    Clear(string)
}

// Options Cookie对应的相关选项
type Options struct {
    Path   string
    Domain string
    // Cookie中的SessionID存活时间
    // MaxAge=0 means no 'Max-Age' attribute specified.
    // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'.
    // MaxAge>0 means Max-Age attribute present and given in seconds.
    MaxAge   int
    Secure   bool
    HttpOnly bool
}

func CreateSessionMgr(name SessionMgrType, addr string, options ...string) (sm SessionMgr, err error) {
    switch name {
    case Memory:
        sm = NewMemSessionMgr()
    case Redis:
        sm = NewRedisSessionMgr()
    default:
        err = fmt.Errorf("unsupported %v\n", name)
        return
    }
    err = sm.Init(addr, options...)
    return
}

func SessionMiddleware(sm SessionMgr, options Options) gin.HandlerFunc {
    return func(c *gin.Context) {
        var session Session
        // 尝试从cookie获取session ID
        sessionID, err := c.Cookie(SessionCookieName)
        if err != nil {
            log.Printf("get session_id from cookie failed, err:%v\n", err)
            session = sm.CreateSession()
            sessionID = session.ID()
        } else {
            log.Printf("SessionId: %v\n", sessionID)
            session, err = sm.GetSession(sessionID)
            if err != nil {
                log.Printf("Get session by %s failed, err: %v\n", sessionID, err)
                session = sm.CreateSession()
                sessionID = session.ID()
            }
        }

        session.SetExpired(options.MaxAge)
        c.Set(SessionContextName, session)
        c.SetCookie(SessionCookieName, sessionID, options.MaxAge, options.Path, options.Domain, options.Secure, options.HttpOnly)
        defer sm.Clear(sessionID)
        c.Next()
    }
}

memory.go

package gsession

import (
    "fmt"
    "sync"

    uuid "github.com/satori/go.uuid"
)

// memSession 内存对应的Session对象
type memSession struct {
    // 全局唯一标识的session id对象
    id string
    // session数据
    data map[string]interface{}
    // session过期时间
    expired int
    // 读写锁,支持多线程
    rwLock sync.RWMutex
}

func NewMemSession(id string) *memSession {
    return &memSession{
        id:   id,
        data: make(map[string]interface{}, 8),
    }
}

func (m *memSession) ID() string {
    return m.id
}

func (m *memSession) Load() (err error) {
    return
}

func (m *memSession) Get(key string) (value interface{}, err error) {
    m.rwLock.RLock()
    defer m.rwLock.RUnlock()
    value, ok := m.data[key]
    if !ok {
        err = fmt.Errorf("Invalid key")
        return
    }
    return
}

func (m *memSession) Set(key string, value interface{}) {
    m.rwLock.Lock()
    defer m.rwLock.Unlock()
    m.data[key] = value
}

func (m *memSession) Del(key string) {
    m.rwLock.Lock()
    defer m.rwLock.Unlock()
    delete(m.data, key)
}

func (m *memSession) Save() {
    return
}

func (m *memSession) SetExpired(expired int) {
    m.expired = expired
}

// MemSessionMgr 内存Session管理器
type MemSessionMgr struct {
    session map[string]Session
    rwLock  sync.RWMutex
}

// NewMemSessionMgr MemSessionMgr类构造函数
func NewMemSessionMgr() *MemSessionMgr {
    return &MemSessionMgr{
        session: make(map[string]Session, 1024),
    }
}

func (m *MemSessionMgr) Init(addr string, options ...string) (err error) {
    return
}

// GetSession get the session by session id
func (m *MemSessionMgr) GetSession(sessionID string) (sd Session, err error) {
    m.rwLock.RLock()
    defer m.rwLock.RUnlock()
    sd, ok := m.session[sessionID]
    if !ok {
        err = fmt.Errorf("Invalid session id")
        return
    }
    return
}

func (m *MemSessionMgr) CreateSession() (sd Session) {
    sessionID := uuid.NewV4().String()
    sd = NewMemSession(sessionID)
    m.session[sd.ID()] = sd
    return
}

func (m *MemSessionMgr) Clear(sessionID string) {
    m.rwLock.Lock()
    defer m.rwLock.Unlock()
    delete(m.session, sessionID)
}

redis.go

package gsession

import (
    "bytes"
    "encoding/gob"
    "fmt"
    "log"
    "strconv"
    "sync"
    "time"

    "github.com/go-redis/redis"
    uuid "github.com/satori/go.uuid"
)

// redisSession redis session对象
type redisSession struct {
    // redis session id 对象
    id string
    // session 数据对象
    data map[string]interface{}
    // session 数据是否有更新
    modifyFlag bool
    // 过期时间
    expired int
    rwLock  sync.RWMutex
    client  *redis.Client
}

func NewRedisSession(id string, client *redis.Client) (session Session) {
    session = &redisSession{
        id:     id,
        data:   make(map[string]interface{}, 8),
        client: client,
    }
    return
}

func (r *redisSession) ID() string {
    return r.id
}

func (r *redisSession) Load() (err error) {
    data, err := r.client.Get(r.id).Bytes()
    if err != nil {
        log.Printf("get session data from redis by %s failed, err: %v\n", r.id, err)
        return
    }

    dec := gob.NewDecoder(bytes.NewBuffer(data))
    err = dec.Decode(&r.data)
    if err != nil {
        log.Printf("gob decode session data failed, err: %v\n", err)
        return
    }
    return
}

func (r *redisSession) Get(key string) (value interface{}, err error) {
    r.rwLock.RLock()
    defer r.rwLock.RUnlock()
    value, ok := r.data[key]
    if !ok {
        err = fmt.Errorf("invalid key")
        return
    }
    return
}

func (r *redisSession) Set(key string, value interface{}) {
    r.rwLock.Lock()
    defer r.rwLock.Unlock()
    r.data[key] = value
    r.modifyFlag = true
}

func (r *redisSession) Del(key string) {
    r.rwLock.Lock()
    defer r.rwLock.Unlock()
    delete(r.data, key)
    r.modifyFlag = true
}

func (r *redisSession) SetExpired(expired int) {
    r.expired = expired
}

func (r *redisSession) Save() {
    r.rwLock.Lock()
    defer r.rwLock.Unlock()
    if !r.modifyFlag {
        return
    }
    buf := new(bytes.Buffer)
    enc := gob.NewEncoder(buf)
    err := enc.Encode(r.data)
    if err != nil {
        log.Fatalf("gob encode r.data failed, err: %v\n", err)
        return
    }

    r.client.Set(r.id, buf.Bytes(), time.Second*time.Duration(r.expired))
    log.Printf("set data %v to redis.\n", buf.Bytes())
    r.modifyFlag = false
}

// redisSessionMgr redis Session管理器对象
type redisSessionMgr struct {
    session map[string]Session
    rwLock  sync.RWMutex
    client  *redis.Client
}

// NewRedisSessionMgr Redis SessionMgr类构造函数
func NewRedisSessionMgr() *redisSessionMgr {
    return &redisSessionMgr{
        session: make(map[string]Session, 1024),
    }
}

func (r *redisSessionMgr) Init(addr string, options ...string) (err error) {
    var (
        password string
        db       int
    )
    if len(options) == 1 {
        password = options[0]
    }

    if len(options) == 2 {
        password = options[0]
        db, err = strconv.Atoi(options[1])
        if err != nil {
            log.Fatalln("invalid redis DB param")
        }
    }

    r.client = redis.NewClient(&redis.Options{
        Addr:     addr,
        Password: password,
        DB:       db,
    })

    _, err = r.client.Ping().Result()
    if err != nil {
        return
    }
    return nil
}

func (r *redisSessionMgr) GetSession(sessionID string) (sd Session, err error) {
    sd = NewRedisSession(sessionID, r.client)
    err = sd.Load()

    if err != nil {
        return
    }

    r.rwLock.RLock()
    r.session[sessionID] = sd
    r.rwLock.RUnlock()
    return
}

func (r *redisSessionMgr) CreateSession() (sd Session) {
    sessionID := uuid.NewV4().String()
    sd = NewRedisSession(sessionID, r.client)
    r.session[sd.ID()] = sd
    return
}

func (r *redisSessionMgr) Clear(sessionID string) {
    r.rwLock.Lock()
    defer r.rwLock.Unlock()
    delete(r.session, sessionID)
}

程序输出如下,


image.png

有疑问加站长微信联系(非本文作者)

本文来自:简书

感谢作者:FredricZhu

查看原文:Golang自定义基于gin框架的Session中间件

入群交流(和以上内容无关):加入Go大咖交流群,或添加微信:liuxiaoyan-s 备注:入群;或加QQ群:692541889

1053 次点击  
加入收藏 微博
暂无回复
添加一条新回复 (您需要 登录 后才能回复 没有账号 ?)
  • 请尽量让自己的回复能够对别人有帮助
  • 支持 Markdown 格式, **粗体**、~~删除线~~、`单行代码`
  • 支持 @ 本站用户;支持表情(输入 : 提示),见 Emoji cheat sheet
  • 图片支持拖拽、截图粘贴等方式上传