一、前言

本篇的介绍对象是 CountDownLatch,它同样是基于 AQS 之上扩展的一款多线程场景下的工具类,它可以使一个或多个线程等待其他线程各自执行完毕后再执行。

对于 CountDownLatch 理解,我们可以将单次拆开为 CountDownLatchCountDown 表示倒计时,Latch 表示门闩,当倒计时结束后门闩解除,门就开了。

二、使用场景

要完成一项复杂的任务,任务被划分为子任务1和子任务2,3,4...,为了提高执行任务的效率,采用多线程去完成。

由于子任务1的执行条件依赖于 子任务2,3,4...,需要先执行子任务2,3,4...获取到相应的结果才能执行子任务1,这是 CountDownLatch 就派上用场了。

三、工作原理

给定 CountDownLatch 一个倒计时数,每个线程都能访问 CountDownLatch 实例。当一批线程要协作完成任务,线程 A 可以调用 CountDownLatchawait() 进行等待阻塞。其他线程则做其他业务,当业务执行完成后调用 CountDownLatchcountDown() 减掉倒计时。最后倒计时减到 0 时,阻塞的线程 A 就会被唤醒执行后续的业务。

由于是 CountDownLatch 是基于 AQS 扩展的,因此引用 AQS 模型图可方便我们理解:

图中,state 用于保存倒计时数,Node 节点用于封装等待阻塞的线程。

四、源码解析

我们先通过案例了解 CountDownLatch 基本使用。

  • 案例

我们将 CountDownLatch 当作餐馆服务员,线程比作客人。当客人来到餐馆吃饭时,餐馆服务员负责记录餐桌、客人吃饭的情况。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
public class CountDownLatchTest {

public static void main(String[] args) throws InterruptedException {
// (1)
CountDownLatch countDownLatch = new CountDownLatch(5);
System.out.println(LocalDateTime.now() + " -> " + Thread.currentThread().getName() + " 上菜");

for (int i = 1; i <= 5; i++) {
new Thread(() -> {
try {
System.out.println(LocalDateTime.now() + " -> " + Thread.currentThread().getName() + " 开始吃饭");
Double time = Math.random() * 3 + 1;
TimeUnit.SECONDS.sleep(time.intValue());
System.out.println(LocalDateTime.now() + " -> " + Thread.currentThread().getName() + " 吃饭结束,走人");
// (2) 减去倒计时
countDownLatch.countDown();
} catch (InterruptedException e) {
e.printStackTrace();
}
}, "t" + i).start();
}

// (3) 等待阻塞,当倒计时为 0 就放行
System.out.println(LocalDateTime.now() + " -> " + Thread.currentThread().getName() + " 等待客人结账");
countDownLatch.await();
System.out.println(LocalDateTime.now() + " -> " + Thread.currentThread().getName() + " 客人都走了,开始收摊");
}
}

执行结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
2023-03-15T11:40:08.536 -> main 上菜
2023-03-15T11:40:08.542 -> t1 开始吃饭
2023-03-15T11:40:08.542 -> t2 开始吃饭
2023-03-15T11:40:08.542 -> main 等待客人结账
2023-03-15T11:40:08.542 -> t3 开始吃饭
2023-03-15T11:40:08.542 -> t4 开始吃饭
2023-03-15T11:40:08.542 -> t5 开始吃饭
2023-03-15T11:40:09.543 -> t2 吃饭结束,走人
2023-03-15T11:40:10.543 -> t4 吃饭结束,走人
2023-03-15T11:40:11.542 -> t3 吃饭结束,走人
2023-03-15T11:40:11.542 -> t1 吃饭结束,走人
2023-03-15T11:40:11.542 -> t5 吃饭结束,走人
2023-03-15T11:40:11.542 -> main 客人都走了,开始收摊

当服务员上菜给客人后,需要等待(await())所有客人吃完饭结账后才能收摊,客人吃完饭需要通知服务员吃完饭结账(countDown())。

  • 源码分析

我们按照例子中的代码执行顺序分析。

首先查看 (1) 处代码,即创建 CountDownLatch 实例,进入构造方法中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
public class CountDownLatch {

private static final class Sync extends AbstractQueuedSynchronizer {

Sync(int count) {
setState(count);
}

int getCount() {
return getState();
}

// (4) 尝试获取资源
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}

// (5) 尝试释放资源
protected boolean tryReleaseShared(int releases) {
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}

private final Sync sync;

public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}

// ...省略...

}

在构造方法内部创建了 Sync 实例,而 Sync 是一个静态的内部类, 它继承 AbstractQueuedSynchronizer 类,因此 Sync 拥有了 AQS 的能力,CountDownLatch 的所有操作都是通过 Sync 实例完成的。

调用构造方法传入的 count 值(倒计时数)被传入到 Sync 的构造方法中,其内部调用 setState(count) 方法,该方法来自 AQS,被保存到 AQSstate 中。

此时,AQS 的模型图如下:

回到案例代码中,main 线程创建好 CountDownLatch 实例后, 接着执行 for 循环,其方法体中创建新的线程执行其他业务,都是异步操作。我们顺着当前线程直接来到 (3) 处,即 countDownLatch.await(),跳进源码:

1
2
3
4
5
6
public class CountDownLatch {

public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
}

await() 方法底层通过 Sync 实例调用了 acquireSharedInterruptibly(1) 方法,该方法来自 AQS

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {

// ...省略...

public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
// (6)
if (tryAcquireShared(arg) < 0)
// (7)
doAcquireSharedInterruptibly(arg);
}

}

进入该方法:先判断 main 线程是否被中断,并没有,然后执行 (6) 处代码,即 tryAcquireShared(arg),尝试获取资源权限(判断倒计时是否为 0)。该方法是一个抽象方法,最终通过子类来实现,即上文提到的 Sync 类来实现,跳回 (4) 处:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
public class CountDownLatch {

private static final class Sync extends AbstractQueuedSynchronizer {

Sync(int count) {
setState(count);
}

int getCount() {
return getState();
}

// (4)
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}

// ...省略...

}

// ...省略...
}

tryAcquireShared() 方法中判断 state 值(倒计时)是否为 0 ,是则返回 1,否则返回 -1。

从上文案例的执行结果可以看出,main 线程在线程阻塞之后,其他线程才陆续执行完毕,因此 state 值不可能为 0,最终方法返回 -1,然后执行 (7) 处代码,即 doAcquireSharedInterruptibly(arg) 方法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {

// ...省略...

private void doAcquireSharedInterruptibly(int arg) throws InterruptedException {
// (8) 线程被封装到 Node 节点中
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
// (9) 获取前驱节点
final Node p = node.predecessor();
if (p == head) {
// (10) 再一次尝试获取资源
int r = tryAcquireShared(arg);
if (r >= 0) {
// (11) 设置头结点
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
// (12) 获取资源失败,修改前驱节点的 state 状态
if (shouldParkAfterFailedAcquire(p, node) &&
// (13) 底层调用 LockSupport.lock() 挂起当前线程
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}

// ...省略...
}

该方法在 《AQS 源码详解》 文章中详细解说过,源码上已简单注释说明,此处不多赘述。

最终,main 线程执行到 parkAndCheckInterrupt() 方法中被挂起等待。

此时,AQS 的模型图如下:

我们切换到其他线程视角,案例中 t2 线程先执行完业务调用了 countDown() 方法:

1
2
3
4
5
6
7
8
9
public class CountDownLatch {

// ...省略...

public void countDown() {
sync.releaseShared(1);
}

}

countDown() 方法底层调用 releaseShared(1),该方法来自 AQS

1
2
3
4
5
6
7
8
9
10
11
12
13
14
public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {

// ...省略...

public final boolean releaseShared(int arg) {
// (14)
if (tryReleaseShared(arg)) {
// (15)
doReleaseShared();
return true;
}
return false;
}
}

线程 t2 来到 releaseShared(1) 方法中先执行 (14) 处代码,即 tryReleaseShared(arg) 代码,该方法是个抽象方法,通过子类 Sync 来实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
public class CountDownLatch {

private static final class Sync extends AbstractQueuedSynchronizer {

Sync(int count) {
setState(count);
}

int getCount() {
return getState();
}

// (5) 尝试释放资源
protected boolean tryReleaseShared(int releases) {
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}

// ...省略...

}

// ...省略...

}

进到 tryReleaseShared(arg) 方法中,开启一个无限循环:

  1. 获取 state 值,当前值为 5。
  2. 判断 state 值,如果 为 0 返回 false,否则计算 state 新值(state 旧值 -1),此时新值为 4。
  3. 通过 CAS 方式将新值赋给 state
  4. 如果 state 新值为 0 返回 true,否则返回 false。

t2 线程执行方法最终返回值为 false,线程也跟着结束。

此时,AQS 的模型图如下:

其他条线程的执行步骤与 t2 线程都一样,我们直接跳到最后的 t5 线程视角。当 t5 线程执行 tryReleaseShared(arg)state 值改为 0 后,方法返回 true,开始执行 (15) 处代码,即 doReleaseShared()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {

// ...省略...

private void doReleaseShared() {

for (;;) {
Node h = head;
// (16)
if (h != null && h != tail) {
int ws = h.waitStatus;
// (17) Node.SIGNAL:-1
if (ws == Node.SIGNAL) {
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue;
// (18)
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue;
}
// (19)
if (h == head)
break;
}
}

// ...省略...
}

该方法用于修改 CLH 队列中头结点的 waitStatus 值以及唤醒头结点的后继节点中的线程。 开启一个无限循环:

  1. 获取 CLH 的头结点
  2. 判断头结点(dummy)是否为空,同时头结点是否与尾节点相同。由 AQS 模型图可知,(16) 处的判断是成立的,随后 t5 线程进到 if 方法体中。
  3. 判断头结点(dummy)的 waitStatus 状态,当前状态值为 -1,(17) 处判断成立,将头结点的 waitStatus 通过 CAS 方式还原为 0。
  4. 修改成功后执行 (18) 处代码,即 unparkSuccessor(h),该方法用于查询头结点的后继节点 node1,并通过 LockSupport.unpark(thread) 唤醒节点中的线程(main 线程)。由于该方法在 《AQS 源码详解》 已讲解,此处不多赘述。
  5. t5 线程最后来到 (19) 处,判断成立退出无限循环。

这样 t5 线程释放锁完毕,结束线程,我们转回被唤醒的 main 线程视角:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable {

// ...省略...

private void doAcquireSharedInterruptibly(int arg) throws InterruptedException {
// (8) 线程被封装到 Node 节点中
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
// (9) 获取前驱节点
final Node p = node.predecessor();
if (p == head) {
// (10) 再一次尝试获取资源
int r = tryAcquireShared(arg);
if (r >= 0) {
// (11) 设置头结点
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
// (12) 获取资源失败,修改前驱节点的 state 状态
if (shouldParkAfterFailedAcquire(p, node) &&
// (13) 底层调用 LockSupport.lock() 挂起当前线程
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}

// ...省略...
}

main 线程在执行 (13) 处代码被挂起等待,该方法是在一个无限循环中进行的,当 main 线程被 t5 线程唤醒后开始执行下一轮循环任务:

  1. 获取前驱节点,即 dummy 节点,判断是否头结点,由 AQS 模型图可知,判断成立。
  2. 调用 tryAcquireShared(arg),上文已介绍,由于 state 值被减为 0, 最终该方法返回值为 1。
  3. 之后执行 (11) 处代码,即 setHeadAndPropagate(node, r),该方法用于将 node1 节点设置为新的头结点,移除节点中的线程
  4. 旧的头结点与当前节点解除关系

最终, AQS 的模型图如下:

五、参考资料

CAS 原理新讲

LockSupport 工具介绍

AQS 源码详解