V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
The Go Programming Language
http://golang.org/
Go Playground
Go Projects
Revel Web Framework
beego
hetiansu5
V2EX  ›  Go

Go 之 WaitGroup 底层实现

  •  
  •   hetiansu5 · 18 天前 · 1110 次点击

    WaitGroup

    WaitGroup 用于等待一组线程的结束,父线程调用 Add 来增加等待的线程数,被等待的线程在结束后调用 Done 来将等待线程数减 1,父线程通过调用 Wait 阻塞等待所有结束(计数器清零)后进行唤醒。

    源码位置

    WaitGroup 的源码在 SDK 包的路径为src/sync/waitgroup.go

    数据结构

    type WaitGroup struct {
    	noCopy noCopy
    	state1 [3]uint32
    }
    

    1.noCopy noCopy

    noCopy 这个主要用来限制不能进行 copy,这里是为了避免 copy 后的 waitGroup 并发使用后,可能会与原 waitGroup 出现异常而 panic 。

    2.state1 [3]unit32

    数组的三个元素(非顺序):

    • counter 通过 Add()设置的子 goroutine 的数量,即被等待线程计数
    • waiter 通过 Wait()陷入阻塞的等待者计数
    • semap 信号量,用于唤醒阻塞 waiter

    这里需要注意一下 couter 、waiter 、semap 并不是顺序存储的,64bit 操作系统的原子操作需要保证 64bit 的内存对齐,在设计上我们需要保证 couter 和 waiter 的操作原子性。如果数组的首元素地址能被 8 整除,则 counter 和 waiter 刚好可以在同一块原子操作的 64bit 内存上,所以取数组前两个元素分别表示 couter 和 waiter ;如果不能被 8 整除(根据内存对齐的原理,地址必然是 4 的倍数),则取数组后两个。

    // 根据内存对齐方式的不同,返回 statep(couter 占用高 32bit 和 waiter 占用低 32bit)和 semap 的地址
    func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
    	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
    		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
    	} else {
    		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
    	}
    }
    

    alignment.png

    公共方法

    func (wg *WaitGroup) Add(delta int) //增加 waitGroup 子 goruntine 计数值
    func (wg *WaitGroup) Done() //当子 goruntine 完成后,将计数器-1
    func (wg *WaitGroup) Wait() //调用此方法的 goruntine,阻塞等待计数值为 0
    

    以下方法去除了 race 竞争检查的源代码。

    Add

    操作 counter 计数值加减。

    • 当 counter 增加时,直接 return
    • 当 counter 减少时, 判断条件:counter > 0 || waiter == 0
      • true 时,直接 return
      • false (等待线程都完成且有等待者)时,statep 复位为 0,通过 semap 信号量唤醒所有等待者
    func (wg *WaitGroup) Add(delta int) {
    	//从数组中拿到 stetep ( counter+waiter 的组合)和 semap 信号量的内存地址
    	statep, semap := wg.state()
    	//stetep 原子加操作,高位 32bit 是 counter,实际 counter+1
    	state := atomic.AddUint64(statep, uint64(delta)<<32)
    	//state 的高位 32bit,表示 couter 的计数值
    	v := int32(state >> 32)
    	//state 的低位 32bit,表示 waiter 的等待者数量
    	w := uint32(state)
    	// couter 不能小于 0
    	if v < 0 {
    		panic("sync: negative WaitGroup counter")
    	}
    	// 需要避免错误操作:Add 和 Wait 并发操作,否则会 panic
    	if w != 0 && delta > 0 && v == int32(delta) {
    		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    	}
    	// 如果还有等待线程未完成或者并没有等待者,直接 return
    	if v > 0 || w == 0 {
    		return
    	}
    	// 需要避免错误操作:Add 和 Wait 并发操作,否则会 panic
    	if *statep != state {
    		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    	}
    	// 将 statep 复位为 0 ( counter 和 waiter 都置为 0 )
    	*statep = 0
    	// 有多少个等待者就往 semap 循环发信号量(其实就是 semap+1 ),Wait 等待有一个调用	// runtime_Semacquire(semap)就是在等待这个信号量
    	for ; w != 0; w-- {
    		runtime_Semrelease(semap, false, 0)
    	}
    }
    

    Done

    被等待线程完成后调用 Done,将 counter 计数-1,表示线程结束

    func (wg *WaitGroup) Done() {
    	wg.Add(-1)
    }
    

    Wait

    主线程循环对 waiter 原子操作+1 直到成功后,然后阻塞等待 semap 信号量而被唤醒,最后 return

    func (wg *WaitGroup) Wait() {
    	// 从数组中拿到 stetep ( counter+waiter 的组合)和 semap 信号量的内存地址
    	statep, semap := wg.state()
    	for {
    		//从内存总线中加载最新的 statep 值
    		state := atomic.LoadUint64(statep)
    		//state 的高位 32bit,表示 couter 的计数值
    		v := int32(state >> 32)
    		//state 的低位 32bit,表示 waiter 的等待者数量
    		w := uint32(state)
    		//如果 couter 为 0,表示当前已经没有在运行的等待线程了
    		if v == 0 {
    			return
    		}
    		// CAS 操作 statep+1,低位属于 waiter,即 waiter+1
    		if atomic.CompareAndSwapUint64(statep, state, state+1) {
    			// CAS 操作成功后,阻塞等待 semap 信号为非零,竞争到会将 semap-1,并唤醒线程
    			runtime_Semacquire(semap)
    			if *statep != 0 {
    				panic("sync: WaitGroup is reused before previous Wait has returned")
    			}
    			return
    		}
    		// CAS 操作失败了,重新进入循环
    	}
    }
    
    4 条回复    2021-04-07 13:32:21 +08:00
    makdon
        1
    makdon   18 天前   ❤️ 2
    拉到最后竟然没有公众号 /博客 /培训班 /招聘
    raaaaaar
        2
    raaaaaar   18 天前 via Android
    最近学了操作系统,发现就是个二元信号量。。
    hetiansu5
        3
    hetiansu5   17 天前
    @makdon 哈哈,单纯输出而已,变相的加深理解
    kuro1
        4
    kuro1   15 天前
    拉到最后竟然没有公众号 /博客 /培训班 /招聘+1
    关于   ·   帮助文档   ·   FAQ   ·   API   ·   我们的愿景   ·   广告投放   ·   感谢   ·   实用小工具   ·   1327 人在线   最高记录 5497   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 16ms · UTC 23:46 · PVG 07:46 · LAX 16:46 · JFK 19:46
    ♥ Do have faith in what you're doing.