看到golang标准库sync package WaitGroup 类型, 本以为是golang 版本的 barrier 对象实现,看到文档给出的使用示例:
var wg sync.WaitGroup var urls = []string{ "http://www.golang.org/", "http://www.google.com/", "http://www.somestupidname.com/", } for _, url := range urls { // Increment the WaitGroup counter. wg.Add(1) // Launch a goroutine to fetch the URL. go func(url string) { // Decrement the counter when the goroutine completes. defer wg.Done() // Fetch the URL. http.Get(url) }(url) } // Wait for all HTTP fetches to complete. wg.Wait()
可以看出WaitGroup 类型主要用于某个goroutine(调用Wait() 方法的那个), 等待个数不定goroutine(内部调用Done() 方法),
Add 方法对内部计数,添加或减少,Done方法其实是Add(-1);
与pthread_barrier_t 有着语义上的差别,pthread_barrier_wait() 的调用者之间互相等待,就好比5名队员(线程)参加跨栏比赛,使用 pthread_barrier_init 初始化最后一个参数为5, 五个队员都是好基友, 定了规矩, 不管谁先到栏杆, 都要等队友,直到最后一名队员跨过栏时,然后同一起步点再次出发。下面时使用pthread_barrier_t 的简单示例 5个线程,每个线程拥有一个私有数组,及增量数字:
#define _GNU_SOURCE #include <pthread.h> #include <stdio.h> #include <string.h> #include <stdlib.h> #define NTHR 5 #define NARR 6 #define INLOOPS 1000 #define OUTLOOPS 10 #define err_abort(code,text) do { \ char errbuf[128] = {0}; \ fprintf (stderr, "%s at \"%s\":%d: %s\n", \ (text), __FILE__, __LINE__, strerror_r(code,errbuf,128)); \ abort (); \ } while (0) typedef struct thrArg { pthread_t tid; int incr; int arr[NARR]; }thrArg; pthread_barrier_t barrier; thrArg thrs[NTHR]; void *thrFunc (void *arg) { thrArg *self = (thrArg*)arg; int j, i, k, status; for (i = 0; i < OUTLOOPS; i++) { status = pthread_barrier_wait (&barrier); if (status > 0) err_abort (status, "wait on barrier"); //每个线程迭代 INLOOPS 次,对自己的内部数组arr 成员加上 自己的增量值 for (j = 0; j < INLOOPS; j++) for (k = 0; k < NARR; k++) self->arr[k] += self->incr; //先执行完迭代的线程在此等待,直到最后一个到达 status = pthread_barrier_wait (&barrier); if (status > 0) err_abort (status, "wait on barrier"); //最后一个到达的线程,把所有线程的内部增量加1 //此时其他先到的线程阻塞在第一次wait调用处,所以最后一个到达的线程 //可以排他性地访问所有线程的内部状态,if 语句执行完后,跳到第一次wait处, //其他阻塞在第一次wait处的线程,得到释放,大家一块使用新的增量做计算 if (status == PTHREAD_BARRIER_SERIAL_THREAD ) { int i; for (i = 0; i < NTHR; i++) thrs[i].incr += 1; } } return NULL; } int main (int arg, char *argv[]) { int i, j; int status; pthread_barrier_init (&barrier, NULL, NTHR); for (i = 0; i < NTHR; i++) { thrs[i].incr = i; for (j = 0; j < NARR; j++) thrs[i].arr[j] = j + 1; status = pthread_create (&thrs[i].tid, NULL, thrFunc, (void*)&thrs[i]); if (status != 0) err_abort (status, "create thread"); } for (i = 0; i < NTHR; i++) { status = pthread_join (thrs[i].tid, NULL); if (status != 0) err_abort (status, "join thread"); printf ("%02d: (%d) ", i, thrs[i].incr); for (j = 0; j < NARR; j++) printf ("%010u ", thrs[i].arr[j]); printf ("\n"); } pthread_barrier_destroy (&barrier); return 0; }
怎么用golang 来表达上述c 代码,需要实现pthread_barrier_t 等价语义的的 barrier 对象,可以使用golang 已有的mutex, cond
对象实现 barrier:
package main import ( "fmt" "sync" ) type Barrier struct{ lock sync.Mutex cond sync.Cond threshold int //总的等待个数 count int //还剩多少没有到达barrier,即没有完成wait调用个数 cycle bool //用于重初始化下一个wait 周期, } func NewBarrier(n int) *Barrier{ b := &Barrier{threshold: n, count: n} b.cond.L = &b.lock return b } //last == true ,说明最有一个到达 func (b *Barrier)Wait()(last bool){ b.lock.Lock() defer b.lock.Unlock() cycle := b.cycle b.count-- //最后一个到达负责,重初始化count 计数,cycle 变量翻转, if b.count == 0 { b.cycle = !b.cycle b.count = b.threshold b.cond.Broadcast() last = true }else{ for cycle == b.cycle { b.cond.Wait() } } return } type thrArg struct{ incr int arr [narr]int } var ( thrs [nthr]thrArg wg sync.WaitGroup barrier = NewBarrier(nthr) ) const ( outloops = 10 inloops = 1000 nthr = 5 narr = 6 ) func thrFunc(arg *thrArg){ defer wg.Done() for i := 0; i < outloops; i++{ barrier.Wait() for j := 0; j < inloops; j++{ for k:= 0; k < narr; k++{ arg.arr[k] += arg.incr } } if barrier.Wait() { for i := 0; i < nthr; i++{ thrs[i].incr += 1 } } } } func main(){ for i:= 0; i < nthr; i++{ thrs[i].incr = i for j := 0; j < narr; j++{ thrs[i].arr[j] = j + 1 } wg.Add(1) go thrFunc(&thrs[i]) } wg.Wait() //所有goroutine完成,main goroutine,检查最后的结果 for i := 0; i < nthr; i++{ fmt.Printf("%02d: (%d) ", i, thrs[i].incr) for j := 0; j < narr; j++{ fmt.Printf ("%010d ", thrs[i].arr[j]); } fmt.Println() } }
有疑问加站长微信联系(非本文作者)