Skip to content

Zig 原子操作

原子操作是多线程编程中确保数据一致性的重要工具。本章将详细介绍 Zig 中的原子操作和并发编程基础。

原子操作基础

什么是原子操作?

原子操作是不可分割的操作,在多线程环境中能够保证操作的完整性:

zig
const std = @import("std");

pub fn main() void {
    // 普通变量(非线程安全)
    var normal_counter: i32 = 0;
    
    // 原子变量(线程安全)
    var atomic_counter = std.atomic.Atomic(i32).init(0);
    
    std.debug.print("普通计数器: {}\n", .{normal_counter});
    std.debug.print("原子计数器: {}\n", .{atomic_counter.load(.Monotonic)});
    
    // 原子操作
    _ = atomic_counter.fetchAdd(5, .Monotonic);
    std.debug.print("原子计数器 +5: {}\n", .{atomic_counter.load(.Monotonic)});
    
    _ = atomic_counter.fetchSub(2, .Monotonic);
    std.debug.print("原子计数器 -2: {}\n", .{atomic_counter.load(.Monotonic)});
}

内存序 (Memory Ordering)

内存序定义了原子操作的同步和排序约束:

zig
const std = @import("std");

pub fn main() void {
    var atomic_value = std.atomic.Atomic(i32).init(0);
    
    // 不同的内存序
    std.debug.print("内存序示例:\n");
    
    // Unordered: 最弱的内存序,只保证原子性
    atomic_value.store(10, .Unordered);
    const val1 = atomic_value.load(.Unordered);
    std.debug.print("Unordered: {}\n", .{val1});
    
    // Monotonic: 保证对同一原子变量的操作有序
    atomic_value.store(20, .Monotonic);
    const val2 = atomic_value.load(.Monotonic);
    std.debug.print("Monotonic: {}\n", .{val2});
    
    // Acquire/Release: 获取-释放语义
    atomic_value.store(30, .Release);
    const val3 = atomic_value.load(.Acquire);
    std.debug.print("Acquire/Release: {}\n", .{val3});
    
    // SeqCst: 最强的内存序,顺序一致性
    atomic_value.store(40, .SeqCst);
    const val4 = atomic_value.load(.SeqCst);
    std.debug.print("SeqCst: {}\n", .{val4});
}

基本原子操作

加载和存储

zig
const std = @import("std");

pub fn main() void {
    var atomic_int = std.atomic.Atomic(i32).init(42);
    var atomic_bool = std.atomic.Atomic(bool).init(false);
    var atomic_ptr = std.atomic.Atomic(?*i32).init(null);
    
    // 加载值
    const int_val = atomic_int.load(.Monotonic);
    const bool_val = atomic_bool.load(.Monotonic);
    const ptr_val = atomic_ptr.load(.Monotonic);
    
    std.debug.print("加载的值:\n");
    std.debug.print("  整数: {}\n", .{int_val});
    std.debug.print("  布尔: {}\n", .{bool_val});
    std.debug.print("  指针: {?*}\n", .{ptr_val});
    
    // 存储值
    atomic_int.store(100, .Monotonic);
    atomic_bool.store(true, .Monotonic);
    
    var target_value: i32 = 999;
    atomic_ptr.store(&target_value, .Monotonic);
    
    std.debug.print("\n存储后的值:\n");
    std.debug.print("  整数: {}\n", .{atomic_int.load(.Monotonic)});
    std.debug.print("  布尔: {}\n", .{atomic_bool.load(.Monotonic)});
    if (atomic_ptr.load(.Monotonic)) |ptr| {
        std.debug.print("  指针指向的值: {}\n", .{ptr.*});
    }
}

交换操作

zig
const std = @import("std");

pub fn main() void {
    var atomic_value = std.atomic.Atomic(i32).init(10);
    
    std.debug.print("初始值: {}\n", .{atomic_value.load(.Monotonic)});
    
    // 交换操作:设置新值并返回旧值
    const old_value = atomic_value.swap(20, .Monotonic);
    std.debug.print("交换后 - 旧值: {}, 新值: {}\n", 
                    .{ old_value, atomic_value.load(.Monotonic) });
    
    // 比较并交换:只有当前值等于期望值时才交换
    const expected: i32 = 20;
    const new_value: i32 = 30;
    
    if (atomic_value.compareAndSwap(expected, new_value, .Monotonic, .Monotonic)) |actual| {
        std.debug.print("CAS 失败 - 期望: {}, 实际: {}\n", .{ expected, actual });
    } else {
        std.debug.print("CAS 成功 - 新值: {}\n", .{atomic_value.load(.Monotonic)});
    }
    
    // 尝试失败的 CAS
    const wrong_expected: i32 = 99;
    if (atomic_value.compareAndSwap(wrong_expected, 40, .Monotonic, .Monotonic)) |actual| {
        std.debug.print("CAS 失败 - 期望: {}, 实际: {}\n", .{ wrong_expected, actual });
    } else {
        std.debug.print("CAS 成功\n", .{});
    }
}

算术操作

zig
const std = @import("std");

pub fn main() void {
    var counter = std.atomic.Atomic(i32).init(0);
    
    std.debug.print("初始计数器: {}\n", .{counter.load(.Monotonic)});
    
    // 原子加法
    const old_add = counter.fetchAdd(5, .Monotonic);
    std.debug.print("fetchAdd(5) - 旧值: {}, 新值: {}\n", 
                    .{ old_add, counter.load(.Monotonic) });
    
    // 原子减法
    const old_sub = counter.fetchSub(2, .Monotonic);
    std.debug.print("fetchSub(2) - 旧值: {}, 新值: {}\n", 
                    .{ old_sub, counter.load(.Monotonic) });
    
    // 原子按位与
    counter.store(0b1111, .Monotonic);
    const old_and = counter.fetchAnd(0b1010, .Monotonic);
    std.debug.print("fetchAnd(0b1010) - 旧值: 0b{b:0>4}, 新值: 0b{b:0>4}\n", 
                    .{ old_and, counter.load(.Monotonic) });
    
    // 原子按位或
    const old_or = counter.fetchOr(0b0101, .Monotonic);
    std.debug.print("fetchOr(0b0101) - 旧值: 0b{b:0>4}, 新值: 0b{b:0>4}\n", 
                    .{ old_or, counter.load(.Monotonic) });
    
    // 原子按位异或
    const old_xor = counter.fetchXor(0b1100, .Monotonic);
    std.debug.print("fetchXor(0b1100) - 旧值: 0b{b:0>4}, 新值: 0b{b:0>4}\n", 
                    .{ old_xor, counter.load(.Monotonic) });
}

多线程示例

简单的多线程计数器

zig
const std = @import("std");

const ThreadData = struct {
    counter: *std.atomic.Atomic(i32),
    thread_id: u32,
    iterations: u32,
};

fn workerThread(data: *ThreadData) void {
    for (0..data.iterations) |_| {
        _ = data.counter.fetchAdd(1, .Monotonic);
        
        // 模拟一些工作
        std.time.sleep(1000000); // 1ms
    }
    
    std.debug.print("线程 {} 完成,执行了 {} 次递增\n", 
                    .{ data.thread_id, data.iterations });
}

pub fn main() !void {
    var gpa = std.heap.GeneralPurposeAllocator(.{}){};
    defer _ = gpa.deinit();
    const allocator = gpa.allocator();
    
    var counter = std.atomic.Atomic(i32).init(0);
    const num_threads = 4;
    const iterations_per_thread = 100;
    
    // 创建线程数据
    var thread_data = try allocator.alloc(ThreadData, num_threads);
    defer allocator.free(thread_data);
    
    var threads = try allocator.alloc(std.Thread, num_threads);
    defer allocator.free(threads);
    
    // 启动线程
    for (0..num_threads) |i| {
        thread_data[i] = ThreadData{
            .counter = &counter,
            .thread_id = @intCast(i),
            .iterations = iterations_per_thread,
        };
        
        threads[i] = try std.Thread.spawn(.{}, workerThread, .{&thread_data[i]});
    }
    
    // 等待所有线程完成
    for (threads) |thread| {
        thread.join();
    }
    
    const final_count = counter.load(.Monotonic);
    const expected_count = num_threads * iterations_per_thread;
    
    std.debug.print("\n最终计数: {}\n", .{final_count});
    std.debug.print("期望计数: {}\n", .{expected_count});
    std.debug.print("计数正确: {}\n", .{final_count == expected_count});
}

生产者-消费者模式

zig
const std = @import("std");

const RingBuffer = struct {
    buffer: []i32,
    head: std.atomic.Atomic(usize),
    tail: std.atomic.Atomic(usize),
    size: usize,
    
    const Self = @This();
    
    pub fn init(allocator: std.mem.Allocator, size: usize) !Self {
        const buffer = try allocator.alloc(i32, size);
        return Self{
            .buffer = buffer,
            .head = std.atomic.Atomic(usize).init(0),
            .tail = std.atomic.Atomic(usize).init(0),
            .size = size,
        };
    }
    
    pub fn deinit(self: *Self, allocator: std.mem.Allocator) void {
        allocator.free(self.buffer);
    }
    
    pub fn push(self: *Self, value: i32) bool {
        const current_tail = self.tail.load(.Acquire);
        const next_tail = (current_tail + 1) % self.size;
        
        // 检查缓冲区是否已满
        if (next_tail == self.head.load(.Acquire)) {
            return false; // 缓冲区已满
        }
        
        self.buffer[current_tail] = value;
        self.tail.store(next_tail, .Release);
        return true;
    }
    
    pub fn pop(self: *Self) ?i32 {
        const current_head = self.head.load(.Acquire);
        
        // 检查缓冲区是否为空
        if (current_head == self.tail.load(.Acquire)) {
            return null; // 缓冲区为空
        }
        
        const value = self.buffer[current_head];
        self.head.store((current_head + 1) % self.size, .Release);
        return value;
    }
    
    pub fn isEmpty(self: *Self) bool {
        return self.head.load(.Acquire) == self.tail.load(.Acquire);
    }
    
    pub fn isFull(self: *Self) bool {
        const current_tail = self.tail.load(.Acquire);
        const next_tail = (current_tail + 1) % self.size;
        return next_tail == self.head.load(.Acquire);
    }
};

const ProducerData = struct {
    buffer: *RingBuffer,
    producer_id: u32,
    items_to_produce: u32,
};

const ConsumerData = struct {
    buffer: *RingBuffer,
    consumer_id: u32,
    items_to_consume: u32,
};

fn producer(data: *ProducerData) void {
    for (0..data.items_to_produce) |i| {
        const value: i32 = @intCast(data.producer_id * 1000 + i);
        
        // 尝试推入值,如果缓冲区满了就等待
        while (!data.buffer.push(value)) {
            std.time.sleep(1000); // 1μs
        }
        
        std.debug.print("生产者 {} 生产了: {}\n", .{ data.producer_id, value });
        std.time.sleep(10000); // 10μs
    }
    
    std.debug.print("生产者 {} 完成\n", .{data.producer_id});
}

fn consumer(data: *ConsumerData) void {
    var consumed: u32 = 0;
    
    while (consumed < data.items_to_consume) {
        if (data.buffer.pop()) |value| {
            std.debug.print("消费者 {} 消费了: {}\n", .{ data.consumer_id, value });
            consumed += 1;
            std.time.sleep(15000); // 15μs
        } else {
            std.time.sleep(1000); // 1μs
        }
    }
    
    std.debug.print("消费者 {} 完成\n", .{data.consumer_id});
}

pub fn main() !void {
    var gpa = std.heap.GeneralPurposeAllocator(.{}){};
    defer _ = gpa.deinit();
    const allocator = gpa.allocator();
    
    // 创建环形缓冲区
    var ring_buffer = try RingBuffer.init(allocator, 10);
    defer ring_buffer.deinit(allocator);
    
    const num_producers = 2;
    const num_consumers = 2;
    const items_per_producer = 5;
    const items_per_consumer = 5;
    
    // 创建线程数据
    var producer_data = try allocator.alloc(ProducerData, num_producers);
    defer allocator.free(producer_data);
    
    var consumer_data = try allocator.alloc(ConsumerData, num_consumers);
    defer allocator.free(consumer_data);
    
    var producer_threads = try allocator.alloc(std.Thread, num_producers);
    defer allocator.free(producer_threads);
    
    var consumer_threads = try allocator.alloc(std.Thread, num_consumers);
    defer allocator.free(consumer_threads);
    
    // 启动生产者线程
    for (0..num_producers) |i| {
        producer_data[i] = ProducerData{
            .buffer = &ring_buffer,
            .producer_id = @intCast(i),
            .items_to_produce = items_per_producer,
        };
        producer_threads[i] = try std.Thread.spawn(.{}, producer, .{&producer_data[i]});
    }
    
    // 启动消费者线程
    for (0..num_consumers) |i| {
        consumer_data[i] = ConsumerData{
            .buffer = &ring_buffer,
            .consumer_id = @intCast(i),
            .items_to_consume = items_per_consumer,
        };
        consumer_threads[i] = try std.Thread.spawn(.{}, consumer, .{&consumer_data[i]});
    }
    
    // 等待所有线程完成
    for (producer_threads) |thread| {
        thread.join();
    }
    
    for (consumer_threads) |thread| {
        thread.join();
    }
    
    std.debug.print("所有线程完成\n");
}

原子指针操作

无锁栈

zig
const std = @import("std");

fn LockFreeStack(comptime T: type) type {
    return struct {
        const Self = @This();
        
        const Node = struct {
            data: T,
            next: ?*Node,
        };
        
        head: std.atomic.Atomic(?*Node),
        allocator: std.mem.Allocator,
        
        pub fn init(allocator: std.mem.Allocator) Self {
            return Self{
                .head = std.atomic.Atomic(?*Node).init(null),
                .allocator = allocator,
            };
        }
        
        pub fn deinit(self: *Self) void {
            // 清理剩余节点
            while (self.pop()) |_| {}
        }
        
        pub fn push(self: *Self, data: T) !void {
            const new_node = try self.allocator.create(Node);
            new_node.data = data;
            
            // 无锁推入
            while (true) {
                const current_head = self.head.load(.Acquire);
                new_node.next = current_head;
                
                if (self.head.compareAndSwap(current_head, new_node, .Release, .Acquire) == null) {
                    break; // 成功
                }
                // 失败,重试
            }
        }
        
        pub fn pop(self: *Self) ?T {
            while (true) {
                const current_head = self.head.load(.Acquire);
                if (current_head == null) {
                    return null; // 栈为空
                }
                
                const next = current_head.?.next;
                if (self.head.compareAndSwap(current_head, next, .Release, .Acquire) == null) {
                    const data = current_head.?.data;
                    self.allocator.destroy(current_head.?);
                    return data;
                }
                // 失败,重试
            }
        }
        
        pub fn isEmpty(self: *Self) bool {
            return self.head.load(.Acquire) == null;
        }
    };
}

const StackData = struct {
    stack: *LockFreeStack(i32),
    thread_id: u32,
    operations: u32,
};

fn stackWorker(data: *StackData) !void {
    var rng = std.rand.DefaultPrng.init(@intCast(std.time.timestamp()));
    const random = rng.random();
    
    for (0..data.operations) |i| {
        if (random.boolean()) {
            // 推入操作
            const value: i32 = @intCast(data.thread_id * 1000 + i);
            try data.stack.push(value);
            std.debug.print("线程 {} 推入: {}\n", .{ data.thread_id, value });
        } else {
            // 弹出操作
            if (data.stack.pop()) |value| {
                std.debug.print("线程 {} 弹出: {}\n", .{ data.thread_id, value });
            } else {
                std.debug.print("线程 {} 弹出: 栈为空\n", .{data.thread_id});
            }
        }
        
        std.time.sleep(1000000); // 1ms
    }
    
    std.debug.print("线程 {} 完成\n", .{data.thread_id});
}

pub fn main() !void {
    var gpa = std.heap.GeneralPurposeAllocator(.{}){};
    defer _ = gpa.deinit();
    const allocator = gpa.allocator();
    
    var stack = LockFreeStack(i32).init(allocator);
    defer stack.deinit();
    
    const num_threads = 3;
    const operations_per_thread = 10;
    
    var thread_data = try allocator.alloc(StackData, num_threads);
    defer allocator.free(thread_data);
    
    var threads = try allocator.alloc(std.Thread, num_threads);
    defer allocator.free(threads);
    
    // 启动线程
    for (0..num_threads) |i| {
        thread_data[i] = StackData{
            .stack = &stack,
            .thread_id = @intCast(i),
            .operations = operations_per_thread,
        };
        threads[i] = try std.Thread.spawn(.{}, stackWorker, .{&thread_data[i]});
    }
    
    // 等待所有线程完成
    for (threads) |thread| {
        thread.join();
    }
    
    // 清空剩余元素
    std.debug.print("\n清空栈中剩余元素:\n");
    while (stack.pop()) |value| {
        std.debug.print("剩余: {}\n", .{value});
    }
    
    std.debug.print("栈为空: {}\n", .{stack.isEmpty()});
}

内存屏障和同步

内存屏障示例

zig
const std = @import("std");

var data: i32 = 0;
var flag = std.atomic.Atomic(bool).init(false);

fn writer() void {
    // 写入数据
    data = 42;
    
    // 内存屏障确保数据写入在标志设置之前完成
    std.atomic.fence(.Release);
    
    // 设置标志
    flag.store(true, .Release);
    
    std.debug.print("写入线程: 数据已写入并设置标志\n", .{});
}

fn reader() void {
    // 等待标志被设置
    while (!flag.load(.Acquire)) {
        std.time.sleep(1000); // 1μs
    }
    
    // 内存屏障确保在读取数据之前标志已被正确读取
    std.atomic.fence(.Acquire);
    
    // 读取数据
    const value = data;
    std.debug.print("读取线程: 读取到数据 {}\n", .{value});
}

pub fn main() !void {
    std.debug.print("内存屏障示例:\n");
    
    const writer_thread = try std.Thread.spawn(.{}, writer, .{});
    const reader_thread = try std.Thread.spawn(.{}, reader, .{});
    
    writer_thread.join();
    reader_thread.join();
    
    std.debug.print("同步完成\n");
}

原子操作的最佳实践

1. 选择合适的内存序

zig
const std = @import("std");

pub fn main() void {
    var counter = std.atomic.Atomic(i32).init(0);
    
    // ✅ 对于简单计数器,Monotonic 通常足够
    _ = counter.fetchAdd(1, .Monotonic);
    
    // ✅ 对于标志变量,使用 Acquire/Release
    var ready_flag = std.atomic.Atomic(bool).init(false);
    ready_flag.store(true, .Release);
    const is_ready = ready_flag.load(.Acquire);
    
    // ✅ 对于需要严格顺序的操作,使用 SeqCst
    var sequence_counter = std.atomic.Atomic(u64).init(0);
    _ = sequence_counter.fetchAdd(1, .SeqCst);
    
    std.debug.print("计数器: {}\n", .{counter.load(.Monotonic)});
    std.debug.print("就绪标志: {}\n", .{is_ready});
    std.debug.print("序列计数器: {}\n", .{sequence_counter.load(.SeqCst)});
}

2. 避免 ABA 问题

zig
const std = @import("std");

// 使用版本号避免 ABA 问题
const VersionedPointer = struct {
    ptr: ?*Node,
    version: u64,
    
    const Node = struct {
        data: i32,
        next: ?*Node,
    };
};

const SafeStack = struct {
    head: std.atomic.Atomic(VersionedPointer),
    allocator: std.mem.Allocator,
    
    const Self = @This();
    
    pub fn init(allocator: std.mem.Allocator) Self {
        return Self{
            .head = std.atomic.Atomic(VersionedPointer).init(VersionedPointer{ .ptr = null, .version = 0 }),
            .allocator = allocator,
        };
    }
    
    pub fn push(self: *Self, data: i32) !void {
        const new_node = try self.allocator.create(VersionedPointer.Node);
        new_node.data = data;
        
        while (true) {
            const current = self.head.load(.Acquire);
            new_node.next = current.ptr;
            
            const new_head = VersionedPointer{
                .ptr = new_node,
                .version = current.version + 1,
            };
            
            if (self.head.compareAndSwap(current, new_head, .Release, .Acquire) == null) {
                break;
            }
        }
    }
    
    pub fn pop(self: *Self) ?i32 {
        while (true) {
            const current = self.head.load(.Acquire);
            if (current.ptr == null) {
                return null;
            }
            
            const new_head = VersionedPointer{
                .ptr = current.ptr.?.next,
                .version = current.version + 1,
            };
            
            if (self.head.compareAndSwap(current, new_head, .Release, .Acquire) == null) {
                const data = current.ptr.?.data;
                self.allocator.destroy(current.ptr.?);
                return data;
            }
        }
    }
};

pub fn main() !void {
    var gpa = std.heap.GeneralPurposeAllocator(.{}){};
    defer _ = gpa.deinit();
    const allocator = gpa.allocator();
    
    var safe_stack = SafeStack.init(allocator);
    
    // 测试栈操作
    try safe_stack.push(10);
    try safe_stack.push(20);
    try safe_stack.push(30);
    
    std.debug.print("安全栈测试:\n");
    while (safe_stack.pop()) |value| {
        std.debug.print("弹出: {}\n", .{value});
    }
}

3. 性能考虑

zig
const std = @import("std");

pub fn main() void {
    const iterations = 1000000;
    
    // 测试不同原子操作的性能
    var counter = std.atomic.Atomic(i64).init(0);
    
    const start_time = std.time.nanoTimestamp();
    
    // 使用最弱的内存序进行简单计数
    for (0..iterations) |_| {
        _ = counter.fetchAdd(1, .Monotonic);
    }
    
    const end_time = std.time.nanoTimestamp();
    const duration = end_time - start_time;
    
    std.debug.print("性能测试结果:\n");
    std.debug.print("操作次数: {}\n", .{iterations});
    std.debug.print("总时间: {} ns\n", .{duration});
    std.debug.print("平均每次操作: {d:.2} ns\n", .{@as(f64, @floatFromInt(duration)) / @as(f64, @floatFromInt(iterations))});
    std.debug.print("最终计数: {}\n", .{counter.load(.Monotonic)});
}

总结

本章详细介绍了 Zig 的原子操作:

  • ✅ 原子操作的基本概念和内存序
  • ✅ 基本的原子操作:加载、存储、交换
  • ✅ 算术和位运算的原子操作
  • ✅ 多线程编程示例
  • ✅ 无锁数据结构的实现
  • ✅ 内存屏障和同步机制
  • ✅ 最佳实践和性能考虑

原子操作是构建高性能并发程序的基础工具。正确使用原子操作可以避免数据竞争,确保程序在多线程环境下的正确性。在下一章中,我们将学习 Zig 的异步编程。

本站内容仅供学习和研究使用。