一、AI Agent的异步挑战

一个典型的AI Agent请求会触发多个异步操作:

用户请求
  +--> LLM推理(3秒)
  +--> 向量检索(200ms)
  +--> 数据库查询(100ms)
  +--> 外部API调用(1秒)
        |
     结果聚合

如果用同步阻塞调用,总耗时 = 3s + 200ms + 100ms + 1s = 4.3s。
如果用CompletableFuture并行,总耗时 ~ max(3s, 200ms, 100ms, 1s) = 3s。

本文用Java的CompletableFuture构建一个类型安全的异步AI工作流编排框架


二、核心抽象

package com.demo.workflow.core;

import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.function.Function;

/**
 * 工作流阶段:接收输入,异步产生输出
 *
 * @param <I> 输入类型
 * @param <O> 输出类型
 */
@FunctionalInterface
public interface WorkflowStage<I, O> {
    CompletableFuture<O> execute(I input);

    /**
     * 链式组合:当前阶段 -> 下一阶段
     */
    default <T> WorkflowStage<I, T> then(WorkflowStage<O, T> next) {
        return input -> execute(input).thenCompose(next::execute);
    }

    /**
     * 并行组合:当前阶段与另一个阶段同时执行,合并结果
     */
    default <T, R> WorkflowStage<I, R> parallel(
            WorkflowStage<I, T> other,
            BiFunction<O, T, R> combiner) {
        return input -> {
            CompletableFuture<O> f1 = execute(input);
            CompletableFuture<T> f2 = other.execute(input);
            return f1.thenCombine(f2, combiner);
        };
    }

    /**
     * 错误恢复:当当前阶段失败时,用fallback处理
     */
    default WorkflowStage<I, O> onError(Function<Throwable, O> fallback) {
        return input -> execute(input).exceptionally(fallback);
    }

    /**
     * 静态工厂:从同步函数创建Stage
     */
    static <I, O> WorkflowStage<I, O> fromSync(Function<I, O> syncFn) {
        return input -> CompletableFuture.supplyAsync(() -> syncFn.apply(input));
    }
}

三、工作流编排器

package com.demo.workflow.core;

import java.time.Duration;
import java.time.Instant;
import java.util.*;
import java.util.concurrent.*;
import java.util.function.Consumer;

/**
 * 工作流编排器:管理多阶段的异步执行
 */
public class WorkflowOrchestrator<I, O> {

    private final WorkflowStage<I, O> stage;
    private final Executor executor;
    private final Duration defaultTimeout;
    private final List<Consumer<StageEvent>> eventListeners;
    private final CheckpointManager checkpointManager;

    private WorkflowOrchestrator(Builder<I, O> builder) {
        this.stage = builder.stage;
        this.executor = builder.executor != null ? builder.executor
            : Executors.newVirtualThreadPerTaskExecutor();  // Java 21虚拟线程
        this.defaultTimeout = builder.timeout;
        this.eventListeners = new ArrayList<>(builder.eventListeners);
        this.checkpointManager = builder.checkpointManager;
    }

    /**
     * 执行工作流(支持检查点恢复)
     */
    public WorkflowResult<O> execute(I input, String executionId) {
        Instant startTime = Instant.now();
        List<StageTrace> traces = new ArrayList<>();

        try {
            // 尝试从检查点恢复
            CompletableFuture<O> future;
            O restored = checkpointManager != null
                ? checkpointManager.load(executionId) : null;

            if (restored != null) {
                emitEvent(new StageEvent(executionId, "checkpoint", "RESTORED", null, null));
                future = CompletableFuture.completedFuture(restored);
            } else {
                future = stage.execute(input);
            }

            O result = future
                .orTimeout(defaultTimeout.toMillis(), TimeUnit.MILLISECONDS)
                .get();

            Instant endTime = Instant.now();
            WorkflowResult<O> workflowResult = new WorkflowResult<>(
                executionId, "SUCCESS", result, traces,
                Duration.between(startTime, endTime));

            emitEvent(new StageEvent(executionId, "workflow", "COMPLETED", null,
                workflowResult.duration().toMillis()));

            return workflowResult;

        } catch (TimeoutException e) {
            emitEvent(new StageEvent(executionId, "workflow", "TIMEOUT", e.getMessage(), null));
            return new WorkflowResult<>(executionId, "TIMEOUT", null, traces,
                Duration.between(startTime, Instant.now()));

        } catch (Exception e) {
            emitEvent(new StageEvent(executionId, "workflow", "FAILED", e.getMessage(), null));
            return new WorkflowResult<>(executionId, "FAILED", null, traces,
                Duration.between(startTime, Instant.now()));
        }
    }

    private void emitEvent(StageEvent event) {
        for (Consumer<StageEvent> listener : eventListeners) {
            try {
                listener.accept(event);
            } catch (Exception ignored) {
                // 事件监听器异常不应影响主流程
            }
        }
    }

    // ========== Builder ==========

    public static <I, O> Builder<I, O> builder(WorkflowStage<I, O> stage) {
        return new Builder<>(stage);
    }

    public static class Builder<I, O> {
        private final WorkflowStage<I, O> stage;
        private Executor executor;
        private Duration timeout = Duration.ofSeconds(30);
        private final List<Consumer<StageEvent>> eventListeners = new ArrayList<>();
        private CheckpointManager checkpointManager;

        private Builder(WorkflowStage<I, O> stage) {
            this.stage = stage;
        }

        public Builder<I, O> executor(Executor executor) {
            this.executor = executor;
            return this;
        }

        public Builder<I, O> timeout(Duration timeout) {
            this.timeout = timeout;
            return this;
        }

        public Builder<I, O> onEvent(Consumer<StageEvent> listener) {
            this.eventListeners.add(listener);
            return this;
        }

        public Builder<I, O> withCheckpoints(CheckpointManager cm) {
            this.checkpointManager = cm;
            return this;
        }

        public WorkflowOrchestrator<I, O> build() {
            return new WorkflowOrchestrator<>(this);
        }
    }

    // ========== 结果类型 ==========

    public record WorkflowResult<T>(
        String executionId,
        String status,
        T data,
        List<StageTrace> traces,
        Duration duration
    ) {}

    public record StageTrace(String stageName, Duration duration, String status, String error) {}
    public record StageEvent(String executionId, String stage, String event, String error, Long durationMs) {}
}

四、检查点管理器——支持失败恢复

AI Agent调用LLM是昂贵的操作(时间和token成本),如果工作流在步骤3失败,应该能从步骤3重试而不重跑步骤1和2:

package com.demo.workflow.core;

import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.*;
import java.nio.file.*;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 基于文件的检查点管理器
 */
public class FileCheckpointManager implements CheckpointManager {

    private final Path checkpointDir;
    private final ObjectMapper objectMapper;
    private final Map<String, Object> cache;

    public FileCheckpointManager(String checkpointDir) {
        this.checkpointDir = Path.of(checkpointDir);
        this.objectMapper = new ObjectMapper();
        this.cache = new ConcurrentHashMap<>();
        try {
            Files.createDirectories(this.checkpointDir);
        } catch (IOException e) {
            throw new RuntimeException("无法创建检查点目录", e);
        }
    }

    @Override
    public <T> void save(String executionId, T data) {
        cache.put(executionId, data);
        try {
            Path file = checkpointDir.resolve(executionId + ".json");
            objectMapper.writeValue(file.toFile(), data);
        } catch (IOException e) {
            System.err.println("检查点保存失败: " + e.getMessage());
        }
    }

    @Override
    @SuppressWarnings("unchecked")
    public <T> T load(String executionId) {
        if (cache.containsKey(executionId)) {
            return (T) cache.get(executionId);
        }

        Path file = checkpointDir.resolve(executionId + ".json");
        if (!Files.exists(file)) {
            return null;
        }
        try {
            String content = Files.readString(file);
            return (T) objectMapper.readValue(content, Object.class);
        } catch (IOException e) {
            return null;
        }
    }

    @Override
    public void delete(String executionId) {
        cache.remove(executionId);
        try {
            Files.deleteIfExists(checkpointDir.resolve(executionId + ".json"));
        } catch (IOException ignored) {}
    }
}

/**
 * 检查点管理器接口
 */
interface CheckpointManager {
    <T> void save(String executionId, T data);
    <T> T load(String executionId);
    void delete(String executionId);
}

五、实战:多模型评估流水线

构建一个并行调用3个LLM模型、汇总评分的Agent评估流水线:

package com.demo.workflow.example;

import com.demo.workflow.core.*;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.*;

public class ModelEvaluationPipeline {

    public static void main(String[] args) {
        // ===== 定义流水线阶段 =====

        // 阶段1: 上下文检索
        WorkflowStage<EvalRequest, EvalContext> retrievalStage =
            WorkflowStage.fromSync(req -> {
                System.out.println("[检索] 查询相关文档...");
                return new EvalContext(req.question(),
                    List.of("文档A: Spring Boot自动配置原理",
                           "文档B: 条件注解@Conditional详解"));
            });

        // 阶段2: 并行多模型评估
        WorkflowStage<EvalContext, Map<String, ModelResult>> evalStage = ctx -> {
            List<ModelRequest> reqs = List.of(
                new ModelRequest(ctx.question(), "gpt-4o", ctx.documents()),
                new ModelRequest(ctx.question(), "claude-3.5-sonnet", ctx.documents()),
                new ModelRequest(ctx.question(), "deepseek-v3", ctx.documents())
            );

            List<CompletableFuture<ModelResult>> futures = reqs.stream()
                .map(r -> CompletableFuture.supplyAsync(() -> callModel(r)))
                .toList();

            return CompletableFuture.allOf(
                futures.toArray(new CompletableFuture[0]))
                .thenApply(v -> {
                    Map<String, ModelResult> results = new LinkedHashMap<>();
                    for (int i = 0; i < reqs.size(); i++) {
                        ModelResult mr = futures.get(i).join();
                        results.put(reqs.get(i).model(), mr);
                    }
                    return results;
                });
        };

        // 阶段3: 结果评分与排序
        WorkflowStage<Map<String, ModelResult>, EvaluationReport> scoringStage =
            WorkflowStage.fromSync(results -> {
                System.out.println("[评分] 计算各模型得分...");
                List<ModelScore> scores = new ArrayList<>();
                for (var entry : results.entrySet()) {
                    ModelResult mr = entry.getValue();
                    double total = mr.quality() * 0.4
                                 + (1.0 / Math.max(mr.latencyMs(), 1)) * 100 * 0.3
                                 + (1.0 / Math.max(mr.cost(), 0.001)) * 10 * 0.3;
                    scores.add(new ModelScore(entry.getKey(), total, mr));
                }
                scores.sort((a, b) -> Double.compare(b.score(), a.score()));
                return new EvaluationReport(scores);
            });

        // ===== 组合流水线 =====
        WorkflowStage<EvalRequest, EvaluationReport> pipeline =
            retrievalStage
                .then(evalStage)
                .then(scoringStage);

        // ===== 配置编排器 =====
        WorkflowOrchestrator<EvalRequest, EvaluationReport> orchestrator =
            WorkflowOrchestrator.<EvalRequest, EvaluationReport>builder(pipeline)
                .timeout(Duration.ofSeconds(60))
                .withCheckpoints(
                    new FileCheckpointManager("/tmp/workflow-checkpoints"))
                .onEvent(event -> System.out.printf(
                    "[事件] %s/%s: %s (耗时%dms)\n",
                    event.executionId(), event.stage(),
                    event.event(), event.durationMs()))
                .build();

        // ===== 执行 =====
        EvalRequest request = new EvalRequest(
            "Spring Boot如何实现自动配置?", null);
        WorkflowOrchestrator.WorkflowResult<EvaluationReport> result =
            orchestrator.execute(request, "eval-001");

        System.out.println("\n=== 评估报告 ===");
        System.out.println("状态: " + result.status());
        System.out.println("耗时: " + result.duration().toMillis() + "ms");
        if (result.data() != null) {
            for (ModelScore score : result.data().scores()) {
                System.out.printf("  %s: %.2f分 (质量:%.1f, 延迟:%dms, 成本:%.4f)\n",
                    score.model(), score.score(),
                    score.result().quality(),
                    score.result().latencyMs(),
                    score.result().cost());
            }
        }
    }

    // ===== 模拟函数 =====

    static ModelResult callModel(ModelRequest req) {
        long start = System.currentTimeMillis();
        try {
            Thread.sleep(ThreadLocalRandom.current().nextLong(500, 2000));
        } catch (InterruptedException ignored) {}
        long latency = System.currentTimeMillis() - start;

        return new ModelResult(
            ThreadLocalRandom.current().nextDouble(0.5, 1.0),
            latency,
            ThreadLocalRandom.current().nextDouble(0.001, 0.01)
        );
    }

    // ===== 数据类 =====

    record EvalRequest(String question, List<String> documents) {}
    record EvalContext(String question, List<String> documents) {}
    record ModelRequest(String question, String model, List<String> documents) {}
    record ModelResult(double quality, long latencyMs, double cost) {}
    record ModelScore(String model, double score, ModelResult result) {}
    record EvaluationReport(List<ModelScore> scores) {}
}

六、生产环境增强

6.1 Micrometer指标集成

import io.micrometer.core.instrument.*;

public class WorkflowMetrics {
    private final MeterRegistry registry;

    public WorkflowMetrics(MeterRegistry registry) {
        this.registry = registry;
    }

    public void recordExecution(String workflowName, String status, Duration duration) {
        Timer.builder("workflow.execution.duration")
            .tag("workflow", workflowName)
            .tag("status", status)
            .register(registry)
            .record(duration);

        Counter.builder("workflow.execution.count")
            .tag("workflow", workflowName)
            .tag("status", status)
            .register(registry)
            .increment();
    }
}

6.2 自定义线程池(Spring Boot集成)

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadPoolExecutor;

@Configuration
public class WorkflowConfig {

    @Bean("workflowExecutor")
    public Executor workflowExecutor() {
        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
        executor.setCorePoolSize(4);
        executor.setMaxPoolSize(16);
        executor.setQueueCapacity(200);
        executor.setThreadNamePrefix("wf-");
        executor.setRejectedExecutionHandler(
            new ThreadPoolExecutor.CallerRunsPolicy());  // 背压
        executor.initialize();
        return executor;
    }
}

七、总结

能力 实现 价值
类型安全组合 WorkflowStage<I,O> 泛型接口 编译期保证类型正确
并行执行 thenCombine + CompletableFuture.allOf 多模型/多数据源并行,减少延迟
错误恢复 exceptionally + onError 单点故障不中断整体流程
检查点恢复 FileCheckpointManager 避免重复执行昂贵的LLM调用
可观测性 事件监听器 + Micrometer 全链路追踪与监控
Logo

小龙虾开发者社区是 CSDN 旗下专注 OpenClaw 生态的官方阵地,聚焦技能开发、插件实践与部署教程,为开发者提供可直接落地的方案、工具与交流平台,助力高效构建与落地 AI 应用

更多推荐