如何快速的随机从 map 中返回一个值

胡大海 · · 14273 次点击 · · 开始浏览    
这是一个创建于 的文章,其中的信息可能已经有所发展或是发生改变。

前言

本文翻译自 lukechampine.com/hackmap.htm… go 的 map 源码解析都会引用的一片文章。

第一部分:问题

从一个切片中随机的获取一个值是非常简单的,可以使用map.Int(n),这样就可以从[0,n)中随机的返回一个值,从而可以从切片中随机的返回一个元素

func randSliceValue(xs []string) string {
	return xs[rand.Intn(len(xs))]
}
复制代码

这个方法是非常好的,因为耗费的时间和内存都是O(1)的。但是呢,对于 map 而言,没有简单并且等效的方式来做到随机从 map 中获取一个值。我们获取 map 中的数据有两种方式,取值(e.g. m["foo"])或者range。所以,如何根据这两种方式和一个随机的值来在 map 中随机的返回一个 key 呢?(注意,既然可以随机的获取一个 key,那么根据此 key 获得的 value 也相当于是随机获得的了)

一种方式是我们可以把 map 展开,然后复用切片中随机获取一个值的方法

func randMapKey(m map[string]int) string {
	mapKeys = make([]string, 0, len(m)) // pre-allocate exact size
	for key := range m {
		mapKeys = append(mapKeys, key)
	}
	return mapKeys[rand.Intn(len(mapKeys))]
}
复制代码

很容易就可以看出来这个确实可以随机的返回 key,但是这种简单的代价却是以性能为代价:时间和空间复杂度均是O(n)

一个表现稍微好的一个方式是使用rangerange会访问每对 key/value 一次,并且每次使用range的顺序并不会一致(译者注:在调用 mapiterinit 会取一个随机值来决定访问的起始位置)。我们可以把随机取得的值作为一个计数器,在每次遍历完一对键值对之后进行减一,最后在值为 0 的时候返回此时遍历的键即可:

func randMapKey(m map[string]int) string {
	r := rand.Intn(len(m))
	for k := range m {
		if r == 0 {
			return k
		}
		r--
	}
	panic("unreachable")
}
复制代码

这个操作的时间复杂度为O(n),空间复杂度为O(1),在大多数情况下都是可以接受的。但是我们会遇到一个新的问题,我们无法完成一个通用的函数来进行随机取值的操作!

第二部分:深入理解

unsafe可以允许我们转换 go 的类型。也就是说我们可以把一个变量看做类型 X,即使它的类型原来是 Y。另外比较友好的情况是这种操作不仅仅可以用来处理 go 的内置类型(比如说 strings、slices、maps),也可以用来处理用户自定义的类型。需要注意的是这是一种双向的操作,我们不仅仅可以把[]byte转化成其他类型,也可以把其他类型转化为[]byte类型。

我们可以使用这种技巧来操作对象底层的内存,就如我在 一些包 中做的一样。但是呢,我并不赞同这样做。今天我们只会使用一个方向。具体点就是我们将把map转化为局部复制为runtime definition.(不要关闭这个页面,后续会用到)。然后我们就可以直接的获取内部的数据,并且有希望实现一个快速的获取 map 中随机一个 key 的方法。

查看map.go文件,我们可以看到一种类型为hiter的结构体,并且有方法mapiterinitmapinternext。这就是当你执行range来遍历一个 map 进行的操作。hiter是一个迭代器,mapinitinit用来初始化这个迭代器,mapiternext用于寻找下个迭代器迭代的位置。现在我有一个计划:

  1. 把 map 转化为 hmap
  2. 使用方法new(hiter)来创建一个迭代器,并且使用mapiterinit来初始化
  3. 产生一个处于[0, len(m)]范围内的随机数字 n
  4. 执行 mapiternext方法 n 次
  5. 在执行到 n 为 0 的位置,返回当前的 key

这种做法的一个优势就是我们不需要理解方法mapiterinitmapiternext的细节。我们仅仅在hashmap.go中复制相关代码即可。(实际上编译器会警告缺少 runtime 相关的方法,如 atomic.Or8,我们把这些删除即可)。之后,我们仅仅需要完成关于转变变量类型和从interface{}中变换回来的代码就好了。

// runtime representation of an interface{}
type emptyInterface struct {
	typ unsafe.Pointer
	val unsafe.Pointer
}

func mapTypeAndValue(m interface{}) (*maptype, *hmap) {
	ei := (*emptyInterface)(unsafe.Pointer(&m))
	return (*maptype)(ei.typ), (*hmap)(ei.val)
}

func iterKey(it *hiter) interface{} {
	ei := emptyInterface{
		typ: unsafe.Pointer(it.t.key), // it.t is the maptype
		val: it.key,
	}
	return *(*interface{})(unsafe.Pointer(&ei))
}
复制代码

现在,我们终于可以实现一个通用的randMapKey方法,并且时间复杂度为O(n)、空间复杂度为O(1)

func randMapKey(m interface{}) interface{} {
	// get map internals
	t, h := mapTypeAndValue(m)
	// initialize iterator
	it := new(hiter)
	mapiterinit(t, h, it)
	// advance iterator a random number of times
	r := rand.Intn(h.count) // h.count == len(m)
	for i := 0; i < r; i++ {
		mapiternext(it)
	}
	// return current iterator key as an interface{}
	return iterKey(it)
}
复制代码

这个是可以运行成功的,但是也伴随着相应的代价。一个原因就是使用了unsafe包,正如其名,是不安全的。

第三部分:可以的

实际上我们是可以花费常数时间(虽然这个常数时间可以有点大)复杂度来获取 map 的随机 key 的。但是呢,为了达到常数时间复杂度的地步,我们需要更深的理解 go 中的 map 是如何运作的。

hmap实际上就是一个 bucket 数组。bucket 的数量是 2 的 h.B 次方。由于随着 map 的增长,bucket 的数量也是变化的,所以 bucket 数组在hmap中是用作unsafe.Pointer来记载的。每个 bucket 有 8 个 cell (键值对),其中有些 cell 可能没有有效的数据。并且有一个tophash数组,用于存储每个空间部分的 key 的 hash 或者此键值对的状态。bucket 的结构体类型可能会让人疑惑,整体如下:

const bucketCnt = 8 // number of cells per bucket

type bmap struct {
	tophash [bucketCnt]uint8
	// Followed by bucketCnt keys and then bucketCnt values.
	// NOTE: packing all the keys together and then all the values together makes the
	// code a bit more complicated than alternating key/value/key/value/... but it allows
	// us to eliminate padding which would be needed for, e.g., map[int64]int8.
	// Followed by an overflow pointer.
}
复制代码

正如如你看见的,里面并没有键值对的空间。这是因为键值对的类型并不知道,编译器不知道 bmap 的大小。如果 map 的类型是 map[string]int,那么你可以认为bmap会和下方代码一样:

type bmap struct {
	tophash  [bucketCnt]uint8
	keys     [bucketCnt]string
	values   [bucketCnt]int
	overflow *bmap
}
复制代码

要插入一个 key/value ,我们首先需要计算的是 key 的 hash, 其类型是uintptr。(每个类型在运行时都有一个对应的hash函数)。然后我们需要决定使用哪个 bucket 。由于 bucket 的数量永远都是 2的幂,所以我们可以使用key的经过hash函数计算后的二进制数后 h.B 位来决定使用哪个bucket:

// h is an hmap, t is a maptype
hash := t.key.alg.hash(key, uintptr(h.hash0))
bucketIndex := hash & (uintptr(1) << h.B - 1)
复制代码

在选定了 bucket 之后,我们需要找到一个可用的 cell。我们遍历每个 cell,检查 tophash数组中的每个值。tophash数组中存储的是存储于此 bucket 中每个键值对中键的hash运行结果的二进制数的前8位,在 tophash中元素值为 0 的时候,表示此 cell 是空的,没有存储任何值。当发现空的 cell 的时候,我们可以用此 cell 来存键值对,并且修改tophash数组对应位置的值。下面就是简化后的算法:

// calculate tophash
top := uint8(hash >> (unsafe.Sizeof(hash)*8 - 8))

// seek to offset of bucketIndex in h.buckets
b := (*bmap)(unsafe.Pointer(uintptr(h.buckets) + bucketIndex*uintptr(t.bucketsize)))

// iterate through the cells of b. If a tophash matches top, it means we've
// already inserted a value with this key, so overwrite it. Otherwise, store
// the key/value in the first empty cell.
for i := 0; i < bucketCnt; i++ {
	if b.tophash[i] == top {
		// overwrite the existing value
		// [ code omitted ]
		return
	} else if b.tophash[i] == 0 {
		// insert the new key/value
		// [ code omitted ]
		b.tophash[i] = top
		h.count++
		return
	}
}
复制代码

正如你期待的那样,为了获得 map 中的值,我们仅仅是重复了此过程。但是这个过程是返回 cell 中的数据而不是往 cell 中插入或者覆盖。

好了,现在我们可以改进我们的randMapKey函数了。回顾从切片中随机返回一个元素的方法,仅仅获得一个随机的索引即可。如果你仔细瞅瞅,你就会发现 h.buckets就是一个切片。是一个连续的 key/value 数组。最大的区别是一些 cell 中是空的。所以我们需要获得一个随机的索引,并且要避免这些空的 cell。

一个简单的方法就是当遍历到了空的 cell 的时候略过去就好了,直到找到一个非空的 cell。这个就是mapinterinit做的。但是这种方法有一个严重的缺陷,可以想想如下这种状况,一个 bucket 只有两个有效的 cell :

[foo] [bar] [---] [---] [---] [---] [---] [---]
复制代码

如果我们随机的选择 [0, 8) 中的一个索引,会发生什么?如果选择索引 0, 那么我们会获得foo,如果选择索引1, 我们会获得bar。但是如果我们选择其他索引,我们会接下来的往下寻找,直到折回到索引 0, 然后获取到元素foo。也就是说即使我们使用一个随机的索引,取得元素foo的概率也会是获得元素bar的 7 倍。这个对于随机的获取一个元素的目的肯定是不可行的。

幸运的是,我们有一个替代的方案:随机的获取一个索引,如果那个 cell 是空的,我们再次随机的获取一个索引即可。平均下来,我们没获取到一个非空的 cell 需要的次数为 k/n ,其中 k 是 cell 的个数,n 是非空的 cell 的个数。当然,算法必须保证随机值是均匀分布的,这样每个索引的取值都是等可能的。

让我们使用这种算法来实现randMapKey函数。我们需要一个函数,可以从 cell 中获取到 key 的值。我也会创建一个add函数,是的算法更为可读:

func add(p unsafe.Pointer, x uintptr) unsafe.Pointer {
	return unsafe.Pointer(uintptr(p) + x)
}

func cellKey(t *maptype, b *bmap, i int) interface{} {
	// dataOffset is where the cell data begins in a bmap.
	const dataOffset = unsafe.Offsetof(struct {
		tophash [bucketCnt]uint8
		cells   int64
	}{}.cells)

	k := add(unsafe.Pointer(b), dataOffset+uintptr(i)*uintptr(t.keysize))
	if t.indirectkey {
		// if the map's key type is too big, a pointer will be stored in
		// the map instead of the actual data. In that case, we need to
		// dereference the pointer.
		k = *(*unsafe.Pointer)(k)
	}

	ei := emptyInterface{
		typ: unsafe.Pointer(t.key),
		val: k,
	}
	return *(*interface{})(unsafe.Pointer(&ei))
}

func randMapKey(m interface{}) interface{} {
	// get map internals
	t, h := mapTypeAndValue(m)
	numBuckets := 1 << h.B

	// loop until we hit a valid cell
	for {
		// pick random indices
		bucketIndex := rand.Intn(numBuckets)
		cellIndex := rand.Intn(bucketCnt)

		// lookup cell
		b := (*bmap)(add(h.buckets, uintptr(bucketIndex)*uintptr(t.bucketsize)))
		if b.tophash[cellIndex] == 0 {
			// cell is empty; try again
			continue
		}
		return cellKey(t, b, cellIndex)
	}
}
复制代码

这个是可以正确的运作的 我们就该到此结束吗?

第四部分:不

其实上述的代码是有问题的。例子的代码中只有整数 [0, 8),这个是没有冲突的。如果在 map 中不同的 key 的 hash 值是相同的呢?答案:使用额外的 bucket。如果你再看看bmap 的定义你会发现,内部含有一个overflow的指针类型来指向另一个bmap。在发生冲突的时候,会分配一个新的 bucket,挂在之前的 bmap的后面。在查询的过程中,我们首先在 bucket 中的每个 cell 找到一个匹配的,如果没有匹配的,就会通过 overflow指针,在下一个 bucket 中查找,如此反复循环。

这个是如何影响randMapKey函数的呢?实际上影响不大。我们仅需增加一个维度即可。在之前,我们是从一个随机的 bucket 中随机选择一个 cell。现在,我们需要在一个 bucket 的链表中随机选择一个 bucket,然后在 bucket 中随机选择一个 cell。map 的可视化图像如下:

           bucket0   bucket1   bucket2   bucket3
overflow0 [|||||||] [|||||||] [|||||||] [|||||||]
overflow1 [|||||||]           [|||||||]
overflow2                     [|||||||]
复制代码

之前,我们仅仅从第一行的四个 bucket 中进行选择。现在我们需要在 12 个 bucket 中进行选择,即使一些 bucket 并不存在。操作和之前的遇到空的 cell 相同,如果一个 bucket 并不存在,我们从新选择一个随机值再来一次就好了。

有一点比较烦人,就是事先我们并不知道链接起来的 bucket 有多少个。hmap没有一个参数表示链接起来的 bucket 最多有多少个可以用于我们相乘。我们需要自己计算这个值,而这个方法是遍历每个 bucket 然后顺着 overflow指针不断搜索直到overflow指针是一个nil。这个比较耗费时间,但是我们没有其他方法了,代码如下:

func (b *bmap) overflow(t *maptype) *bmap {
	offset := uintptr(t.bucketsize)-unsafe.Sizeof(uintptr(0))
	return *(**bmap)(add(unsafe.Pointer(b), offset))
}

func maxOverflow(t *maptype, h *hmap) int {
	numBuckets := uintptr(1 << h.B)
	max := 0
	for i := uintptr(0); i < numBuckets; i++ {
		over := 0
		b := (*bmap)(add(h.buckets, i*uintptr(t.bucketsize)))
		for b = b.overflow(t); b != nil; over++ {
			b = b.overflow(t)
		}
		if over > max {
			max = over
		}
	}
	return max
}
复制代码

现在randMapKey函数如下:

func randMapKey(m interface{}) interface{} {
	// get map internals
	t, h := mapTypeAndValue(m)
	numBuckets := 1 << h.B
	numOver := maxOverflow(t, h) + 1 // add 1 to account for "base" bucket

	// loop until we hit a valid cell
loop:
	for {
		// pick random indices
		bucketIndex := rand.Intn(numBuckets)
		overIndex := rand.Intn(numOver)
		cellIndex := rand.Intn(bucketCnt)

		// seek to index in h.buckets
		b := (*bmap)(add(h.buckets, uintptr(bucketIndex)*uintptr(t.bucketsize)))

		// seek to index in overflow chain
		for i := 0; i < overIndex; i++ {
			b = b.overflow(t)
			if b == nil {
				// invalid bucket; try again
				continue loop
			}
		}

		// lookup cell
		if b.tophash[cellIndex] == 0 {
			// cell is empty; try again
			continue loop
		}
		return cellKey(t, b, cellIndex)
	}
}
复制代码

我们可以确定 对于含有 overflow 的 bucket 是可以运行的. 我在担心是否还有第五部分的内容

第五部分:当然还有第五部分

(这是最后一部分,我保证)

go 中的 map 是优化的非常好的,其中有一个优化是"incremental copying"。简单来讲就是在一个 map 满了,当你尝试插入一个新的元素的时候,go 会立刻分配一个新的 bucket 数组(数组长度是之前的两倍) 来存储新的 key/value。但是呢,他不会把旧的 bucket 复制到新的 bucket;而是每当你插入或者删除元素的时候,此 bucket(以及后续的通过 overflow 连接的 bucket) 会被复制(evacuated)到新的数组中去。在所有的 bucket 都移动完毕之后, h.oldbuckets会被置为nil

我确定你看到了问题所在:直到现在,我们取的 cell 的值都是来自于h.buckets。为了覆盖所有的可能,我们当然也需要检查h.oldbuckets。有三点需要我们做出改变:

  1. 当选择一个了一个 bucket 的时候,我们检查下对应的 oldbucket 是否复制过了。如果没有,我们从 oldbucket 中选择值。
  2. 当我们在 oldbucket 中选择了一个值的时候,我们需要确定此值最终需要分配到哪个 bucket 中去。如果正好是之前选择的 bucket 的时候我们返回此值即可。如果不是那么就需要从头再来了(这个为了避免重合,因为 oldbucket 迁移到会到两个可能的位置)
  3. maxOverflow需要返回h.bucketsh.oldbuckets中最长的 overflow

幸运的是,这个并不难实现,首先,修改maxOverflow函数:

func maxOverflow(t *maptype, h *hmap) int {
	numBuckets := uintptr(1 << h.B)
	max := 0
	for i := uintptr(0); i < numBuckets; i++ {
		over := 0
		b := (*bmap)(add(h.buckets, i*uintptr(t.bucketsize)))
		for b = b.overflow(t); b != nil; over++ {
			b = b.overflow(t)
		}
		if over > max {
			max = over
		}
	}

	// check oldbuckets too, if it exists
	if h.oldbuckets != nil {
		for i := uintptr(0); i < numBuckets/2; i++ {
			var over int
			b := (*bmap)(add(h.oldbuckets, i*uintptr(t.bucketsize)))
			if evacuated(b) {
				// we already counted this bucket in the first loop
				continue
			}
			for b = b.overflow(t); b != nil; over++ {
				b = b.overflow(t)
			}
			if over > max {
				max = over
			}
		}
	}
	return max
}
复制代码

然后呢,创建最新版本的randMapKey函数。当我们检查未迁移的 oldbucket 的时候,设置一个表示告诉我们检测 cell 的迁移位置:

func randMapKey(m interface{}) interface{} {
	// get map internals
	t, h := mapTypeAndValue(m)
	numBuckets := uintptr(1 << h.B)
	numOver := maxOverflow(t, h) + 1 // add 1 to account for "base" bucket

	// loop until we hit a valid cell
loop:
	for {
		// pick a random index
		bucketIndex := uintptr(rand.Intn(int(numBuckets)))
		overIndex := rand.Intn(numOver)
		cellIndex := rand.Intn(bucketCnt)

		// seek to index in h.buckets
		b := (*bmap)(add(h.buckets, bucketIndex*uintptr(t.bucketsize)))

		// if the oldbucket hasn't been evacuated, then we need to use that
		// pointer instead.
		usingOldBucket := false
		if h.oldbuckets != nil {
			numOldBuckets := numBuckets / 2
			oldBucketIndex := bucketIndex & (numOldBuckets - 1)
			oldB := (*bmap)(add(h.oldbuckets, oldBucketIndex*uintptr(t.bucketsize)))
			if !evacuated(oldB) {
				b = oldB
				usingOldBucket = true
			}
		}

		// seek to index in overflow chain
		for i := 0; i < overIndex; i++ {
			b = b.overflow(t)
			if b == nil {
				// invalid bucket; try again
				continue loop
			}
		}

		// lookup cell
		if b.tophash[cellIndex] == 0 {
			// cell is empty; try again
			continue loop
		}

		// grab key and dereference if necessary (same as cellKey)
		k := add(unsafe.Pointer(b), dataOffset+uintptr(cellIndex)*uintptr(t.keysize))
		if t.indirectkey {
			k = *(*unsafe.Pointer)(k)
		}

		// if this is an old bucket, we need to check whether this key is destined
		// for the new bucket. Otherwise, we will have a 2x bias towards oldbucket
		// values, since two different bucket selections can result in the same
		// oldbucket.
		if usingOldBucket {
			hash := t.key.alg.hash(k, uintptr(h.hash0))
			if hash&(numBuckets-1) != bucketIndex {
				// this key is destined for a different bucket
				continue loop
			}
		}

		// pack key into interface{} (same as cellKey)
		ei := emptyInterface{
			typ: unsafe.Pointer(t.key),
			val: k,
		}
		return *(*interface{})(unsafe.Pointer(&ei))
	}
}
复制代码

感觉并不坏,全部情况都考虑到了!如果你没看到最初的代码地址,那么看可以看这个完整的代码


有疑问加站长微信联系(非本文作者)

本文来自:掘金

感谢作者:胡大海

查看原文:如何快速的随机从 map 中返回一个值

入群交流(和以上内容无关):加入Go大咖交流群,或添加微信:liuxiaoyan-s 备注:入群;或加QQ群:692541889

14273 次点击  
加入收藏 微博
暂无回复
添加一条新回复 (您需要 登录 后才能回复 没有账号 ?)
  • 请尽量让自己的回复能够对别人有帮助
  • 支持 Markdown 格式, **粗体**、~~删除线~~、`单行代码`
  • 支持 @ 本站用户;支持表情(输入 : 提示),见 Emoji cheat sheet
  • 图片支持拖拽、截图粘贴等方式上传