掘金 后端 ( ) • 2024-03-28 11:20

前言

AQS_封面.png AbstractQueuedSynchronizer(AQS)是Java并发包中的一个抽象类, 它为实现阻塞锁和相关的同步器(如信号量、事件等)提供了一个框架, Java内置的线程同步工具如ReentrantLockSemaphoreCountDownLatch等都基于AQS实现

AQS的实现原理

  • 一个变量state, 表示加锁的次数, 如0表示未加锁、1表示已加锁、N表示重入加锁了N次
  • 一个队列来存储等待获取锁的线程, 通过持有队列的headtail来实现队列访问

特性

  • 可重入
  • 独占模式和共享模式
  • 支持超时、支持可中断
  • 支持通过锁获取condition对象, 实现等待/唤醒功能

简单锁示例

在锁工具内部, 我们还需要定义一个类Sync来继承AQS,并覆写tryAcquiretryRelease两个方法即可

  • tryAcquire CAS设置state变量从0到1, 成功则视为拿到锁资源
  • tryRelease 将state变量恢复为0, 即释放锁资源
public class DemoLock {

    public void lock() {
        // 代理给sync的acquire方法
        sync.acquire(1);
    }

    public void unlock() {
        // 代理给sync的release方法
        sync.release(0);
    }

    private static final Sync sync = new Sync();

    private static class Sync extends AbstractQueuedSynchronizer {

        @Override
        protected boolean tryAcquire(int arg) {
            return compareAndSetState(0, 1);
        }

        @Override
        protected boolean tryRelease(int arg) {
            setState(0);
            return true;
        }
    }

}

测试用例如下所示,启动10个线程并发调用increment, 注意unlock要在finally块中执行

class DemoLockTest {
    private static int count = 0;
    private static final DemoLock lock = new DemoLock();

    @Test
    void testLock() throws InterruptedException {
        // 10个线程, 执行内容都是调用increment方法
        List<Thread> threads = new ArrayList<>(10);
        for (int i = 0; i < 10; i++) {
            threads.add(new Thread(this::increment));
        }
        
        // 启动所有线程
        for (Thread thread : threads) {
            thread.start();
        }
        
        // main线程等待所有线程结束
        for (Thread thread : threads) {
            thread.join();
        }
        System.out.println("count = " + count);
    }

    private void increment() {
        for (int i = 0; i < 100000; i++) {
            try {
                lock.lock();
                count++;
            } finally {
                lock.unlock();
            }
        }
    }
}

lock 流程

lock实现就是代理给内部syncacquire方法 acquire方法的逻辑为

  • tryAcquire, 即我们DemoLock中Sync覆写的方法, 实现为CAS设置state从0变1
  • 如果tryAcquire失败
    • addWaiter, 将当前线程包装为链表Node,并加入到等待队列中
    • acquireQueued, 对已在队列中的Node, 循环的尝试获取锁
  • 如果循环获取锁时发现线程被中断, 会执行selfInterrupt重置中断标记

AQS_acquire.png

public final void acquire(int arg) {
    // tryAcquire 尝试获取锁
    if (!tryAcquire(arg) &&
            // addWaiter 将当前线程加入到等待队列中
            // acquireQueued 从等待队列中循环锁, 返回结果为当前节点关联的线程是否被中断了
            acquireQueued(addWaiter(Node.EXCLUSIVE), arg))
        // 如果当前节点关联的线程被中断了, 则重置中断状态
        selfInterrupt();
}

addWaiter方法

private Node addWaiter(Node mode) {
    // 将当前线程包装为链表节点
    Node node = new Node(Thread.currentThread(), mode);
    // 尝试快速入队,如果失败则使用enq方法
    Node pred = tail; // 原链表尾节点
    if (pred != null) {
        // 将当前节点作为新的尾节点
        node.prev = pred;
        if (compareAndSetTail(pred, node)) {
            pred.next = node;
            return node;
        }
    }

    // 原链表尾节点为空或者CAS失败, enq入队
    enq(node);
    return node;
}

enq

private Node enq(final Node node) {
    for (; ; ) {
        Node t = tail;
        if (t == null) { // Must initialize
            // tail为null时,head也必为null,初始化head和tail节点
            if (compareAndSetHead(new Node()))
                tail = head;
        } else {
            // 将当前节点插入到队列尾部
            node.prev = t;
            if (compareAndSetTail(t, node)) {
                t.next = node;
                return t;
            }
        }
    }
}

acquireQueued方法

final boolean acquireQueued(final Node node, int arg) {
    boolean failed = true;
    try {
        boolean interrupted = false;
        for (; ; ) {
            final Node p = node.predecessor(); // p为当前节点的前驱节点
            if (p == head && tryAcquire(arg)) { // 前驱节点为head节点,尝试获取锁
                setHead(node); // 设置当前节点为head节点
                p.next = null; // help GC
                failed = false;
                return interrupted;
            }
            // 挂起即阻塞当前线程,
            // 根据前驱节点的状态来判断当前节点是否需要挂起
            if (shouldParkAfterFailedAcquire(p, node) &&
                    // 将当前线程挂起,并返回当前线程是否被中断
                    parkAndCheckInterrupt())
                // 若当前线程被中断,则设置中断标志为true
                interrupted = true;
        }
    } finally {
        if (failed) // 如果获取锁失败,则取消当前节点的获取锁操作
            cancelAcquire(node);
    }
}

setHead

private void setHead(Node node) {
    head = node;
    node.thread = null;
    node.prev = null;
}

shouldParkAfterFailedAcquire

源码之前, 要了解下Node有哪些状态

/**
 * 当前线程已取消
 */
static final int CANCELLED = 1;
/**
 * 后继线程需要唤醒
 */
static final int SIGNAL = -1;
/**
 * 当前等待条件中
 */
static final int CONDITION = -2;
/**
 * CAS获取锁失败后, 判断是否要挂起当前线程
 */
private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
    int ws = pred.waitStatus; // 获取前驱节点的状态
    if (ws == Node.SIGNAL) // 前驱节点的状态为SIGNAL,表示当前节点需要挂起
        return true;
    if (ws > 0) {
        // 前驱节点的状态大于0,则只能是1(CANCELLED),表示前驱节点已取消,则跳过N个已取消的前驱节点
        do {
            node.prev = pred = pred.prev;
        } while (pred.waitStatus > 0);
        pred.next = node;
    } else {
        // 尝试将前驱节点的状态设置为SIGNAL, 保证下一次循环进入此方法时,返回true
        compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
    }
    return false;
}

parkAndCheckInterrupt

/**
 * 挂起当前线程,并返回当前线程是否被中断
 */
private final boolean parkAndCheckInterrupt() {
    LockSupport.park(this);
    return Thread.interrupted();
}

cancelAcquire

/**
 * 取消获取锁的尝试
 */
private void cancelAcquire(Node node) {
    if (node == null) return;
    node.thread = null;

    // 跳过已取消的前驱节点
    Node pred = node.prev;
    while (pred.waitStatus > 0)
        node.prev = pred = pred.prev;
    // 找到未取消的前驱节点
    Node predNext = pred.next;

    // 设置当前节点的状态为取消
    node.waitStatus = Node.CANCELLED;

    // 如果当前节点是尾节点 则将找到的前驱节点设置为新的尾节点, 即将当前节点从链表中删除
    if (node == tail && compareAndSetTail(node, pred)) {
        compareAndSetNext(pred, predNext, null);
    } else {
        int ws;
        if (
            // 1. 前驱节点不是头节点
                pred != head &&
                        // 2. 并且前驱节点的状态是SIGNAL 或能CAS设置成SINGNAL
                        ((ws = pred.waitStatus) == Node.SIGNAL ||
                                (ws <= 0 && compareAndSetWaitStatus(pred, ws, Node.SIGNAL)))
                        // 3. 并且前驱节点的线程不为空
                        && pred.thread != null) {

            // 将当前节点的后继链到找到的新前驱上, 即将当前节点从链表中删除
            Node next = node.next;
            if (next != null && next.waitStatus <= 0)
                compareAndSetNext(pred, predNext, next);
        } else {
            // 唤醒后继节点的条件为
            // 1. 前驱节点是头节点
            // 2. 前驱节点的状态不是SIGNAL 或者不能CAS设置成SINGNAL
            // 3. 前驱节点的线程为空
            unparkSuccessor(node);
        }

        node.next = node; // help GC
    }
}

unparkSuccessor

/**
 * 唤醒后继节点
 */
private void unparkSuccessor(Node node) {
    int ws = node.waitStatus;
    if (ws < 0) // CAS设置当前节点为初始状态
        compareAndSetWaitStatus(node, ws, 0);

    Node s = node.next;
    if (s == null || s.waitStatus > 0) {
        s = null;
        // 从链表尾部向前查找第一个状态不为取消的节点
        for (Node t = tail; t != null && t != node; t = t.prev)
            if (t.waitStatus <= 0)
                s = t;
    }
    if (s != null) // 唤醒后继节点的线程
        LockSupport.unpark(s.thread);
}

unlock 流程

unlock实现就是代理给内部sync的release方法,实现比较简单

  • 调用DemoLock覆写的tryRelease释放锁
  • 释放锁成功, 唤醒后继节点
public final boolean release(int arg) {
    if (tryRelease(arg)) {
        // 释放锁成功后, 需要唤醒后继节点
        Node h = head;
        if (h != null && h.waitStatus != 0)
            unparkSuccessor(h);
        return true;
    }
    return false;
}