Rust 并发编程
概述
Rust 提供了安全且高效的并发编程支持。通过所有权系统和类型系统,Rust 能够在编译时防止数据竞争和其他并发问题。本章将学习 Rust 中的线程、消息传递、共享状态等并发编程概念。
🧵 线程基础
创建和管理线程
rust
use std::thread;
use std::time::Duration;
fn basic_threading() {
// 创建新线程
let handle = thread::spawn(|| {
for i in 1..10 {
println!("子线程: {}", i);
thread::sleep(Duration::from_millis(1));
}
});
// 主线程继续执行
for i in 1..5 {
println!("主线程: {}", i);
thread::sleep(Duration::from_millis(1));
}
// 等待子线程完成
handle.join().unwrap();
println!("所有线程执行完成");
}线程间数据传递
rust
fn thread_data_transfer() {
let v = vec![1, 2, 3];
// 使用 move 转移所有权
let handle = thread::spawn(move || {
println!("向量: {:?}", v);
});
handle.join().unwrap();
// println!("v: {:?}", v); // 编译错误!v 已被移动
}
fn multiple_threads() {
let mut handles = vec![];
for i in 0..10 {
let handle = thread::spawn(move || {
let thread_id = thread::current().id();
println!("线程 {} (ID: {:?}): 计算 {} * 2 = {}", i, thread_id, i, i * 2);
thread::sleep(Duration::from_millis(100));
i * 2 // 返回值
});
handles.push(handle);
}
// 收集所有结果
let results: Vec<i32> = handles
.into_iter()
.map(|handle| handle.join().unwrap())
.collect();
println!("所有结果: {:?}", results);
}📨 消息传递
基本通道(Channel)
rust
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
fn basic_channels() {
// 创建通道
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let val = String::from("你好");
tx.send(val).unwrap();
// println!("val: {}", val); // 编译错误!val 已被移动
});
// 接收消息
let received = rx.recv().unwrap();
println!("收到消息: {}", received);
}
fn multiple_messages() {
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let vals = vec![
String::from("你好"),
String::from("来自"),
String::from("线程"),
];
for val in vals {
tx.send(val).unwrap();
thread::sleep(Duration::from_secs(1));
}
});
// 迭代接收消息
for received in rx {
println!("收到: {}", received);
}
}多生产者单消费者
rust
fn multiple_producers() {
let (tx, rx) = mpsc::channel();
// 克隆发送者
let tx1 = tx.clone();
thread::spawn(move || {
let vals = vec![
String::from("线程1: 消息1"),
String::from("线程1: 消息2"),
];
for val in vals {
tx1.send(val).unwrap();
thread::sleep(Duration::from_secs(1));
}
});
thread::spawn(move || {
let vals = vec![
String::from("线程2: 消息1"),
String::from("线程2: 消息2"),
];
for val in vals {
tx.send(val).unwrap();
thread::sleep(Duration::from_secs(1));
}
});
// 接收所有消息
for received in rx {
println!("收到: {}", received);
}
}异步通道
rust
use std::sync::mpsc::{sync_channel, TryRecvError};
fn sync_channels() {
// 同步通道,容量为 0(阻塞)
let (tx, rx) = sync_channel(0);
let handle = thread::spawn(move || {
println!("发送消息前");
tx.send(1).unwrap(); // 阻塞直到接收者准备好
println!("发送消息后");
});
thread::sleep(Duration::from_secs(2));
println!("准备接收");
let msg = rx.recv().unwrap();
println!("收到: {}", msg);
handle.join().unwrap();
}
fn non_blocking_receive() {
let (tx, rx) = mpsc::channel();
// 非阻塞接收
match rx.try_recv() {
Ok(msg) => println!("收到: {}", msg),
Err(TryRecvError::Empty) => println!("通道为空"),
Err(TryRecvError::Disconnected) => println!("通道已断开"),
}
// 发送消息
tx.send("Hello").unwrap();
// 现在可以接收
match rx.try_recv() {
Ok(msg) => println!("收到: {}", msg),
Err(TryRecvError::Empty) => println!("通道为空"),
Err(TryRecvError::Disconnected) => println!("通道已断开"),
}
}🔒 共享状态
互斥锁(Mutex)
rust
use std::sync::{Arc, Mutex};
use std::thread;
fn basic_mutex() {
// 互斥锁保护数据
let m = Mutex::new(5);
{
let mut num = m.lock().unwrap();
*num = 6;
} // 锁在这里被释放
println!("m = {:?}", m);
}
fn shared_mutex() {
// 使用 Arc 在多个线程间共享 Mutex
let counter = Arc::new(Mutex::new(0));
let mut handles = vec![];
for _ in 0..10 {
let counter = Arc::clone(&counter);
let handle = thread::spawn(move || {
let mut num = counter.lock().unwrap();
*num += 1;
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("最终计数: {}", *counter.lock().unwrap());
}读写锁(RwLock)
rust
use std::sync::RwLock;
use std::collections::HashMap;
fn rwlock_example() {
let data = Arc::new(RwLock::new(HashMap::new()));
let mut handles = vec![];
// 写入数据的线程
for i in 0..5 {
let data = Arc::clone(&data);
let handle = thread::spawn(move || {
let mut map = data.write().unwrap();
map.insert(i, i * i);
println!("写入: {} -> {}", i, i * i);
});
handles.push(handle);
}
// 读取数据的线程
for i in 0..10 {
let data = Arc::clone(&data);
let handle = thread::spawn(move || {
let map = data.read().unwrap();
if let Some(value) = map.get(&(i % 5)) {
println!("读取: {} -> {}", i % 5, value);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}原子类型
rust
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
fn atomic_operations() {
let counter = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..10 {
let counter = Arc::clone(&counter);
let handle = thread::spawn(move || {
for _ in 0..1000 {
counter.fetch_add(1, Ordering::SeqCst);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("原子计数器: {}", counter.load(Ordering::SeqCst));
}
fn compare_and_swap() {
let value = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for i in 0..10 {
let value = Arc::clone(&value);
let handle = thread::spawn(move || {
let expected = i;
let new_value = i + 10;
match value.compare_exchange_weak(
expected,
new_value,
Ordering::SeqCst,
Ordering::SeqCst
) {
Ok(prev) => println!("线程 {} 成功更新: {} -> {}", i, prev, new_value),
Err(current) => println!("线程 {} 更新失败,当前值: {}", i, current),
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("最终值: {}", value.load(Ordering::SeqCst));
}🔄 并发模式
生产者-消费者模式
rust
use std::sync::{Arc, Condvar, Mutex};
use std::collections::VecDeque;
struct Buffer<T> {
queue: Mutex<VecDeque<T>>,
not_empty: Condvar,
not_full: Condvar,
capacity: usize,
}
impl<T> Buffer<T> {
fn new(capacity: usize) -> Self {
Self {
queue: Mutex::new(VecDeque::new()),
not_empty: Condvar::new(),
not_full: Condvar::new(),
capacity,
}
}
fn push(&self, item: T) {
let mut queue = self.queue.lock().unwrap();
// 等待缓冲区有空间
while queue.len() >= self.capacity {
queue = self.not_full.wait(queue).unwrap();
}
queue.push_back(item);
self.not_empty.notify_one();
}
fn pop(&self) -> T {
let mut queue = self.queue.lock().unwrap();
// 等待缓冲区有数据
while queue.is_empty() {
queue = self.not_empty.wait(queue).unwrap();
}
let item = queue.pop_front().unwrap();
self.not_full.notify_one();
item
}
}
fn producer_consumer_pattern() {
let buffer = Arc::new(Buffer::new(5));
// 生产者
let buffer_producer = Arc::clone(&buffer);
let producer = thread::spawn(move || {
for i in 0..10 {
println!("生产: {}", i);
buffer_producer.push(i);
thread::sleep(Duration::from_millis(100));
}
});
// 消费者
let buffer_consumer = Arc::clone(&buffer);
let consumer = thread::spawn(move || {
for _ in 0..10 {
let item = buffer_consumer.pop();
println!("消费: {}", item);
thread::sleep(Duration::from_millis(150));
}
});
producer.join().unwrap();
consumer.join().unwrap();
}工作窃取模式
rust
use std::sync::{Arc, Mutex};
use std::collections::VecDeque;
struct WorkStealing<T> {
local_queue: Mutex<VecDeque<T>>,
global_queue: Arc<Mutex<VecDeque<T>>>,
}
impl<T> WorkStealing<T> {
fn new(global_queue: Arc<Mutex<VecDeque<T>>>) -> Self {
Self {
local_queue: Mutex::new(VecDeque::new()),
global_queue,
}
}
fn push_local(&self, item: T) {
let mut queue = self.local_queue.lock().unwrap();
queue.push_back(item);
}
fn pop_local(&self) -> Option<T> {
let mut queue = self.local_queue.lock().unwrap();
queue.pop_back()
}
fn steal_work(&self) -> Option<T> {
// 先尝试本地队列
if let Some(item) = self.pop_local() {
return Some(item);
}
// 再尝试全局队列
let mut global = self.global_queue.lock().unwrap();
global.pop_front()
}
}
fn work_stealing_pattern() {
let global_queue = Arc::new(Mutex::new(VecDeque::new()));
// 添加一些工作
{
let mut queue = global_queue.lock().unwrap();
for i in 0..20 {
queue.push_back(i);
}
}
let mut handles = vec![];
// 创建工作线程
for thread_id in 0..4 {
let global_queue = Arc::clone(&global_queue);
let worker = WorkStealing::new(global_queue);
let handle = thread::spawn(move || {
let mut processed = 0;
while let Some(work) = worker.steal_work() {
println!("线程 {} 处理工作: {}", thread_id, work);
thread::sleep(Duration::from_millis(10));
processed += 1;
}
println!("线程 {} 完成,处理了 {} 个任务", thread_id, processed);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}⚡ 并发性能优化
线程池
rust
use std::sync::{mpsc, Arc, Mutex};
pub struct ThreadPool {
workers: Vec<Worker>,
sender: mpsc::Sender<Job>,
}
type Job = Box<dyn FnOnce() + Send + 'static>;
impl ThreadPool {
pub fn new(size: usize) -> ThreadPool {
assert!(size > 0);
let (sender, receiver) = mpsc::channel();
let receiver = Arc::new(Mutex::new(receiver));
let mut workers = Vec::with_capacity(size);
for id in 0..size {
workers.push(Worker::new(id, Arc::clone(&receiver)));
}
ThreadPool { workers, sender }
}
pub fn execute<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
let job = Box::new(f);
self.sender.send(job).unwrap();
}
}
struct Worker {
id: usize,
thread: thread::JoinHandle<()>,
}
impl Worker {
fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker {
let thread = thread::spawn(move || loop {
let job = receiver.lock().unwrap().recv().unwrap();
println!("工作线程 {} 执行任务", id);
job();
});
Worker { id, thread }
}
}
fn thread_pool_example() {
let pool = ThreadPool::new(4);
for i in 0..8 {
pool.execute(move || {
println!("任务 {} 开始执行", i);
thread::sleep(Duration::from_secs(1));
println!("任务 {} 执行完成", i);
});
}
thread::sleep(Duration::from_secs(3));
}无锁数据结构
rust
use std::sync::atomic::{AtomicPtr, Ordering};
use std::ptr;
pub struct LockFreeStack<T> {
head: AtomicPtr<Node<T>>,
}
struct Node<T> {
data: T,
next: *mut Node<T>,
}
impl<T> LockFreeStack<T> {
pub fn new() -> Self {
Self {
head: AtomicPtr::new(ptr::null_mut()),
}
}
pub fn push(&self, data: T) {
let new_node = Box::into_raw(Box::new(Node {
data,
next: ptr::null_mut(),
}));
loop {
let head = self.head.load(Ordering::Acquire);
unsafe {
(*new_node).next = head;
}
match self.head.compare_exchange_weak(
head,
new_node,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(_) => continue,
}
}
}
pub fn pop(&self) -> Option<T> {
loop {
let head = self.head.load(Ordering::Acquire);
if head.is_null() {
return None;
}
let next = unsafe { (*head).next };
match self.head.compare_exchange_weak(
head,
next,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => {
let data = unsafe { Box::from_raw(head).data };
return Some(data);
}
Err(_) => continue,
}
}
}
}
unsafe impl<T: Send> Send for LockFreeStack<T> {}
unsafe impl<T: Send> Sync for LockFreeStack<T> {}
fn lock_free_example() {
let stack = Arc::new(LockFreeStack::new());
let mut handles = vec![];
// 推入数据
for i in 0..10 {
let stack = Arc::clone(&stack);
let handle = thread::spawn(move || {
stack.push(i);
println!("推入: {}", i);
});
handles.push(handle);
}
// 弹出数据
for _ in 0..10 {
let stack = Arc::clone(&stack);
let handle = thread::spawn(move || {
if let Some(value) = stack.pop() {
println!("弹出: {}", value);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}📝 本章小结
通过本章学习,你应该掌握了:
基础并发
- ✅ 线程的创建和管理
- ✅ 消息传递机制(Channel)
- ✅ 共享状态的同步(Mutex、RwLock)
- ✅ 原子操作的使用
高级并发模式
- ✅ 生产者-消费者模式
- ✅ 工作窃取模式
- ✅ 线程池的实现
- ✅ 无锁数据结构
性能和安全
- ✅ Rust 的内存安全保证
- ✅ 数据竞争的预防
- ✅ 并发性能优化技巧
- ✅ 选择合适的并发原语
最佳实践
- 优先选择消息传递
- 谨慎使用共享状态
- 利用类型系统保证安全
- 根据场景选择同步原语
继续学习:下一章 - Rust 宏