本文主要讲解context库实现,以及如何实现类似的gls(golang local storage)
相关代码文件:
src/context/context.go
src/runtime/runtime2.go
src/runtime/proc.go
src/runtime/stubs.go
参考文献:https://golang.org/doc/asm
测试代码:buptbill220/go-test
背景:先思考一个问题:如果要在一个请求处理过程中,存在多个goroutine处理该请求,一般会有共享当前请求相关的上下文数据的需求,如果让你来实现,会怎么做???
context作用
// Incoming requests to a server should create a Context, and outgoing
// calls to servers should accept a Context. The chain of function
// calls between them must propagate the Context, optionally replacing
// it with a derived Context created using WithCancel, WithDeadline,
// WithTimeout, or WithValue. When a Context is canceled, all
// Contexts derived from it are also canceled.
根据包头部注解,作用如下:
1:对于一个请求(rpc/http等io请求或其他服务),需要产生一个context用于管理各个调用链(goroutine)
2:context可用于存储该请求上下文数据以供各个链环节共享数据
context实现
上图解释了context和goroutine之间的关系,以及context之间的交互
使用特点如下:
1:对于一个请求来说,新生成一个context。如果要共享数据,可设置key,val存储共享数据,调用WithValue函数;如果要控制流程WithCancel,或者WithDeadline或WithTimeout
2:整个context链是树形结构,当前节点存储parent context节点
3:如果中间链节点需要取共享数据,优先从当前context查找,否则递归从parent节点查找
4:整个context链是propagate(类似js冒泡):当前节点被cancel,那么所有子context都必须cancel(自上向下冒泡)
从context类型上看,主要使用的由3种context:
1:共享数据类型的valueContext
2:带cancel类型的cancelContext(手动cancel)
3:带timer类型的timerContext(定时器cancel context)
各个原型定义分别如下:
context接口
type Context interface {
Deadline() (deadline time.Time, ok bool)
Done() <-chan struct{}
Value(key interface{}) interface{}
}
常用实例化context
1:数据共享Context
type valueCtx struct {
Context
key, val interface{}
}
func (c *valueCtx) Value(key interface{}) interface{} {
if c.key == key {
return c.val
}
return c.Context.Value(key)
}
2:支持propagate cancle的value context
// A cancelCtx can be canceled. When canceled, it also cancels any children
// that implement canceler.
type cancelCtx struct {
Context
mu sync.Mutex // protects following fields
done chan struct{} // created lazily, closed by first cancel call
children map[canceler]struct{} // set to nil by the first cancel call
err error // set to non-nil by the first cancel call
}
3:支持超时的cancel context
// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
// implement Done and Err. It implements cancel by stopping its timer then
// delegating to cancelCtx.cancel.
type timerCtx struct {
cancelCtx
timer *time.Timer // Under cancelCtx.mu.
deadline time.Time
}
定义上简单明了,没什么过多解释。
在这里贴下关键代码关于cancel和timer的主要实现,有助于理解细节
cancel的自顶向下冒泡实现
func (c *cancelCtx) cancel(removeFromParent bool, err error) {
if err == nil {
panic("context: internal error: missing cancel error")
}
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return // already canceled
}
c.err = err
if c.done == nil {
c.done = closedchan
} else {
close(c.done)
}
for child := range c.children {
// NOTE: acquiring the child's lock while holding parent's lock.
child.cancel(false, err)
}
c.children = nil
c.mu.Unlock()
if removeFromParent {
removeChild(c.Context, c)
}
}
带timer的cancel实现
c := &timerCtx{
cancelCtx: newCancelCtx(parent),
deadline: d,
}
propagateCancel(parent, c)
dur := time.Until(d)
if dur <= 0 {
c.cancel(true, DeadlineExceeded) // deadline has already passed
return c, func() { c.cancel(true, Canceled) }
}
c.mu.Lock()
defer c.mu.Unlock()
if c.err == nil {
c.timer = time.AfterFunc(dur, func() {
c.cancel(true, DeadlineExceeded)
})
}
return c, func() { c.cancel(true, Canceled) }
我们看下cancel和timer的接口的异同点
func WithDeadline(parent Context, d time.Time) (Context, CancelFunc)
func WithCancel(parent Context) (ctx Context, cancel CancelFunc)
2者接口都是返回一个context和一个cancel callback,使用代码如下
1:
ctx, cancel := context.WithCancel(ctx, time.Duration(time.Millisecond))
defer cancel()
2:
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(Duration(time.Millisecond)))
//defer cancel()
可以看到纯cacelContext必须手动调用返回的cancel callback;对于timerContext来说,这个不是必须的(timer超时过后会被动调用cancel)
context缺点
其实context作为共享数据来说,实现上并不好,
1:每次调用需要新建一个context
2:取数据时,会存在递归遍历,效率低下
3:多goroutine之间,如果共享必须把parent context作为参数向下传递,很麻烦?
回到最前面的问题,有没有更好的方式实现多个goroutine之间共享数据呢(同一请求)?
我们可以转换成:如果每个请求有唯一的id,每次取数直接用该id取对应上下文数据,从而达到共享数据目的?
那么如何对于一个请求来说如何得到它的唯一id呢?
答案当然是goroutine id
如何实现获取goroutine id呢?
如何获取goroutine id
这里我介绍3中比较简便的方法来实现
首先从源码里我们可以得到goroutine结构定义
type stack struct {
lo uintptr
hi uintptr
}
type g struct {
stack stack // offset known to runtime/cgo
stackguard0 uintptr // offset known to liblink
stackguard1 uintptr // offset known to liblink
_panic *_panic // innermost panic - offset known to liblink
_defer *_defer // innermost defer
m *m // current m; offset known to arm liblink
sched gobuf
syscallsp uintptr // if status==Gsyscall, syscallsp = sched.sp to use during gc
syscallpc uintptr // if status==Gsyscall, syscallpc = sched.pc to use during gc
stktopsp uintptr // expected sp at top of stack, to check in traceback
param unsafe.Pointer // passed parameter on wakeup
atomicstatus uint32
stackLock uint32 // sigprof/scang lock; TODO: fold in to atomicstatus
goid int64
......
......
}
第一种方法,直接该源码即可
1:修改源码获取
// getg returns the pointer to the current g.
// The compiler rewrites calls to this function into instructions
// that fetch the g directly (from TLS or from the dedicated register).
func getg() *g
我们只需要提供一个内核方法
func GetG() int64 {
return getg().goid
}
明显的缺点是:不可移植
2:汇编获取
代码放这:buptbill220/go-test
核心实现:
TEXT ·getg(SB), NOSPLIT, $0-32
MOVQ TLS, CX
MOVQ 0(TLS), BX
MOVQ (CX)(TLS), DX
MOVQ 0(CX)(TLS*1), AX
MOVQ AX, ret+0(FP)
MOVQ BX, ret+8(FP)
MOVQ CX, ret+16(FP)
MOVQ DX, ret+24(FP)
RET
主要原理是使用TLS伪寄存器,该寄存器存储当前goroutine g结构地址;
具体实现在call.go里有介绍
这里只贴图,图中guid即是每个goroutine id(从5开始是因为还有系统goroutine)
该方法高效简单,很实用
3:伪goroutine id实现
这里再介绍第三种方法:我们不一定要获取goroutine id,只需要保证一个唯一id。那么我们就可以针对每个请求生成一个伪goroutine id;
实现思想:
通常处理流程:主入口-->业务入口
优化处理流程:主入口-->(p0-->p1-->p2-->p3)-->业务入口,做法就是在主入口和业务入口之间插入4个同作用函数,但是入口地址不一样,p0,p1,p2,p3是同一个函数地址的不同排列组合;
具体流程如下:
1: 维护一个全局map,每一个请求进入主入口按序号递增生成id,并转换成p0,p1,p2,p3排列顺序进入业务入口,
2: 在业务代码中,一步一步向上获取p0,p1,p2,p3函数排列顺序,并转换成id
3: 业务执行完,退出主入口,回收id以便重复使用
如下图:
这里我们用到2个系统函数:
pc, _, _, ok := runtime.Caller(n)
fpc := runtime.FuncForPC(pc).Entry()
该函数作用是获取从当前函数向上第n个父函数调用入口地址。下面你将看到它的妙用
在业务函数里的任一步骤里,我们可以使用这2函数,最终肯定能获取p0,p1,p2,p3顺序,该顺序就可以组成一个排列组合,如何转成数字id?使用每个p0,p1,p2,p3函数地址做位运算即可。
核心代码如下:
package gls
import (
"reflect"
"runtime"
)
type flagFunc func(rem uint64, cb func())
var fs []flagFunc
func initFlagFuncs() {
fs = [256]flagFunc{
func(rem uint64, cb func()) { if rem == 0 { cb() } else { fs[rem & 0xff](rem >> 8, cb) } }, // 00
func(rem uint64, cb func()) { if rem == 0 { cb() } else { fs[rem & 0xff](rem >> 8, cb) } },
......// 总共256个
}
var startPc uintptr
var pcToN = make(map[uintptr]uint64, 256)
func SetGID(gid uint64, cb func()) {
if gid == 0 {
cb()
} else {
fs[gid&0xff](gid>>8, cb)
}
}
func init() {
initFlagFuncs()
for i := uint64(0); i < 256; i++ {
pc := reflect.ValueOf(fs[i]).Pointer()
pcToN[pc] = i
}
startPc = reflect.ValueOf(SetGID).Pointer()
}
func GetGID() uint64 {
var ret uint64 = 0
for i := 1; ; i++ {
pc, _, _, ok := runtime.Caller(i)
if ! ok {
break
}
fpc := runtime.FuncForPC(pc).Entry()
n, ok := pcToN[fpc]
if ok {
ret <<= 8
ret += n
}
if fpc == startPc {
break
}
}
return ret
}
参考实现:https://github.com/xiezhenye/gls,gls.go
有疑问加站长微信联系(非本文作者)