Skip to content

Rust 并发编程

概述

Rust 提供了安全且高效的并发编程支持。通过所有权系统和类型系统,Rust 能够在编译时防止数据竞争和其他并发问题。本章将学习 Rust 中的线程、消息传递、共享状态等并发编程概念。

🧵 线程基础

创建和管理线程

rust
use std::thread;
use std::time::Duration;

fn basic_threading() {
    // 创建新线程
    let handle = thread::spawn(|| {
        for i in 1..10 {
            println!("子线程: {}", i);
            thread::sleep(Duration::from_millis(1));
        }
    });
    
    // 主线程继续执行
    for i in 1..5 {
        println!("主线程: {}", i);
        thread::sleep(Duration::from_millis(1));
    }
    
    // 等待子线程完成
    handle.join().unwrap();
    println!("所有线程执行完成");
}

线程间数据传递

rust
fn thread_data_transfer() {
    let v = vec![1, 2, 3];
    
    // 使用 move 转移所有权
    let handle = thread::spawn(move || {
        println!("向量: {:?}", v);
    });
    
    handle.join().unwrap();
    
    // println!("v: {:?}", v); // 编译错误!v 已被移动
}

fn multiple_threads() {
    let mut handles = vec![];
    
    for i in 0..10 {
        let handle = thread::spawn(move || {
            let thread_id = thread::current().id();
            println!("线程 {} (ID: {:?}): 计算 {} * 2 = {}", i, thread_id, i, i * 2);
            thread::sleep(Duration::from_millis(100));
            i * 2 // 返回值
        });
        handles.push(handle);
    }
    
    // 收集所有结果
    let results: Vec<i32> = handles
        .into_iter()
        .map(|handle| handle.join().unwrap())
        .collect();
    
    println!("所有结果: {:?}", results);
}

📨 消息传递

基本通道(Channel)

rust
use std::sync::mpsc;
use std::thread;
use std::time::Duration;

fn basic_channels() {
    // 创建通道
    let (tx, rx) = mpsc::channel();
    
    thread::spawn(move || {
        let val = String::from("你好");
        tx.send(val).unwrap();
        // println!("val: {}", val); // 编译错误!val 已被移动
    });
    
    // 接收消息
    let received = rx.recv().unwrap();
    println!("收到消息: {}", received);
}

fn multiple_messages() {
    let (tx, rx) = mpsc::channel();
    
    thread::spawn(move || {
        let vals = vec![
            String::from("你好"),
            String::from("来自"),
            String::from("线程"),
        ];
        
        for val in vals {
            tx.send(val).unwrap();
            thread::sleep(Duration::from_secs(1));
        }
    });
    
    // 迭代接收消息
    for received in rx {
        println!("收到: {}", received);
    }
}

多生产者单消费者

rust
fn multiple_producers() {
    let (tx, rx) = mpsc::channel();
    
    // 克隆发送者
    let tx1 = tx.clone();
    thread::spawn(move || {
        let vals = vec![
            String::from("线程1: 消息1"),
            String::from("线程1: 消息2"),
        ];
        
        for val in vals {
            tx1.send(val).unwrap();
            thread::sleep(Duration::from_secs(1));
        }
    });
    
    thread::spawn(move || {
        let vals = vec![
            String::from("线程2: 消息1"),
            String::from("线程2: 消息2"),
        ];
        
        for val in vals {
            tx.send(val).unwrap();
            thread::sleep(Duration::from_secs(1));
        }
    });
    
    // 接收所有消息
    for received in rx {
        println!("收到: {}", received);
    }
}

异步通道

rust
use std::sync::mpsc::{sync_channel, TryRecvError};

fn sync_channels() {
    // 同步通道,容量为 0(阻塞)
    let (tx, rx) = sync_channel(0);
    
    let handle = thread::spawn(move || {
        println!("发送消息前");
        tx.send(1).unwrap(); // 阻塞直到接收者准备好
        println!("发送消息后");
    });
    
    thread::sleep(Duration::from_secs(2));
    println!("准备接收");
    let msg = rx.recv().unwrap();
    println!("收到: {}", msg);
    
    handle.join().unwrap();
}

fn non_blocking_receive() {
    let (tx, rx) = mpsc::channel();
    
    // 非阻塞接收
    match rx.try_recv() {
        Ok(msg) => println!("收到: {}", msg),
        Err(TryRecvError::Empty) => println!("通道为空"),
        Err(TryRecvError::Disconnected) => println!("通道已断开"),
    }
    
    // 发送消息
    tx.send("Hello").unwrap();
    
    // 现在可以接收
    match rx.try_recv() {
        Ok(msg) => println!("收到: {}", msg),
        Err(TryRecvError::Empty) => println!("通道为空"),
        Err(TryRecvError::Disconnected) => println!("通道已断开"),
    }
}

🔒 共享状态

互斥锁(Mutex)

rust
use std::sync::{Arc, Mutex};
use std::thread;

fn basic_mutex() {
    // 互斥锁保护数据
    let m = Mutex::new(5);
    
    {
        let mut num = m.lock().unwrap();
        *num = 6;
    } // 锁在这里被释放
    
    println!("m = {:?}", m);
}

fn shared_mutex() {
    // 使用 Arc 在多个线程间共享 Mutex
    let counter = Arc::new(Mutex::new(0));
    let mut handles = vec![];
    
    for _ in 0..10 {
        let counter = Arc::clone(&counter);
        let handle = thread::spawn(move || {
            let mut num = counter.lock().unwrap();
            *num += 1;
        });
        handles.push(handle);
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
    
    println!("最终计数: {}", *counter.lock().unwrap());
}

读写锁(RwLock)

rust
use std::sync::RwLock;
use std::collections::HashMap;

fn rwlock_example() {
    let data = Arc::new(RwLock::new(HashMap::new()));
    let mut handles = vec![];
    
    // 写入数据的线程
    for i in 0..5 {
        let data = Arc::clone(&data);
        let handle = thread::spawn(move || {
            let mut map = data.write().unwrap();
            map.insert(i, i * i);
            println!("写入: {} -> {}", i, i * i);
        });
        handles.push(handle);
    }
    
    // 读取数据的线程
    for i in 0..10 {
        let data = Arc::clone(&data);
        let handle = thread::spawn(move || {
            let map = data.read().unwrap();
            if let Some(value) = map.get(&(i % 5)) {
                println!("读取: {} -> {}", i % 5, value);
            }
        });
        handles.push(handle);
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
}

原子类型

rust
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

fn atomic_operations() {
    let counter = Arc::new(AtomicUsize::new(0));
    let mut handles = vec![];
    
    for _ in 0..10 {
        let counter = Arc::clone(&counter);
        let handle = thread::spawn(move || {
            for _ in 0..1000 {
                counter.fetch_add(1, Ordering::SeqCst);
            }
        });
        handles.push(handle);
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
    
    println!("原子计数器: {}", counter.load(Ordering::SeqCst));
}

fn compare_and_swap() {
    let value = Arc::new(AtomicUsize::new(0));
    let mut handles = vec![];
    
    for i in 0..10 {
        let value = Arc::clone(&value);
        let handle = thread::spawn(move || {
            let expected = i;
            let new_value = i + 10;
            
            match value.compare_exchange_weak(
                expected, 
                new_value, 
                Ordering::SeqCst, 
                Ordering::SeqCst
            ) {
                Ok(prev) => println!("线程 {} 成功更新: {} -> {}", i, prev, new_value),
                Err(current) => println!("线程 {} 更新失败,当前值: {}", i, current),
            }
        });
        handles.push(handle);
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
    
    println!("最终值: {}", value.load(Ordering::SeqCst));
}

🔄 并发模式

生产者-消费者模式

rust
use std::sync::{Arc, Condvar, Mutex};
use std::collections::VecDeque;

struct Buffer<T> {
    queue: Mutex<VecDeque<T>>,
    not_empty: Condvar,
    not_full: Condvar,
    capacity: usize,
}

impl<T> Buffer<T> {
    fn new(capacity: usize) -> Self {
        Self {
            queue: Mutex::new(VecDeque::new()),
            not_empty: Condvar::new(),
            not_full: Condvar::new(),
            capacity,
        }
    }
    
    fn push(&self, item: T) {
        let mut queue = self.queue.lock().unwrap();
        
        // 等待缓冲区有空间
        while queue.len() >= self.capacity {
            queue = self.not_full.wait(queue).unwrap();
        }
        
        queue.push_back(item);
        self.not_empty.notify_one();
    }
    
    fn pop(&self) -> T {
        let mut queue = self.queue.lock().unwrap();
        
        // 等待缓冲区有数据
        while queue.is_empty() {
            queue = self.not_empty.wait(queue).unwrap();
        }
        
        let item = queue.pop_front().unwrap();
        self.not_full.notify_one();
        item
    }
}

fn producer_consumer_pattern() {
    let buffer = Arc::new(Buffer::new(5));
    
    // 生产者
    let buffer_producer = Arc::clone(&buffer);
    let producer = thread::spawn(move || {
        for i in 0..10 {
            println!("生产: {}", i);
            buffer_producer.push(i);
            thread::sleep(Duration::from_millis(100));
        }
    });
    
    // 消费者
    let buffer_consumer = Arc::clone(&buffer);
    let consumer = thread::spawn(move || {
        for _ in 0..10 {
            let item = buffer_consumer.pop();
            println!("消费: {}", item);
            thread::sleep(Duration::from_millis(150));
        }
    });
    
    producer.join().unwrap();
    consumer.join().unwrap();
}

工作窃取模式

rust
use std::sync::{Arc, Mutex};
use std::collections::VecDeque;

struct WorkStealing<T> {
    local_queue: Mutex<VecDeque<T>>,
    global_queue: Arc<Mutex<VecDeque<T>>>,
}

impl<T> WorkStealing<T> {
    fn new(global_queue: Arc<Mutex<VecDeque<T>>>) -> Self {
        Self {
            local_queue: Mutex::new(VecDeque::new()),
            global_queue,
        }
    }
    
    fn push_local(&self, item: T) {
        let mut queue = self.local_queue.lock().unwrap();
        queue.push_back(item);
    }
    
    fn pop_local(&self) -> Option<T> {
        let mut queue = self.local_queue.lock().unwrap();
        queue.pop_back()
    }
    
    fn steal_work(&self) -> Option<T> {
        // 先尝试本地队列
        if let Some(item) = self.pop_local() {
            return Some(item);
        }
        
        // 再尝试全局队列
        let mut global = self.global_queue.lock().unwrap();
        global.pop_front()
    }
}

fn work_stealing_pattern() {
    let global_queue = Arc::new(Mutex::new(VecDeque::new()));
    
    // 添加一些工作
    {
        let mut queue = global_queue.lock().unwrap();
        for i in 0..20 {
            queue.push_back(i);
        }
    }
    
    let mut handles = vec![];
    
    // 创建工作线程
    for thread_id in 0..4 {
        let global_queue = Arc::clone(&global_queue);
        let worker = WorkStealing::new(global_queue);
        
        let handle = thread::spawn(move || {
            let mut processed = 0;
            
            while let Some(work) = worker.steal_work() {
                println!("线程 {} 处理工作: {}", thread_id, work);
                thread::sleep(Duration::from_millis(10));
                processed += 1;
            }
            
            println!("线程 {} 完成,处理了 {} 个任务", thread_id, processed);
        });
        
        handles.push(handle);
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
}

⚡ 并发性能优化

线程池

rust
use std::sync::{mpsc, Arc, Mutex};

pub struct ThreadPool {
    workers: Vec<Worker>,
    sender: mpsc::Sender<Job>,
}

type Job = Box<dyn FnOnce() + Send + 'static>;

impl ThreadPool {
    pub fn new(size: usize) -> ThreadPool {
        assert!(size > 0);
        
        let (sender, receiver) = mpsc::channel();
        let receiver = Arc::new(Mutex::new(receiver));
        
        let mut workers = Vec::with_capacity(size);
        
        for id in 0..size {
            workers.push(Worker::new(id, Arc::clone(&receiver)));
        }
        
        ThreadPool { workers, sender }
    }
    
    pub fn execute<F>(&self, f: F)
    where
        F: FnOnce() + Send + 'static,
    {
        let job = Box::new(f);
        self.sender.send(job).unwrap();
    }
}

struct Worker {
    id: usize,
    thread: thread::JoinHandle<()>,
}

impl Worker {
    fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker {
        let thread = thread::spawn(move || loop {
            let job = receiver.lock().unwrap().recv().unwrap();
            println!("工作线程 {} 执行任务", id);
            job();
        });
        
        Worker { id, thread }
    }
}

fn thread_pool_example() {
    let pool = ThreadPool::new(4);
    
    for i in 0..8 {
        pool.execute(move || {
            println!("任务 {} 开始执行", i);
            thread::sleep(Duration::from_secs(1));
            println!("任务 {} 执行完成", i);
        });
    }
    
    thread::sleep(Duration::from_secs(3));
}

无锁数据结构

rust
use std::sync::atomic::{AtomicPtr, Ordering};
use std::ptr;

pub struct LockFreeStack<T> {
    head: AtomicPtr<Node<T>>,
}

struct Node<T> {
    data: T,
    next: *mut Node<T>,
}

impl<T> LockFreeStack<T> {
    pub fn new() -> Self {
        Self {
            head: AtomicPtr::new(ptr::null_mut()),
        }
    }
    
    pub fn push(&self, data: T) {
        let new_node = Box::into_raw(Box::new(Node {
            data,
            next: ptr::null_mut(),
        }));
        
        loop {
            let head = self.head.load(Ordering::Acquire);
            unsafe {
                (*new_node).next = head;
            }
            
            match self.head.compare_exchange_weak(
                head,
                new_node,
                Ordering::Release,
                Ordering::Relaxed,
            ) {
                Ok(_) => break,
                Err(_) => continue,
            }
        }
    }
    
    pub fn pop(&self) -> Option<T> {
        loop {
            let head = self.head.load(Ordering::Acquire);
            if head.is_null() {
                return None;
            }
            
            let next = unsafe { (*head).next };
            
            match self.head.compare_exchange_weak(
                head,
                next,
                Ordering::Release,
                Ordering::Relaxed,
            ) {
                Ok(_) => {
                    let data = unsafe { Box::from_raw(head).data };
                    return Some(data);
                }
                Err(_) => continue,
            }
        }
    }
}

unsafe impl<T: Send> Send for LockFreeStack<T> {}
unsafe impl<T: Send> Sync for LockFreeStack<T> {}

fn lock_free_example() {
    let stack = Arc::new(LockFreeStack::new());
    let mut handles = vec![];
    
    // 推入数据
    for i in 0..10 {
        let stack = Arc::clone(&stack);
        let handle = thread::spawn(move || {
            stack.push(i);
            println!("推入: {}", i);
        });
        handles.push(handle);
    }
    
    // 弹出数据
    for _ in 0..10 {
        let stack = Arc::clone(&stack);
        let handle = thread::spawn(move || {
            if let Some(value) = stack.pop() {
                println!("弹出: {}", value);
            }
        });
        handles.push(handle);
    }
    
    for handle in handles {
        handle.join().unwrap();
    }
}

📝 本章小结

通过本章学习,你应该掌握了:

基础并发

  • ✅ 线程的创建和管理
  • ✅ 消息传递机制(Channel)
  • ✅ 共享状态的同步(Mutex、RwLock)
  • ✅ 原子操作的使用

高级并发模式

  • ✅ 生产者-消费者模式
  • ✅ 工作窃取模式
  • ✅ 线程池的实现
  • ✅ 无锁数据结构

性能和安全

  • ✅ Rust 的内存安全保证
  • ✅ 数据竞争的预防
  • ✅ 并发性能优化技巧
  • ✅ 选择合适的并发原语

最佳实践

  1. 优先选择消息传递
  2. 谨慎使用共享状态
  3. 利用类型系统保证安全
  4. 根据场景选择同步原语

继续学习下一章 - Rust 宏

本站内容仅供学习和研究使用。