一、引言

之前写过一篇博客:基于trace_id的链路追踪(含Feign、Hystrix、线程池等场景),主要介绍在微服务体系架构中,如何实现分布式系统的链路追踪的博客,其中主要实现了以下几种场景:

  1. Filter实现trace_id拦截
  2. RestTemplate的链路追踪
  3. Feign的链路追踪
  4. Hystrix的链路追踪
  5. Dubbo的链路追踪
  6. Spring异步线程池的链路追踪

其中,还缺失了一种较为常见的场景,那就是Java中常用的线程池实现:ForkJoinPool

尤其Java 8提供的 Stream并行流 采用了 ForkJoinPool 作为默认实现,当我们基于并行流做一些业务操作时,日志的链路追踪往往很容易在这里出现断层的情况。

本文将探讨如何基于trace_id实现ForkJoinPool的链路追踪,以提升系统的可追溯性。

二、ForkJoinPool简介

ForkJoinPool是Java提供的一种线程池实现,特别适用于处理递归分解的任务。它采用了工作窃取(Work-Stealing)算法,通过将任务分解为更小的子任务并将其分配给空闲线程执行,从而实现了任务的并行执行。

三、基于trace_id的链路追踪设计

为了实现基于trace_id的链路追踪,我们可以通过以下步骤进行设计:

  • 为每个请求生成唯一的trace_id,并将其传递给ForkJoinPool中的任务。
  • 在任务开始和结束时,记录相关的trace_id信息。
  • 在任务执行过程中,将trace_id传递给子任务。
  • 使用日志或专门的链路追踪工具,收集和分析trace_id信息,构建请求的链路图。

四、代码实现

1、自定义线程池:MdcForkJoinPool

MdcForkJoinPool

package com.github.jesse.l2cache.util.pool;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.Future;

/**
 * 自定义 {@link ForkJoinPool},扩展MDC内容,以便链路追踪
 *
 * @author chenck
 * @date 2021/5/11 14:48
 */
public class MdcForkJoinPool extends ForkJoinPool {

    /**
     * max #workers - 1
     */
    public static final int MAX_CAP = 0x7fff;

    /**
     * the default parallelism level
     */
    public static final int DEFAULT_PARALLELISM = Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors());

    /**
     * the default thread name prefix
     */
    public static final String DEFAULT_THREAD_NAME_PREFIX = "MdcForkJoinPool";

    /**
     * Sequence number for creating workerNamePrefix.
     */
    private static int poolNumberSequence;

    /**
     * Returns the next sequence number. We don't expect this to
     * ever contend, so use simple builtin sync.
     */
    private static final synchronized int nextPoolId() {
        return ++poolNumberSequence;
    }

    /**
     * Common (static) pool.
     */
    static final MdcForkJoinPool mdcCommon = new MdcForkJoinPool();

    public static MdcForkJoinPool mdcCommonPool() {
        return mdcCommon;
    }

    // constructor

    public MdcForkJoinPool() {
        this(DEFAULT_PARALLELISM, DEFAULT_THREAD_NAME_PREFIX);
    }

    public MdcForkJoinPool(int parallelism) {
        this(parallelism, DEFAULT_THREAD_NAME_PREFIX);
    }

    public MdcForkJoinPool(String threadNamePrefix) {
        this(DEFAULT_PARALLELISM, threadNamePrefix);
    }

    public MdcForkJoinPool(int parallelism, String threadNamePrefix) {
        this(parallelism, new LimitedThreadForkJoinWorkerThreadFactory(parallelism, threadNamePrefix + "-" + nextPoolId()), null, false);
    }

    /**
     * Creates a new MdcForkJoinPool.
     *
     * @param parallelism the parallelism level. For default value, use {@link java.lang.Runtime#availableProcessors}.
     * @param factory     the factory for creating new threads. For default value, use
     *                    {@link #defaultForkJoinWorkerThreadFactory}.
     * @param handler     the handler for internal worker threads that terminate due to unrecoverable errors encountered
     *                    while executing tasks. For default value, use {@code null}.
     * @param asyncMode   if true, establishes local first-in-first-out scheduling mode for forked tasks that are never
     *                    joined. This mode may be more appropriate than default locally stack-based mode in applications
     *                    in which worker threads only process event-style asynchronous tasks. For default value, use
     *                    {@code false}.
     */
    public MdcForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, Thread.UncaughtExceptionHandler handler, boolean asyncMode) {
        super(parallelism, factory, handler, asyncMode);
    }

    // Execution methods

    @Override
    public <T> T invoke(ForkJoinTask<T> task) {
        if (task == null) {
            throw new NullPointerException();
        }
        return super.invoke(new ForkJoinTaskMdcWrapper<T>(task));
    }

    @Override
    public void execute(ForkJoinTask<?> task) {
        if (task == null) {
            throw new NullPointerException();
        }
        super.execute(new ForkJoinTaskMdcWrapper<>(task));
    }

    // AbstractExecutorService methods

    @Override
    public void execute(Runnable task) {
        if (task == null) {
            throw new NullPointerException();
        }
        super.execute(new RunnableMdcWarpper(task));
    }

    @Override
    public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
        if (task == null) {
            throw new NullPointerException();
        }
        return super.submit(new ForkJoinTaskMdcWrapper<T>(task));
    }

    @Override
    public <T> ForkJoinTask<T> submit(Callable<T> task) {
        if (task == null) {
            throw new NullPointerException();
        }
        return super.submit(new CallableMdcWrapper(task));
    }

    @Override
    public <T> ForkJoinTask<T> submit(Runnable task, T result) {
        if (task == null) {
            throw new NullPointerException();
        }
        return super.submit(new RunnableMdcWarpper(task), result);
    }

    @Override
    public ForkJoinTask<?> submit(Runnable task) {
        if (task == null) {
            throw new NullPointerException();
        }
        return super.submit(new RunnableMdcWarpper(task));
    }

    @Override
    public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) {
        if (tasks == null) {
            throw new NullPointerException();
        }
        Collection<Callable<T>> wrapperTasks = new ArrayList<>();
        for (Callable<T> task : tasks) {
            wrapperTasks.add(new CallableMdcWrapper(task));
        }

        return super.invokeAll(wrapperTasks);
    }

}

2、自定义包装类:透传trace_id

CallableMdcWrapper

package com.github.jesse.l2cache.util.pool;

import org.slf4j.MDC;
import java.util.Map;
import java.util.concurrent.Callable;

/**
 * @author chenck
 * @date 2021/5/11 17:09
 */
public class CallableMdcWrapper<T> implements Callable<T> {

    private static final long serialVersionUID = 1L;

    Callable<T> callable;
    Map<String, String> contextMap;

    public CallableMdcWrapper(Callable<T> callable) {
        this.callable = callable;
        this.contextMap = MDC.getCopyOfContextMap();
    }

    @Override
    public T call() throws Exception {
        Map<String, String> oldContext = MdcUtil.beforeExecution(contextMap);
        try {
            return callable.call();
        } finally {
            MdcUtil.afterExecution(oldContext);
        }
    }
}

RunnableMdcWarpper

package com.github.jesse.l2cache.util.pool;

import org.slf4j.MDC;
import java.util.Map;

/**
 * Runnable 包装 MDC
 *
 * @author chenck
 * @date 2020/9/23 19:37
 */
public class RunnableMdcWarpper implements Runnable {

    private static final long serialVersionUID = 1L;

    Runnable runnable;
    Map<String, String> contextMap;
    Object param;

    public RunnableMdcWarpper(Runnable runnable) {
        this.runnable = runnable;
        this.contextMap = MDC.getCopyOfContextMap();
    }

    public RunnableMdcWarpper(Runnable runnable, Object param) {
        this.runnable = runnable;
        this.contextMap = MDC.getCopyOfContextMap();
        this.param = param;
    }

    @Override
    public void run() {
        Map<String, String> oldContext = MdcUtil.beforeExecution(contextMap);
        try {
            runnable.run();
        } finally {
            MdcUtil.afterExecution(oldContext);
        }
    }

    public Object getParam() {
        return param;
    }
}

ForkJoinTaskMdcWrapper

package com.github.jesse.l2cache.util.pool;

import org.slf4j.MDC;
import java.util.Map;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.atomic.AtomicReference;

/**
 * @author chenck
 * @date 2021/5/11 16:56
 * @see https://stackoverflow.com/questions/36026402/how-to-use-mdc-with-forkjoinpool
 */
public class ForkJoinTaskMdcWrapper<T> extends ForkJoinTask<T> {

    private static final long serialVersionUID = 1L;

    /**
     * If non-null, overrides the value returned by the underlying task.
     */
    private final AtomicReference<T> override = new AtomicReference<>();

    private ForkJoinTask<T> task;
    private Map<String, String> newContext;

    public ForkJoinTaskMdcWrapper(ForkJoinTask<T> task) {
        this.task = task;
        this.newContext = MDC.getCopyOfContextMap();
    }

    @Override
    public T getRawResult() {
        T result = override.get();
        if (result != null) {
            return result;
        }
        return task.getRawResult();
    }

    @Override
    protected void setRawResult(T value) {
        override.set(value);
    }

    @Override
    protected boolean exec() {
        Map<String, String> oldContext = MdcUtil.beforeExecution(newContext);
        try {
            task.invoke();
            return true;
        } finally {
            MdcUtil.afterExecution(oldContext);
        }
    }
}

3、自定义线程工厂:自定义线程名称前缀+管理阻塞时限制最大线程数

LimitedThreadForkJoinWorkerThread

package com.github.jesse.l2cache.util.pool;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;

/**
 * 自定义ForkJoinWorkerThread,用于限制ForkJoinPool中创建的最大线程数
 *
 * @author chenck
 * @date 2023/5/6 13:49
 */
public class LimitedThreadForkJoinWorkerThread extends ForkJoinWorkerThread {
    protected LimitedThreadForkJoinWorkerThread(ForkJoinPool pool) {
        super(pool);
        setPriority(Thread.NORM_PRIORITY); // 设置线程优先级
        setDaemon(false); // 设置是否为守护线程
    }

    protected LimitedThreadForkJoinWorkerThread(ForkJoinPool pool, String threadName) {
        super(pool);
        setPriority(Thread.NORM_PRIORITY); // 设置线程优先级
        setDaemon(false); // 设置是否为守护线程
        setName(threadName);
    }
}

LimitedThreadForkJoinWorkerThreadFactory

package com.github.jesse.l2cache.util.pool;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * 自定义ForkJoinWorkerThreadFactory,用于限制ForkJoinPool中创建的最大线程数,并复用当前的ForkJoinPool的线程
 *
 * @author chenck
 * @date 2023/5/6 13:48
 */
public class LimitedThreadForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory {

    protected static Logger logger = LoggerFactory.getLogger(LimitedThreadForkJoinWorkerThreadFactory.class);

    /**
     * 最大线程数
     */
    private final int maxThreads;

    /**
     * 线程名称前缀
     */
    private String threadNamePrefix;

    /**
     * 当前线程数
     */
    private final AtomicInteger threadCount = new AtomicInteger(0);

    public LimitedThreadForkJoinWorkerThreadFactory(int maxThreads) {
        this.maxThreads = maxThreads;
    }

    public LimitedThreadForkJoinWorkerThreadFactory(int maxThreads, String threadNamePrefix) {
        this.maxThreads = maxThreads;
        this.threadNamePrefix = threadNamePrefix;
    }

    /**
     * 限制了线程数量并复用当前的ForkJoinPool的线程
     */
    @Override
    public ForkJoinWorkerThread newThread(ForkJoinPool pool) {
        int count = threadCount.incrementAndGet();

        // 如果当前线程数量小于等于最大线程数,则创建新线程,并将threadCount+1
        if (count <= maxThreads) {
            if (null == threadNamePrefix || "".equals(threadNamePrefix.trim())) {
                return new LimitedThreadForkJoinWorkerThread(pool);
            } else {
                // 使用自定义线程名称
                return new LimitedThreadForkJoinWorkerThread(pool, threadNamePrefix + "-worker-" + count);
            }
        }

        // 如果当前线程数量超过最大线程数,则不创建新线程,并将threadCount-1
        threadCount.decrementAndGet();
        if (logger.isDebugEnabled()) {
            logger.debug("Exceeded maximum number of threads");
        }
        return null;// 不创建新线程
    }

}

4、工具类

MyManagedBlocker

package com.github.jesse.l2cache.util.pool;

import java.util.concurrent.ForkJoinPool;
import java.util.function.Function;

/**
 * Java 8中的默认并行流使用公共ForkJoinPool,如果提交任务时公共池线程耗尽,会导致任务延迟执行。
 * <p>
 * CPU密集型:如果在ForkJoinPool中填充的任务,执行时间足够短,且CPU的可用能力足够,那么将不会出现上述延迟的问题。(ForkJoinPool的大多数使用场景)
 * I/O密集型:如果在ForkJoinPool中填充的任务,执行时间足够长,且是不受CPU限制的I/O任务,那么任务将延迟执行,并出现瓶颈。
 * 小结:ForkJoinPool 最适合的是CPU密集型的任务,如果存在 I/O,线程间同步,sleep() 等会造成线程长时间阻塞的情况时,最好配合使用 ManagedBlocker。
 * <p>
 * 对I/O阻塞型任务提供一个ManagedBlocker,让ForkJoinPool知道当前任务即将阻塞,因此需要创建新的`备用线程`来执行新提交的任务.
 * <p>
 * 【问题】通过ManagedBlocker来管理阻塞时,最大正在运行的线程数限制为32767,如果不限制新创建的线程数量,可能导致oom。如何控制ForkJoinPool中新创建的最大备用线程数?
 * 【分析】
 * 1、ForkJoinPool.common.commonMaxSpares 表示 tryCompensate 中`备用线程`创建的限制,默认为256
 * 2、上面这个参数,只能针对commonPool进行限制,并且tryCompensate方法不一定能会命中该限制,若未命中该限制,则可能无限制的创建`备用线程`来避免阻塞,最终还是可能出现oom
 * 3、ManagedBlocker将最大正在运行的线程数限制为32767.尝试创建大于最大数目的池导致IllegalArgumentException,只有当池被关闭或内部资源耗尽时,此实现才会拒绝提交的任务(即通过抛出RejectedExecutionException )。
 * 【方案】
 * 在管理阻塞时,通过自定义 {@LimitedThreadForkJoinWorkerThreadFactory} 来限制ForkJoinPool最大可创建的线程数,并复用当前的ForkJoinPool的线程,以此来避免无限制的创建`备用线程`
 *
 * @author chenck
 * @date 2023/5/5 18:30
 */
public class MyManagedBlocker implements ForkJoinPool.ManagedBlocker {
    private Function function;
    private Object key;
    private Object result;
    private boolean done = false;

    public MyManagedBlocker(Object key, Function function) {
        this.key = key;
        this.function = function;
    }


    @Override
    public boolean block() throws InterruptedException {
        result = function.apply(key);
        done = true;
        return false;
    }

    @Override
    public boolean isReleasable() {
        return done;
    }

    public Object getResult() {
        return result;
    }

}

MdcUtil

package com.github.jesse.l2cache.util.pool;

import org.slf4j.MDC;
import java.util.Map;

/**
 * @author chenck
 * @date 2021/5/11 17:00
 */
public class MdcUtil {

    /**
     * Invoked before running a task.
     *
     * @param newMdcContext the new MDC context
     * @return the old MDC context
     */
    public static Map<String, String> beforeExecution(Map<String, String> newMdcContext) {
        Map<String, String> oldMdcContext = MDC.getCopyOfContextMap();
        if (newMdcContext == null) {
            MDC.clear();
        } else {
            MDC.setContextMap(newMdcContext);
        }
        return oldMdcContext;
    }

    /**
     * Invoked after running a task.
     *
     * @param oldMdcContext the old MDC context
     */
    public static void afterExecution(Map<String, String> oldMdcContext) {
        if (oldMdcContext == null) {
            MDC.clear();
        } else {
            MDC.setContextMap(oldMdcContext);
        }
    }
}

五、小结

基于trace_id的链路追踪是提升分布式系统可追溯性的关键技术之一。

通过在任务中传递和记录trace_id信息,并结合日志和监控系统,开发人员可以更好地了解请求的流转路径和系统性能状况,从而快速定位和解决问题。

在实际应用中,需要根据具体的业务场景和性能要求,灵活选择追踪策略和工具,以实现最佳的性能和可追溯性的平衡。

参考文献:

  • Oracle官方文档:https://docs.oracle.com/javase/8/docs/api/java/util/concurrent/ForkJoinPool.html
  • OpenTracing官方文档:https://opentracing.io/
Logo

一起探索未来云端世界的核心,云原生技术专区带您领略创新、高效和可扩展的云计算解决方案,引领您在数字化时代的成功之路。

更多推荐