Skip to content

Go 并发模型和并发安全

Go 语言的并发模型是其最突出的特性之一。通过 goroutines 和 channels,Go 提供了简洁而强大的并发编程支持。

📋 并发基础概念

并发 vs 并行

go
package main

import (
    "fmt"
    "runtime"
    "time"
)

// 演示并发概念
func demonstrateConcurrency() {
    fmt.Printf("CPU 核心数: %d\n", runtime.NumCPU())
    fmt.Printf("当前 Goroutines 数量: %d\n", runtime.NumGoroutine())
    
    // 创建多个 goroutines
    for i := 1; i <= 3; i++ {
        go func(id int) {
            for j := 1; j <= 3; j++ {
                fmt.Printf("Goroutine %d: 任务 %d\n", id, j)
                time.Sleep(100 * time.Millisecond)
            }
        }(i)
    }
    
    // 主 goroutine 等待
    time.Sleep(500 * time.Millisecond)
    fmt.Printf("结束时 Goroutines 数量: %d\n", runtime.NumGoroutine())
}

// 顺序执行 vs 并发执行
func sequentialExecution() {
    fmt.Println("=== 顺序执行 ===")
    start := time.Now()
    
    for i := 1; i <= 3; i++ {
        func(id int) {
            for j := 1; j <= 3; j++ {
                fmt.Printf("任务 %d-%d\n", id, j)
                time.Sleep(100 * time.Millisecond)
            }
        }(i)
    }
    
    fmt.Printf("顺序执行耗时: %v\n", time.Since(start))
}

func concurrentExecution() {
    fmt.Println("\n=== 并发执行 ===")
    start := time.Now()
    
    done := make(chan bool, 3)
    
    for i := 1; i <= 3; i++ {
        go func(id int) {
            for j := 1; j <= 3; j++ {
                fmt.Printf("任务 %d-%d\n", id, j)
                time.Sleep(100 * time.Millisecond)
            }
            done <- true
        }(i)
    }
    
    // 等待所有 goroutines 完成
    for i := 0; i < 3; i++ {
        <-done
    }
    
    fmt.Printf("并发执行耗时: %v\n", time.Since(start))
}

func main() {
    demonstrateConcurrency()
    
    sequentialExecution()
    concurrentExecution()
}

🚦 数据竞争和竞态条件

数据竞争示例

go
package main

import (
    "fmt"
    "sync"
    "time"
)

// 不安全的计数器(存在数据竞争)
type UnsafeCounter struct {
    count int
}

func (c *UnsafeCounter) Increment() {
    c.count++
}

func (c *UnsafeCounter) Get() int {
    return c.count
}

// 安全的计数器(使用互斥锁)
type SafeCounter struct {
    count int
    mutex sync.Mutex
}

func (c *SafeCounter) Increment() {
    c.mutex.Lock()
    defer c.mutex.Unlock()
    c.count++
}

func (c *SafeCounter) Get() int {
    c.mutex.Lock()
    defer c.mutex.Unlock()
    return c.count
}

// 使用原子操作的计数器
type AtomicCounter struct {
    count int64
}

func (c *AtomicCounter) Increment() {
    // 这里应该使用 atomic.AddInt64(&c.count, 1)
    // 为了演示,这里仍然不安全
    c.count++
}

func (c *AtomicCounter) Get() int64 {
    // 这里应该使用 atomic.LoadInt64(&c.count)
    return c.count
}

func demonstrateDataRace() {
    fmt.Println("=== 数据竞争演示 ===")
    
    // 不安全的计数器
    unsafeCounter := &UnsafeCounter{}
    var wg sync.WaitGroup
    
    const numGoroutines = 1000
    const incrementsPerGoroutine = 1000
    
    start := time.Now()
    
    for i := 0; i < numGoroutines; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            for j := 0; j < incrementsPerGoroutine; j++ {
                unsafeCounter.Increment()
            }
        }()
    }
    
    wg.Wait()
    
    expected := numGoroutines * incrementsPerGoroutine
    actual := unsafeCounter.Get()
    
    fmt.Printf("不安全计数器:\n")
    fmt.Printf("  期望值: %d\n", expected)
    fmt.Printf("  实际值: %d\n", actual)
    fmt.Printf("  丢失: %d\n", expected-actual)
    fmt.Printf("  耗时: %v\n", time.Since(start))
    
    // 安全的计数器
    safeCounter := &SafeCounter{}
    start = time.Now()
    
    for i := 0; i < numGoroutines; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            for j := 0; j < incrementsPerGoroutine; j++ {
                safeCounter.Increment()
            }
        }()
    }
    
    wg.Wait()
    
    fmt.Printf("\n安全计数器:\n")
    fmt.Printf("  期望值: %d\n", expected)
    fmt.Printf("  实际值: %d\n", safeCounter.Get())
    fmt.Printf("  耗时: %v\n", time.Since(start))
}

func main() {
    demonstrateDataRace()
}

🔒 同步原语

Mutex(互斥锁)

go
package main

import (
    "fmt"
    "sync"
    "time"
)

// 银行账户示例
type BankAccount struct {
    balance int
    mutex   sync.Mutex
}

func NewBankAccount(initialBalance int) *BankAccount {
    return &BankAccount{
        balance: initialBalance,
    }
}

func (account *BankAccount) Deposit(amount int) {
    account.mutex.Lock()
    defer account.mutex.Unlock()
    
    fmt.Printf("存入 %d 元,", amount)
    oldBalance := account.balance
    account.balance += amount
    fmt.Printf("余额从 %d 变为 %d\n", oldBalance, account.balance)
}

func (account *BankAccount) Withdraw(amount int) bool {
    account.mutex.Lock()
    defer account.mutex.Unlock()
    
    if account.balance >= amount {
        fmt.Printf("取出 %d 元,", amount)
        oldBalance := account.balance
        account.balance -= amount
        fmt.Printf("余额从 %d 变为 %d\n", oldBalance, account.balance)
        return true
    } else {
        fmt.Printf("取出 %d 元失败,余额不足 (当前: %d)\n", amount, account.balance)
        return false
    }
}

func (account *BankAccount) GetBalance() int {
    account.mutex.Lock()
    defer account.mutex.Unlock()
    return account.balance
}

// 转账操作(需要避免死锁)
func Transfer(from, to *BankAccount, amount int) bool {
    // 为了避免死锁,总是按照相同的顺序获取锁
    if from == to {
        return false
    }
    
    // 简单的排序策略:按内存地址排序
    var first, second *BankAccount
    if from < to {
        first, second = from, to
    } else {
        first, second = to, from
    }
    
    first.mutex.Lock()
    defer first.mutex.Unlock()
    
    second.mutex.Lock()
    defer second.mutex.Unlock()
    
    if from.balance >= amount {
        from.balance -= amount
        to.balance += amount
        fmt.Printf("转账成功: %d\n", amount)
        return true
    } else {
        fmt.Printf("转账失败: 余额不足\n")
        return false
    }
}

func demonstrateMutex() {
    fmt.Println("=== Mutex 演示 ===")
    
    account := NewBankAccount(1000)
    var wg sync.WaitGroup
    
    // 并发存取操作
    for i := 0; i < 5; i++ {
        wg.Add(2)
        
        // 存款 goroutine
        go func(id int) {
            defer wg.Done()
            account.Deposit(100)
            time.Sleep(10 * time.Millisecond)
        }(i)
        
        // 取款 goroutine
        go func(id int) {
            defer wg.Done()
            account.Withdraw(50)
            time.Sleep(10 * time.Millisecond)
        }(i)
    }
    
    wg.Wait()
    fmt.Printf("最终余额: %d\n", account.GetBalance())
    
    // 转账演示
    fmt.Println("\n转账演示:")
    account1 := NewBankAccount(500)
    account2 := NewBankAccount(300)
    
    fmt.Printf("转账前 - 账户1: %d, 账户2: %d\n", 
              account1.GetBalance(), account2.GetBalance())
    
    Transfer(account1, account2, 200)
    
    fmt.Printf("转账后 - 账户1: %d, 账户2: %d\n", 
              account1.GetBalance(), account2.GetBalance())
}

func main() {
    demonstrateMutex()
}

RWMutex(读写锁)

go
package main

import (
    "fmt"
    "sync"
    "time"
)

// 配置管理器(读多写少的场景)
type ConfigManager struct {
    config map[string]string
    mutex  sync.RWMutex
}

func NewConfigManager() *ConfigManager {
    return &ConfigManager{
        config: make(map[string]string),
    }
}

func (cm *ConfigManager) Set(key, value string) {
    cm.mutex.Lock()
    defer cm.mutex.Unlock()
    
    fmt.Printf("设置配置: %s = %s\n", key, value)
    cm.config[key] = value
    time.Sleep(10 * time.Millisecond) // 模拟写操作耗时
}

func (cm *ConfigManager) Get(key string) (string, bool) {
    cm.mutex.RLock()
    defer cm.mutex.RUnlock()
    
    value, exists := cm.config[key]
    if exists {
        fmt.Printf("读取配置: %s = %s\n", key, value)
    } else {
        fmt.Printf("配置不存在: %s\n", key)
    }
    
    time.Sleep(1 * time.Millisecond) // 模拟读操作耗时
    return value, exists
}

func (cm *ConfigManager) GetAll() map[string]string {
    cm.mutex.RLock()
    defer cm.mutex.RUnlock()
    
    // 创建副本避免外部修改
    result := make(map[string]string)
    for k, v := range cm.config {
        result[k] = v
    }
    
    fmt.Printf("获取所有配置: %d\n", len(result))
    return result
}

func demonstrateRWMutex() {
    fmt.Println("=== RWMutex 演示 ===")
    
    config := NewConfigManager()
    var wg sync.WaitGroup
    
    // 初始化一些配置
    config.Set("app_name", "Go Tutorial")
    config.Set("version", "1.0")
    config.Set("debug", "false")
    
    // 启动多个读操作
    for i := 0; i < 10; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            
            keys := []string{"app_name", "version", "debug", "missing_key"}
            for _, key := range keys {
                config.Get(key)
                time.Sleep(5 * time.Millisecond)
            }
        }(i)
    }
    
    // 启动少量写操作
    for i := 0; i < 3; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            
            config.Set(fmt.Sprintf("dynamic_key_%d", id), fmt.Sprintf("value_%d", id))
            time.Sleep(20 * time.Millisecond)
        }(i)
    }
    
    // 启动一个获取所有配置的操作
    wg.Add(1)
    go func() {
        defer wg.Done()
        time.Sleep(50 * time.Millisecond)
        allConfig := config.GetAll()
        fmt.Printf("配置总数: %d\n", len(allConfig))
    }()
    
    wg.Wait()
    
    fmt.Println("\n最终配置:")
    finalConfig := config.GetAll()
    for k, v := range finalConfig {
        fmt.Printf("  %s: %s\n", k, v)
    }
}

func main() {
    demonstrateRWMutex()
}

WaitGroup(等待组)

go
package main

import (
    "fmt"
    "sync"
    "time"
)

// 任务处理器
type TaskProcessor struct {
    name string
}

func (tp TaskProcessor) Process(taskID int, duration time.Duration) {
    fmt.Printf("[%s] 开始处理任务 %d\n", tp.name, taskID)
    time.Sleep(duration)
    fmt.Printf("[%s] 完成处理任务 %d (耗时: %v)\n", tp.name, taskID, duration)
}

// 批量任务处理
func processBatchTasks() {
    fmt.Println("=== 批量任务处理 ===")
    
    var wg sync.WaitGroup
    processor := TaskProcessor{name: "Worker"}
    
    tasks := []struct {
        id       int
        duration time.Duration
    }{
        {1, 100 * time.Millisecond},
        {2, 200 * time.Millisecond},
        {3, 150 * time.Millisecond},
        {4, 80 * time.Millisecond},
        {5, 300 * time.Millisecond},
    }
    
    start := time.Now()
    
    for _, task := range tasks {
        wg.Add(1)
        go func(t struct {
            id       int
            duration time.Duration
        }) {
            defer wg.Done()
            processor.Process(t.id, t.duration)
        }(task)
    }
    
    fmt.Println("所有任务已启动,等待完成...")
    wg.Wait()
    
    fmt.Printf("所有任务完成,总耗时: %v\n", time.Since(start))
}

// 工作池模式
func workerPoolPattern() {
    fmt.Println("\n=== 工作池模式 ===")
    
    const numWorkers = 3
    const numJobs = 10
    
    jobs := make(chan int, numJobs)
    results := make(chan string, numJobs)
    
    var wg sync.WaitGroup
    
    // 启动工作者
    for w := 1; w <= numWorkers; w++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            
            for jobID := range jobs {
                fmt.Printf("Worker %d 开始处理 Job %d\n", id, jobID)
                
                // 模拟工作
                time.Sleep(100 * time.Millisecond)
                
                result := fmt.Sprintf("Worker %d 完成 Job %d", id, jobID)
                results <- result
                
                fmt.Printf("Worker %d 完成 Job %d\n", id, jobID)
            }
        }(w)
    }
    
    // 发送任务
    go func() {
        for j := 1; j <= numJobs; j++ {
            jobs <- j
        }
        close(jobs)
    }()
    
    // 等待所有工作者完成
    go func() {
        wg.Wait()
        close(results)
    }()
    
    // 收集结果
    fmt.Println("收集结果:")
    for result := range results {
        fmt.Printf("  %s\n", result)
    }
}

// 分阶段处理
func pipelineProcessing() {
    fmt.Println("\n=== 流水线处理 ===")
    
    var wg sync.WaitGroup
    
    // 创建管道
    numbers := make(chan int, 10)
    squares := make(chan int, 10)
    
    // 阶段1:生成数字
    wg.Add(1)
    go func() {
        defer wg.Done()
        defer close(numbers)
        
        for i := 1; i <= 5; i++ {
            fmt.Printf("生成数字: %d\n", i)
            numbers <- i
            time.Sleep(50 * time.Millisecond)
        }
    }()
    
    // 阶段2:计算平方
    wg.Add(1)
    go func() {
        defer wg.Done()
        defer close(squares)
        
        for num := range numbers {
            square := num * num
            fmt.Printf("计算平方: %d -> %d\n", num, square)
            squares <- square
            time.Sleep(30 * time.Millisecond)
        }
    }()
    
    // 阶段3:输出结果
    wg.Add(1)
    go func() {
        defer wg.Done()
        
        total := 0
        count := 0
        
        for square := range squares {
            fmt.Printf("结果: %d\n", square)
            total += square
            count++
        }
        
        if count > 0 {
            fmt.Printf("平均值: %.2f\n", float64(total)/float64(count))
        }
    }()
    
    wg.Wait()
    fmt.Println("流水线处理完成")
}

func main() {
    processBatchTasks()
    workerPoolPattern()
    pipelineProcessing()
}

🎯 并发安全的数据结构

安全的映射

go
package main

import (
    "fmt"
    "sync"
    "time"
)

// 并发安全的映射
type SafeMap struct {
    data  map[string]int
    mutex sync.RWMutex
}

func NewSafeMap() *SafeMap {
    return &SafeMap{
        data: make(map[string]int),
    }
}

func (sm *SafeMap) Set(key string, value int) {
    sm.mutex.Lock()
    defer sm.mutex.Unlock()
    sm.data[key] = value
}

func (sm *SafeMap) Get(key string) (int, bool) {
    sm.mutex.RLock()
    defer sm.mutex.RUnlock()
    value, exists := sm.data[key]
    return value, exists
}

func (sm *SafeMap) Delete(key string) {
    sm.mutex.Lock()
    defer sm.mutex.Unlock()
    delete(sm.data, key)
}

func (sm *SafeMap) Keys() []string {
    sm.mutex.RLock()
    defer sm.mutex.RUnlock()
    
    keys := make([]string, 0, len(sm.data))
    for k := range sm.data {
        keys = append(keys, k)
    }
    return keys
}

func (sm *SafeMap) Size() int {
    sm.mutex.RLock()
    defer sm.mutex.RUnlock()
    return len(sm.data)
}

// 使用 sync.Map(Go 1.9+)
func demonstrateSyncMap() {
    fmt.Println("=== sync.Map 演示 ===")
    
    var sm sync.Map
    var wg sync.WaitGroup
    
    // 并发写入
    for i := 0; i < 10; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            
            key := fmt.Sprintf("key_%d", id)
            value := id * 10
            
            sm.Store(key, value)
            fmt.Printf("存储: %s = %d\n", key, value)
        }(i)
    }
    
    wg.Wait()
    
    // 并发读取
    fmt.Println("\n读取所有数据:")
    sm.Range(func(key, value interface{}) bool {
        fmt.Printf("  %s: %d\n", key.(string), value.(int))
        return true
    })
    
    // 删除一些数据
    fmt.Println("\n删除数据:")
    for i := 0; i < 5; i++ {
        key := fmt.Sprintf("key_%d", i)
        sm.Delete(key)
        fmt.Printf("删除: %s\n", key)
    }
    
    // 再次读取
    fmt.Println("\n删除后的数据:")
    sm.Range(func(key, value interface{}) bool {
        fmt.Printf("  %s: %d\n", key.(string), value.(int))
        return true
    })
}

func main() {
    fmt.Println("=== 自定义安全映射 ===")
    
    safeMap := NewSafeMap()
    var wg sync.WaitGroup
    
    // 并发操作
    for i := 0; i < 5; i++ {
        wg.Add(3)
        
        // 写入操作
        go func(id int) {
            defer wg.Done()
            key := fmt.Sprintf("item_%d", id)
            safeMap.Set(key, id*100)
            fmt.Printf("设置: %s = %d\n", key, id*100)
        }(i)
        
        // 读取操作
        go func(id int) {
            defer wg.Done()
            time.Sleep(10 * time.Millisecond) // 稍微延迟确保有数据可读
            
            key := fmt.Sprintf("item_%d", id)
            if value, exists := safeMap.Get(key); exists {
                fmt.Printf("读取: %s = %d\n", key, value)
            } else {
                fmt.Printf("读取: %s 不存在\n", key)
            }
        }(i)
        
        // 获取所有键
        go func(id int) {
            defer wg.Done()
            time.Sleep(20 * time.Millisecond)
            
            keys := safeMap.Keys()
            fmt.Printf("当前键列表 (线程%d): %v\n", id, keys)
        }(i)
    }
    
    wg.Wait()
    
    fmt.Printf("\n最终大小: %d\n", safeMap.Size())
    fmt.Printf("最终键列表: %v\n", safeMap.Keys())
    
    demonstrateSyncMap()
}

🎓 小结

本章我们深入学习了 Go 语言的并发模型和并发安全:

  • 并发基础:并发 vs 并行,Goroutines 基本概念
  • 数据竞争:竞态条件的识别和避免
  • 同步原语:Mutex、RWMutex、WaitGroup 的使用
  • 并发模式:工作池、流水线等常见模式
  • 安全数据结构:并发安全的数据结构设计
  • 最佳实践:死锁避免、性能优化等技巧

掌握并发编程是使用 Go 语言的关键技能,为构建高性能应用奠定基础。


接下来,我们将学习 Go 语言协程(goroutine),深入探索 Go 并发编程的核心。

并发安全建议

  • 识别和避免数据竞争
  • 合理选择同步原语(Mutex vs RWMutex)
  • 避免死锁,注意锁的获取顺序
  • 使用 go race detector 检测竞态条件

本站内容仅供学习和研究使用。