1 简介

CountDownLatch可以使一个或多个线程等待其他线程各自执行完毕后再执行。
CountDownLatch 是多线程控制的一种工具,它被称为 门阀、 计数器或者 闭锁。这个工具经常用来用来协调多个线程之间的同步,或者说起到线程之间的通信(而不是用作互斥的作用)。
CountDownLatch是一个同步工具类,它使用给定的 count初始化, await()方法会一直阻塞,直到计数器的值变为零(由于 countDown()方法被调用导致的),这时会释放所有等待的线程,且之后再调用 await()方法会直接返回,不会阻塞。 CountDownLatch是一个 一次性的类,计数器不能被重置,这一点与 CyclicBarrier不同。另一个不同点是: CountDownLatch是让所有线程 等待计数器的值变为零再继续执行;而 CyclicBarrier是要 等待指定个数的线程到达 Barrier 的位置再一起继续执行。
CountDownLatch 定义了一个计数器,和一个阻塞队列, 当计数器的值递减为0之前,阻塞队列里面的线程处于挂起状态,当计数器递减到0时会唤醒阻塞队列所有线程,这里的计数器是一个标志,可以表示一个任务一个线程,也可以表示一个倒计时器,CountDownLatch可以解决那些一个或者多个线程在执行之前必须依赖于某些必要的前提业务先执行的场景。
在这里插入图片描述

2 方法

2.1 构造方法 CountDownLatch(int count)

计数器的初始值为count,也就是说countDown()方法至少被调用count次等待的线程才会被唤醒。如果count为负数或抛出异常IllegalArgumentException。

2.2 countDown()

如果当前计数器的值大于零,则将其减一,如果新的计数器值等于零,则释放所有等待的线程。
计数器变为零后就不会再减了。
如果当前计数器为零,则什么都不做。
此方法不会阻塞。

2.3 long getCount()

获取当前计数器的值。

2.4 await()

导致当前线程等待,直到计数器的值变为零,除非线程被中断。如果计数器已经为零了,则立即返回。以下两种情况会抛出InterruptedException并清空中断标志:

在调用wait()方法前,当前线程的中断状态已经为 true 了
在等待的过程中被中断了
模拟两种中断情况

// 在调用`wait()`方法前,当前线程的中断状态已经为 true 了
public static void test1() throws InterruptedException {
	CountDownLatch cdl = new CountDownLatch(1);
	Thread.currentThread().interrupt();
	cdl.await();
}

// 在等待的过程中被中断了
public static void test2() throws InterruptedException {
	CountDownLatch cdl = new CountDownLatch(1);
	Thread t1 = new Thread(() -> {
		try {
			cdl.await();
		} catch (InterruptedException e) {
			e.printStackTrace();
		}
	}, "t1");
	t1.start();
	Thread.sleep(500);
	t1.interrupt();
}

boolean await(long timeout, TimeUnit unit)
此方法与await()的不同点:

此方法至多会等待指定的时间,超时后会自动唤醒,若 timeout 小于等于零,则不会等待
次方法有 boolean 类型的返回值:若计数器变为零了,则返回 true;若指定的等待时间过去了,则返回 false
等待指定时间

public static void test3() throws InterruptedException {
	CountDownLatch cdl = new CountDownLatch(1);
	log.info("开始 await");
	boolean b = cdl.await(2, TimeUnit.SECONDS);
	log.info("结束 await, 返回值: {}", b);
}

计数器在等待过程中变为零

public static void test4() throws InterruptedException {
	CountDownLatch cdl = new CountDownLatch(1);
	Thread t1 = new Thread(() -> {
		try {
			log.info("开始 await");
			// 至多等待 2 秒
			boolean b = cdl.await(2, TimeUnit.SECONDS);
			log.info("结束 await, 返回值: {}", b);
		} catch (InterruptedException e) {
			e.printStackTrace();
		}
	}, "t1");
	t1.start();

	Thread.sleep(1000);
	cdl.countDown();
}

计数器在调用await()方法前就变为 0 了

public static void test5() throws InterruptedException {
	CountDownLatch cdl = new CountDownLatch(1);
	cdl.countDown();
	log.info("开始 await");
	boolean b = cdl.await(2, TimeUnit.SECONDS);
	log.info("结束 await, 返回值: {}", b);
}

3 使用场景

3.1 多线程优化报表统计

3.1.1 功能现状

运营系统有统计报表、业务为统计每日的用户新增数量、订单数量、商品的总销量、总销售额…等多项指标统一展示出来,因为数据量比较大,统计指标涉及到的业务范围也比较多,所以这个统计报表的页面一直加载很慢,所以需要对统计报表这块性能需进行优化。

3.1.2 问题分析

统计报表页面涉及到的统计指标数据比较多,每个指标需要单独的去查询统计数据库数据,单个指标只要几秒钟,但是页面的指标有10多个,所以整体下来页面渲染需要将近一分钟。

3.1.3 解决方案

任务时间长是因为统计指标多,而且指标是串行的方式去进行统计的,我们只需要考虑把这些指标从串行化的执行方式改成并行的执行方式,那么整个页面的时间的渲染时间就会大大的缩短, 如何让多个线程同步的执行任务,我们这里考虑使用多线程,每个查询任务单独创建一个线程去执行,这样每个统计指标就可以并行的处理了。

3.1.4 要求

因为主线程需要每个线程的统计结果进行聚合,然后返回给前端渲染,所以这里需要提供一种机制让主线程等所有的子线程都执行完之后再对每个线程统计的指标进行聚合。 这里我们使用CountDownLatch 来完成此功能。

3.1.5 模拟代码

1、分别统计4个指标用户新增数量、订单数量、商品的总销量、总销售额;

2、假设每个指标执行时间为3秒。如果是串行化的统计方式那么总执行时间会为12秒。

3、我们这里使用多线程并行,开启4个子线程分别进行统计

4、主线程等待4个子线程都执行完毕之后,返回结果给前端。

  //用于聚合所有的统计指标
    private static Map map=new HashMap();
    //创建计数器,这里需要统计4个指标
    private static CountDownLatch countDownLatch=new CountDownLatch(4);
​
    public static void main(String[] args) {
​
        //记录开始时间
        long startTime=System.currentTimeMillis();
​
        Thread countUserThread=new Thread(new Runnable() {
            public void run() {
                try {
                    System.out.println("正在统计新增用户数量");
                    Thread.sleep(3000);//任务执行需要3秒
                    map.put("userNumber",1);//保存结果值
                    countDownLatch.countDown();//标记已经完成一个任务
                    System.out.println("统计新增用户数量完毕");
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
​
            }
        });
        Thread countOrderThread=new Thread(new Runnable() {
            public void run() {
                try {
                    System.out.println("正在统计订单数量");
                    Thread.sleep(3000);//任务执行需要3秒
                    map.put("countOrder",2);//保存结果值
                    countDownLatch.countDown();//标记已经完成一个任务
                    System.out.println("统计订单数量完毕");
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
​
            }
        });
​
        Thread countGoodsThread=new Thread(new Runnable() {
            public void run() {
                try {
                    System.out.println("正在商品销量");
                    Thread.sleep(3000);//任务执行需要3秒
                    map.put("countGoods",3);//保存结果值
                    countDownLatch.countDown();//标记已经完成一个任务
                    System.out.println("统计商品销量完毕");
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
​
            }
        });
​
        Thread countmoneyThread=new Thread(new Runnable() {
            public void run() {
                try {
                    System.out.println("正在总销售额");
                    Thread.sleep(3000);//任务执行需要3秒
                    map.put("countmoney",4);//保存结果值
                    countDownLatch.countDown();//标记已经完成一个任务
                    System.out.println("统计销售额完毕");
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
​
            }
        });
        //启动子线程执行任务
        countUserThread.start();
        countGoodsThread.start();
        countOrderThread.start();
        countmoneyThread.start();
​
        try {
            //主线程等待所有统计指标执行完毕
            countDownLatch.await();
            long endTime=System.currentTimeMillis();//记录结束时间
            System.out.println("------统计指标全部完成--------");
            System.out.println("统计结果为:"+map.toString());
            System.out.println("任务总执行时间为"+(endTime-startTime)/1000+"秒");
​
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
​
​
    }

在这里插入图片描述

4 实现原理

4.1 创建计数器

当我们调用CountDownLatch countDownLatch=new CountDownLatch(4) 时候,此时会创建一个AQS的同步队列,并把创建CountDownLatch 传进来的计数器赋值给AQS队列的 state,所以state的值也代表CountDownLatch所剩余的计数次数。
CountDownLatch 使用起来比较简单,但是却非常有用,现在你可以在你的工具箱中加上 CountDownLatch 这个工具类了。下面我们就来深入认识一下 CountDownLatch。
CountDownLatch 的底层是由 AbstractQueuedSynchronizer 支持,而 AQS 的数据结构的核心就是两个队列,一个是 同步队列(sync queue),一个是条件队列(condition queue)。

public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);//创建同步队列,并设置初始计数器值
 }

4.2 阻塞线程

当我们调用countDownLatch.wait()的时候,会创建一个节点,加入到AQS阻塞队列,并同时把当前线程挂起。

public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
}

判断计数器是技术完毕,未完毕则把当前线程加入阻塞队列

public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        //锁重入次数大于0 则新建节点加入阻塞队列,挂起当前线程
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
 }

构建阻塞队列的双向链表,挂起当前线程

private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        //新建节点加入阻塞队列
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                //获得当前节点pre节点
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);//返回锁的state
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return;
                    }
                }
                //重组双向链表,清空无效节点,挂起当前线程
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }

4.3 计数器递减

当我们调用countDownLatch.down()方法的时候,会对计数器进行减1操作,AQS内部是通过释放锁的方式,对state进行减1操作,当state=0的时候证明计数器已经递减完毕,此时会将AQS阻塞队列里的节点线程全部唤醒。

public void countDown() {
        //递减锁重入次数,当state=0时唤醒所有阻塞线程
        sync.releaseShared(1);
}
public final boolean releaseShared(int arg) {
        //递减锁的重入次数
        if (tryReleaseShared(arg)) {
            doReleaseShared();//唤醒队列所有阻塞的节点
            return true;
        }
        return false;
    }
 private void doReleaseShared() {
        //唤醒所有阻塞队列里面的线程
        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {//节点是否在等待唤醒状态
                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))//修改状态为初始
                        continue;
                    unparkSuccessor(h);//成功则唤醒线程
                }
                else if (ws == 0 &&
                         !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
 }

5 Thread.join()和CountDownLatch的区别

Thread.join()是Thread类的一个方法,Thread.join()的实现是依靠Object的wait()和notifyAll()来完成的,而CountDownLatch是JUC包中的一个工具类。

当我们使用ExecutorService ,就不能使用join,必须使用CountDownLatch比如:

ExecutorService service = Executors.newFixedThreadPool(5);
final CountDownLatch latch = new CountDownLatch(5);
for(int x = 0; x < 5; x++) {
service.submit(new Runnable() {
public void run() {
// do something
latch.countDown();
}
});
}
latch.await();
调用join方法需要等待thread执行完毕才能继续向下执行,而CountDownLatch只需要检查计数器的值为零就可以继续向下执行,相比之下,CountDownLatch更加灵活一些,可以实现一些更加复杂的业务场景。

6 示例

当我们调用CountDownLatch的countDown()方法时,N就会减1,CountDownLatch的await()方法 会阻塞当前线程,直到N变成零。
CountDownLatch 方法
await() 阻塞当前线程,直到计数器为零为止;
await(long timeout, TimeUnit unit) await()的重载方法,可以指定阻塞时长;
countDown() 计数器减1,如果计数达到零,释放所有等待的线程。
getCount() 返回当前计数

以下代码均来源于源码的注释
Driver、Worker
下面是两个类,其中一组 Worker 线程使用了两个CountDownLatch:

第一个是启动信号,阻止任何 Worker 继续,直到 Driver 让他们继续
第二个是完成信号,它允许 Driver 等待所有的 Worker 完成

class Driver { // ...
    public static void main(String[] args) throws InterruptedException {
        int N = 5;
        CountDownLatch startSignal = new CountDownLatch(1);
        CountDownLatch doneSignal = new CountDownLatch(N);

        for (int i = 0; i < N; ++i) // create and start threads
            new Thread(new Worker(startSignal, doneSignal)).start();

        System.out.println("doSomethingElse");          // don't let run yet
        startSignal.countDown();                        // let all threads proceed
        System.out.println("doSomethingElse");
        doneSignal.await();                             // wait for all to finish
        System.out.println("all worker completed");
    }
}

class Worker implements Runnable {
    private final CountDownLatch startSignal;
    private final CountDownLatch doneSignal;
    Worker(CountDownLatch startSignal, CountDownLatch doneSignal) {
        this.startSignal = startSignal;
        this.doneSignal = doneSignal;
    }
    public void run() {
        try {
            startSignal.await();
            doWork();
            doneSignal.countDown();
        } catch (InterruptedException ex) {} // return;
    }

    void doWork() {
        System.out.println("doWork");
    }
}

class Driver2 { // ...
    public static void main(String[] args) throws InterruptedException {
        int N = 3;
        CountDownLatch doneSignal = new CountDownLatch(N);
        Executor e = Executors.newFixedThreadPool(N);
		
		// i 代表是问题的第几部分
        for (int i = 0; i < N; ++i) // create and start threads
            e.execute(new WorkerRunnable(doneSignal, i));

        doneSignal.await();           // wait for all to finish
        System.out.println("all task completed");
    }
}

class WorkerRunnable implements Runnable {
    private final CountDownLatch doneSignal;
    private final int i;
    WorkerRunnable(CountDownLatch doneSignal, int i) {
        this.doneSignal = doneSignal;
        this.i = i;
    }
    public void run() {
        doWork(i);
        doneSignal.countDown();
    }

    void doWork(int i) {
        System.out.println("task " + i);
    }
}

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐