snowflake算法可以指定各域位数的改进版

liuyongshuai · 2018-02-01 17:52:29 · 2785 次点击 · 预计阅读时间 10 分钟 · 大约8小时之前 开始浏览    
这是一个创建于 2018-02-01 17:52:29 的文章,其中的信息可能已经有所发展或是发生改变。

snowFlake算法在生成ID时特别高效,可参考:https://segmentfault.com/a/1190000011282426

它可以保证:

  • 所有生成的id按时间趋势递增
  • 整个分布式系统内不会产生重复id(因为有datacenterId和workerId来做区分) 但在在某下场影下dataCenterId、workerId并不需要占那么多的位,或是机器没那么多。自己就写了一个各个域的位可以自定义设置的。

https://github.com/liuyongshuai/goutils/

/**
 * @author      Liu Yongshuai<liuyongshuai@hotmail.com>
 * @package     goSnowFlake
 * @date        2018-01-25 19:19
 */
package goSnowFlake

import (
    "sync"
    "fmt"
    "time"
)

/**
详见测试用例:go test -test.run TestNewIDGenerator
*/

//SnowFlake的结构体
type snowFlakeIdGenerator struct {
    workerId           int64 //当前的workerId
    workerIdAfterShift int64 //移位后的workerId,可直接跟时间戳、序号取位或操作
    lastMsTimestamp    int64 //上一次用的时间戳
    curSequence        int64 //当前的序号

    timeBitSize     uint8 //时间戳占的位数,默认为41位,最大不超过60位
    workerIdBitSize uint8 //workerId占的位数,默认10,最大不超过60位
    sequenceBitSize uint8 //序号占的位数,默认12,最大不超过60位

    lock       *sync.Mutex //同步用的
    isHaveInit bool        //是否已经初始化了

    maxWorkerId        int64 //workerId的最大值,初始化时计算出来的
    maxSequence        int64 //最后序列号最大值,初始化时计算出来的
    workerIdLeftShift  uint8 //生成的workerId只取最低的几位,这里要左移,给序列号腾位,初始化时计算出来的
    timestampLeftShift uint8 //生成的时间戳左移几位,给workId、序列号腾位,初始化时计算出来的
}

//实例化一个ID生成器
func NewIDGenerator() *snowFlakeIdGenerator {
    return &snowFlakeIdGenerator{
        workerId:           0,
        lastMsTimestamp:    0,
        curSequence:        0,
        timeBitSize:        41, //默认的时间戳占的位数
        workerIdBitSize:    10, //默认的workerId占的位数
        sequenceBitSize:    12, //默认的序号占的位数
        maxWorkerId:        0,  //最大的workerId,初始化时计算出来的
        maxSequence:        0,  //最大的序号值,初始化的时计算出来的
        workerIdLeftShift:  0,  //worker id左移位数
        timestampLeftShift: 0,
        lock:               new(sync.Mutex),
        isHaveInit:         false,
    }
}

//设置worker id
func (sfg *snowFlakeIdGenerator) SetWorkerId(w int64) *snowFlakeIdGenerator {
    sfg.lock.Lock()
    defer sfg.lock.Unlock()
    sfg.isHaveInit = false
    sfg.workerId = w
    return sfg
}

//设置时间戳占的位数
func (sfg *snowFlakeIdGenerator) SetTimeBitSize(n uint8) *snowFlakeIdGenerator {
    sfg.lock.Lock()
    defer sfg.lock.Unlock()
    sfg.isHaveInit = false
    sfg.timeBitSize = n
    return sfg
}

//设置worker id占的位数
func (sfg *snowFlakeIdGenerator) SetWorkerIdBitSize(n uint8) *snowFlakeIdGenerator {
    sfg.lock.Lock()
    defer sfg.lock.Unlock()
    sfg.isHaveInit = false
    sfg.workerIdBitSize = n
    return sfg
}

//设置序号占的位数
func (sfg *snowFlakeIdGenerator) SetSequenceBitSize(n uint8) *snowFlakeIdGenerator {
    sfg.lock.Lock()
    defer sfg.lock.Unlock()
    sfg.isHaveInit = false
    sfg.sequenceBitSize = n
    return sfg
}

//初始化操作
func (sfg *snowFlakeIdGenerator) Init() (*snowFlakeIdGenerator, error) {
    sfg.lock.Lock()
    defer sfg.lock.Unlock()

    //如果已经初始化了
    if sfg.isHaveInit {
        return sfg, nil
    }

    if sfg.sequenceBitSize < 1 || sfg.sequenceBitSize > 60 {
        return nil, fmt.Errorf("Init failed:\tinvalid sequence bit size, should (1,60)")
    }
    if sfg.timeBitSize < 1 || sfg.timeBitSize > 60 {
        return nil, fmt.Errorf("Init failed:\tinvalid time bit size, should (1,60)")
    }
    if sfg.workerIdBitSize < 1 || sfg.workerIdBitSize > 60 {
        return nil, fmt.Errorf("Init failed:\tinvalid worker id bit size, should (1,60)")
    }
    if sfg.workerIdBitSize+sfg.sequenceBitSize+sfg.timeBitSize != 63 {
        return nil, fmt.Errorf("Init failed:\tinvalid sum of all bit size, should eq 63")
    }

    //确定移位数
    sfg.workerIdLeftShift = sfg.sequenceBitSize
    sfg.timestampLeftShift = sfg.sequenceBitSize + sfg.workerIdBitSize

    //确定序列号及workerId最大值
    sfg.maxWorkerId = -1 ^ (-1 << sfg.workerIdBitSize)
    sfg.maxSequence = -1 ^ (-1 << sfg.sequenceBitSize)

    //移位之后的workerId,返回结果时可直接跟时间戳、序号取或操作即可
    sfg.workerIdAfterShift = sfg.workerId << sfg.workerIdLeftShift

    //判断当前的workerId是否合法
    if sfg.workerId > sfg.maxWorkerId {
        return nil, fmt.Errorf("Init failed:\tinvalid worker id, should not greater than %d", sfg.maxWorkerId)
    }

    //初始化完毕
    sfg.isHaveInit = true
    sfg.lastMsTimestamp = 0
    sfg.curSequence = 0
    return sfg, nil
}

//生成时间戳,根据bit size设置取高几位
//即,生成的时间戳先右移几位,再左移几位,就保留了最高的指定位数
func (sfg *snowFlakeIdGenerator) genTs() int64 {
    rawTs := time.Now().UnixNano()
    diff := 64 - sfg.timeBitSize
    ret := (rawTs >> diff) << diff
    return ret
}

//生成下一个时间戳,如果时间戳的位数较小,且序号用完时此处等待的时间会较长
func (sfg *snowFlakeIdGenerator) genNextTs(last int64) int64 {
    for {
        cur := sfg.genTs()
        if cur > last {
            return cur
        }
    }
}

//生成下一个ID
func (sfg *snowFlakeIdGenerator) NextId() (int64, error) {
    sfg.lock.Lock()
    defer sfg.lock.Unlock()

    //如果还没有初始化
    if !sfg.isHaveInit {
        return 0, fmt.Errorf("Gen NextId failed:\tplease execute Init() first")
    }

    //先判断当前的时间戳,如果比上一次的还小,说明出问题了
    curTs := sfg.genTs()
    if curTs < sfg.lastMsTimestamp {
        return 0, fmt.Errorf("Gen NextId failed:\tunknown error, the system clock occur some wrong")
    }

    //如果跟上次的时间戳相同,则增加序号
    if curTs == sfg.lastMsTimestamp {
        sfg.curSequence = (sfg.curSequence + 1) & sfg.maxSequence
        //序号又归0即用完了,重新生成时间戳
        if sfg.curSequence == 0 {
            curTs = sfg.genNextTs(sfg.lastMsTimestamp)
        }
    } else {
        //如果两个的时间戳不一样,则归0序号
        sfg.curSequence = 0
    }

    sfg.lastMsTimestamp = curTs

    //将处理好的各个位组装成一个int64型
    curTs = curTs | sfg.workerIdAfterShift | sfg.curSequence
    return curTs, nil
}

//解析生成的ID
func (sfg *snowFlakeIdGenerator) Parse(id int64) (int64, int64, int64, error) {
    //如果还没有初始化
    if !sfg.isHaveInit {
        return 0, 0, 0, fmt.Errorf("Parse failed:\tplease execute Init() first")
    }

    //先提取时间戳部分
    shift := sfg.sequenceBitSize + sfg.sequenceBitSize
    timestamp := (id & (-1 << shift)) >> shift

    //再提取workerId部分
    shift = sfg.sequenceBitSize
    workerId := (id & (sfg.maxWorkerId << shift)) >> shift

    //序号部分
    sequence := id & sfg.maxSequence

    //解析错误
    if workerId != sfg.workerId || workerId > sfg.maxWorkerId {
        fmt.Printf("workerBitSize=%d\tMaxWorkerId=%d\n", sfg.workerIdBitSize, sfg.maxWorkerId)
        return 0, 0, 0, fmt.Errorf("parse failed:invalid id, originWorkerId=%d\tparseWorkerId=%d\n",
            sfg.workerId, workerId)
    }
    if sequence < 0 || sequence > sfg.maxSequence {
        fmt.Printf("sequesnceBitSize=%d\tMaxSequence=%d\n", sfg.sequenceBitSize, sfg.maxSequence)
        return 0, 0, 0, fmt.Errorf("parse failed:invalid id, parseSequence=%d\n", sequence)
    }

    return timestamp, workerId, sequence, nil
}

测试代码

大约共连续生成了1亿三千多万个ID写到文件里,暂时没有发现重复的。

package goSnowFlake

import (
    "testing"
    "fmt"
    "time"
    "os"
)

func TestNewIDGenerator(t *testing.T) {
    b := "\t\t\t"
    b2 := "\t\t\t\t\t"
    d := "====================================="

    //第一个生成器
    gentor1, err := NewIDGenerator().SetWorkerId(100).Init()
    if err != nil {
        fmt.Println(err)
        t.Error(err)
    }
    //第二个生成器
    gentor2, err := NewIDGenerator().
        SetTimeBitSize(48).
        SetSequenceBitSize(10).
        SetWorkerIdBitSize(5).
        SetWorkerId(30).Init()
    if err != nil {
        fmt.Println(err)
        t.Error(err)
    }

    fmt.Printf("%s%s%s\n", d, b, d)
    fmt.Printf("workerId=%d lastTimestamp=%d %s workerId=%d lastTimestamp=%d\n",
        gentor1.workerId, gentor1.lastMsTimestamp, b,
        gentor2.workerId, gentor2.lastMsTimestamp)
    fmt.Printf("sequenceBitSize=%d timeBitSize=%d %s sequenceBitSize=%d timeBitSize=%d\n",
        gentor1.sequenceBitSize, gentor1.timeBitSize, b,
        gentor2.sequenceBitSize, gentor2.timeBitSize)
    fmt.Printf("workerBitSize=%d sequenceBitSize=%d %s workerBitSize=%d sequenceBitSize=%d\n",
        gentor1.workerIdBitSize, gentor1.sequenceBitSize, b,
        gentor2.workerIdBitSize, gentor2.sequenceBitSize)
    fmt.Printf("%s%s%s\n", d, b, d)

    var ids []int64
    for i := 0; i < 100; i++ {
        id1, err := gentor1.NextId()
        if err != nil {
            fmt.Println(err)
            return
        }
        id2, err := gentor2.NextId()
        if err != nil {
            fmt.Println(err)
            return
        }
        ids = append(ids, id2)
        fmt.Printf("%d%s%d\n", id1, b2, id2)
    }

    //解析ID
    for _, id := range ids {
        ts, workerId, seq, err := gentor2.Parse(id)
        fmt.Printf("id=%d\ttimestamp=%d\tworkerId=%d\tsequence=%d\terr=%v\n",
            id, ts, workerId, seq, err)
    }
}

//多线程测试
func TestSnowFlakeIdGenerator_MultiThread(t *testing.T) {
    f := "./snowflake.txt"
    //准备写入的文件
    fp, err := os.OpenFile(f, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0755)
    if err != nil {
        fmt.Println(err)
        t.Error(err)
    }

    //初始化ID生成器,采用默认参数
    gentor, err := NewIDGenerator().SetWorkerId(100).Init()
    if err != nil {
        fmt.Println(err)
        t.Error(err)
    }

    //启动10个线程,出错就报出来
    for i := 0; i < 10; i++ {
        go func() {
            for {
                gid, err := gentor.NextId()
                if err != nil {
                    panic(err)
                }
                n, err := fp.WriteString(fmt.Sprintf("%d\n", gid))
                if err != nil || n <= 0 {
                    panic(err)
                }
            }
        }()
    }
    time.Sleep(10 * time.Second)
    //time.Sleep(600 * time.Second)
}

有疑问加站长微信联系(非本文作者))

入群交流(和以上内容无关):加入Go大咖交流群,或添加微信:liuxiaoyan-s 备注:入群;或加QQ群:692541889

2785 次点击  
加入收藏 微博
暂无回复
添加一条新回复 (您需要 登录 后才能回复 没有账号 ?)
  • 请尽量让自己的回复能够对别人有帮助
  • 支持 Markdown 格式, **粗体**、~~删除线~~、`单行代码`
  • 支持 @ 本站用户;支持表情(输入 : 提示),见 Emoji cheat sheet
  • 图片支持拖拽、截图粘贴等方式上传