一、为什么需要分布式调度

前四篇实现的都是单进程内的工作流引擎。但现实中的AI应用往往是异构的:

组件 最佳语言 原因
调度器/网关 Go 高并发、低延迟、轻量级goroutine
LLM Agent Python/Java 丰富的ML/AI生态库
数据处理 Java Spring生态,企业级数据处理
向量检索 Go/C++ 高性能计算

合理的选型是Go做调度,Java做执行——各取所长。


二、架构总览

                    +------------------+
                    |   Redis Queue    |
                    |  (任务缓冲)       |
                    +--------+---------+
                             |
              +--------------+--------------+
              |                             |
     +--------v--------+          +--------v--------+
     |  Go Scheduler    |  gRPC   |  Java Worker 1  |
     |  (Master)        |<------->|  (Executor)     |
     |  +------------+  |         |  +-----------+  |
     |  |DAG Engine   |  |         |  |LLM Agent  |  |
     |  |Health Check |  |         |  |Text Process|  |
     |  |Load Balance |  |         |  |Embedding   |  |
     |  +------------+  |         |  +-----------+  |
     +------------------+         +-----------------+
              |                             |
              |gRPC                         |gRPC
              |                             |
     +--------v--------+          +--------v--------+
     |  Go Worker 1    |          |  Java Worker 2  |
     |  (Vector Search)|          |  (Data Pipeline) |
     +-----------------+          +-----------------+

三、Protobuf服务定义

这是Go和Java之间的契约——双方严格遵守同一份proto:

syntax = "proto3";

package aiworkflow;

option go_package = "github.com/yourorg/aiworkflow/proto";
option java_package = "com.demo.aiworkflow.proto";
option java_outer_classname = "WorkflowProto";

// 工作流调度服务
service WorkflowService {
  // Worker向Scheduler注册
  rpc Register(RegisterRequest) returns (RegisterResponse);

  // Scheduler向Worker下发任务(双向流)
  rpc AssignTask(stream TaskAssignment) returns (stream TaskResult);

  // Worker心跳上报
  rpc Heartbeat(HeartbeatRequest) returns (HeartbeatResponse);
}

message RegisterRequest {
  string worker_id = 1;
  string host = 2;
  int32 port = 3;
  repeated string capabilities = 4;  // 能力标签: ["llm","embedding","search"]
  int32 max_concurrency = 5;
}

message RegisterResponse {
  bool accepted = 1;
  string message = 2;
  int32 heartbeat_interval_sec = 3;
}

message TaskAssignment {
  string task_id = 1;
  string task_type = 2;            // "llm_call", "embedding", "search", "text_process"
  string workflow_id = 3;          // 所属工作流ID
  int32 priority = 4;              // 优先级 1-10
  map<string, string> params = 5;  // 任务参数
  int64 timeout_ms = 6;            // 超时毫秒
  int64 created_at = 7;
}

message TaskResult {
  string task_id = 1;
  enum Status {
    PENDING = 0;
    RUNNING = 1;
    SUCCESS = 2;
    FAILED = 3;
    TIMEOUT = 4;
  }
  Status status = 2;
  string output_json = 3;          // 任务输出(JSON序列化)
  string error_message = 4;
  int64 duration_ms = 5;
  map<string, string> metadata = 6;
}

message HeartbeatRequest {
  string worker_id = 1;
  int32 active_tasks = 2;
  double cpu_usage = 3;
  int64 memory_mb = 4;
  int64 timestamp = 5;
}

message HeartbeatResponse {
  bool ok = 1;
}

四、Go Scheduler实现

package main

import (
    "context"
    "fmt"
    "log"
    "net"
    "sync"
    "time"

    pb "github.com/yourorg/aiworkflow/proto"
    "google.golang.org/grpc"
)

// WorkerInfo 已注册的Worker信息
type WorkerInfo struct {
    ID             string
    Host           string
    Port           int32
    Capabilities   []string
    MaxConcurrency int32
    ActiveTasks    int32
    LastHeartbeat  time.Time
    Healthy        bool
}

// Scheduler 任务调度器
type Scheduler struct {
    pb.UnimplementedWorkflowServiceServer
    mu      sync.RWMutex
    workers map[string]*WorkerInfo // workerID -> info
}

func NewScheduler() *Scheduler {
    return &Scheduler{
        workers: make(map[string]*WorkerInfo),
    }
}

// Register Worker注册
func (s *Scheduler) Register(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) {
    s.mu.Lock()
    defer s.mu.Unlock()

    s.workers[req.WorkerId] = &WorkerInfo{
        ID:            req.WorkerId,
        Host:          req.Host,
        Port:          req.Port,
        Capabilities:  req.Capabilities,
        MaxConcurrency: req.MaxConcurrency,
        LastHeartbeat: time.Now(),
        Healthy:       true,
    }

    log.Printf("Worker注册成功: %s (%s:%d) 能力: %v",
        req.WorkerId, req.Host, req.Port, req.Capabilities)
    return &pb.RegisterResponse{
        Accepted:            true,
        Message:             "注册成功",
        HeartbeatIntervalSec: 10,
    }, nil
}

// Heartbeat Worker心跳
func (s *Scheduler) Heartbeat(ctx context.Context, req *pb.HeartbeatRequest) (*pb.HeartbeatResponse, error) {
    s.mu.Lock()
    defer s.mu.Unlock()

    worker, ok := s.workers[req.WorkerId]
    if !ok {
        return &pb.HeartbeatResponse{Ok: false}, fmt.Errorf("未知Worker: %s", req.WorkerId)
    }

    worker.ActiveTasks = req.ActiveTasks
    worker.LastHeartbeat = time.Now()
    worker.Healthy = true

    return &pb.HeartbeatResponse{Ok: true}, nil
}

// 健康检查
func (s *Scheduler) startHealthCheck(ctx context.Context) {
    ticker := time.NewTicker(15 * time.Second)
    defer ticker.Stop()

    for {
        select {
        case <-ctx.Done():
            return
        case <-ticker.C:
            s.mu.Lock()
            now := time.Now()
            for id, w := range s.workers {
                if now.Sub(w.LastHeartbeat) > 30*time.Second {
                    w.Healthy = false
                    log.Printf("Worker %s 心跳超时,标记为不健康", id)
                }
            }
            s.mu.Unlock()
        }
    }
}

// AssignTask 双向流:Scheduler下发任务,Worker回报结果
func (s *Scheduler) AssignTask(stream pb.WorkflowService_AssignTaskServer) error {
    go func() {
        for {
            result, err := stream.Recv()
            if err != nil {
                log.Printf("Worker结果接收结束: %v", err)
                return
            }
            log.Printf("任务 %s 完成: status=%s, 耗时=%dms",
                result.TaskId, result.Status, result.DurationMs)
        }
    }()
    <-stream.Context().Done()
    return stream.Context().Err()
}

// SubmitTask 对外暴露的任务提交接口
func (s *Scheduler) SubmitTask(task *pb.TaskAssignment) (*pb.TaskResult, error) {
    worker, err := s.selectWorker(task.TaskType)
    if err != nil {
        return nil, fmt.Errorf("无可用Worker: %w", err)
    }

    log.Printf("任务 %s (%s) 分配给 Worker %s", task.TaskId, task.TaskType, worker.ID)

    return &pb.TaskResult{
        TaskId: task.TaskId,
        Status: pb.TaskResult_PENDING,
    }, nil
}

// selectWorker 按能力+负载选择Worker(最少连接数算法)
func (s *Scheduler) selectWorker(taskType string) (*WorkerInfo, error) {
    s.mu.RLock()
    defer s.mu.RUnlock()

    var best *WorkerInfo
    var bestLoad int32 = 9999

    for _, w := range s.workers {
        if !w.Healthy || !hasCapability(w.Capabilities, taskType) {
            continue
        }
        if w.ActiveTasks < bestLoad && w.ActiveTasks < w.MaxConcurrency {
            bestLoad = w.ActiveTasks
            best = w
        }
    }

    if best == nil {
        return nil, fmt.Errorf("没有匹配 %s 的健康Worker", taskType)
    }
    return best, nil
}

func hasCapability(caps []string, target string) bool {
    for _, c := range caps {
        if c == target || c == "*" {
            return true
        }
    }
    return false
}

func main() {
    lis, err := net.Listen("tcp", ":9090")
    if err != nil {
        log.Fatalf("监听失败: %v", err)
    }

    scheduler := NewScheduler()
    go scheduler.startHealthCheck(context.Background())

    s := grpc.NewServer()
    pb.RegisterWorkflowServiceServer(s, scheduler)

    log.Println("Go Scheduler 启动在 :9090")
    if err := s.Serve(lis); err != nil {
        log.Fatalf("服务失败: %v", err)
    }
}

五、Java Worker实现

package com.demo.aiworkflow.worker;

import com.demo.aiworkflow.proto.WorkflowProto.*;
import com.demo.aiworkflow.proto.WorkflowServiceGrpc;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.stub.StreamObserver;

import java.util.UUID;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * Java AI Worker: 执行LLM调用、文本处理等AI任务
 */
public class AIWorker {

    private final String workerId;
    private final ManagedChannel channel;
    private final WorkflowServiceGrpc.WorkflowServiceStub asyncStub;
    private final WorkflowServiceGrpc.WorkflowServiceBlockingStub blockingStub;

    private final ExecutorService taskExecutor;
    private final ScheduledExecutorService heartbeatExecutor;
    private final AtomicInteger activeTaskCount;
    private volatile boolean running;

    private static final String SCHEDULER_HOST = "localhost";
    private static final int SCHEDULER_PORT = 9090;

    public AIWorker(String workerId) {
        this.workerId = workerId;
        this.channel = ManagedChannelBuilder
            .forAddress(SCHEDULER_HOST, SCHEDULER_PORT)
            .usePlaintext()
            .build();
        this.asyncStub = WorkflowServiceGrpc.newStub(channel);
        this.blockingStub = WorkflowServiceGrpc.newBlockingStub(channel);

        this.taskExecutor = Executors.newVirtualThreadPerTaskExecutor();
        this.heartbeatExecutor = Executors.newSingleThreadScheduledExecutor();
        this.activeTaskCount = new AtomicInteger(0);
        this.running = true;
    }

    /**
     * 启动Worker:注册 + 心跳 + 监听任务
     */
    public void start() throws InterruptedException {
        register();
        startHeartbeat();

        // 建立双向流接收任务
        StreamObserver<TaskAssignment> requestObserver =
            asyncStub.assignTask(new StreamObserver<TaskResult>() {
                @Override
                public void onNext(TaskResult result) {
                    // Scheduler的确认消息
                }

                @Override
                public void onError(Throwable t) {
                    System.err.println("任务流错误: " + t.getMessage());
                }

                @Override
                public void onCompleted() {
                    System.out.println("任务流关闭");
                }
            });

        while (running) {
            Thread.sleep(1000);
        }
    }

    private void register() {
        RegisterRequest request = RegisterRequest.newBuilder()
            .setWorkerId(workerId)
            .setHost("localhost")
            .setPort(8080)
            .addAllCapabilities(java.util.List.of("llm", "embedding", "text_process"))
            .setMaxConcurrency(10)
            .build();

        RegisterResponse response = blockingStub.register(request);
        if (response.getAccepted()) {
            System.out.println("Worker注册成功: " + response.getMessage());
        } else {
            throw new RuntimeException("Worker注册被拒绝: " + response.getMessage());
        }
    }

    private void startHeartbeat() {
        heartbeatExecutor.scheduleAtFixedRate(() -> {
            try {
                HeartbeatRequest request = HeartbeatRequest.newBuilder()
                    .setWorkerId(workerId)
                    .setActiveTasks(activeTaskCount.get())
                    .setCpuUsage(getCpuUsage())
                    .setMemoryMb(getMemoryMB())
                    .setTimestamp(System.currentTimeMillis())
                    .build();

                HeartbeatResponse response = blockingStub
                    .withDeadlineAfter(5, TimeUnit.SECONDS)
                    .heartbeat(request);

                if (!response.getOk()) {
                    System.err.println("心跳失败");
                }
            } catch (Exception e) {
                System.err.println("心跳异常: " + e.getMessage());
            }
        }, 5, 10, TimeUnit.SECONDS);
    }

    /**
     * 执行具体AI任务
     */
    private TaskResult executeTask(TaskAssignment task) {
        activeTaskCount.incrementAndGet();
        long startTime = System.currentTimeMillis();

        try {
            String output = switch (task.getTaskType()) {
                case "llm_call"     -> callLLM(task.getParamsMap());
                case "embedding"    -> generateEmbedding(task.getParamsMap());
                case "text_process" -> processText(task.getParamsMap());
                default -> throw new IllegalArgumentException(
                    "未知任务类型: " + task.getTaskType());
            };

            long duration = System.currentTimeMillis() - startTime;
            return TaskResult.newBuilder()
                .setTaskId(task.getTaskId())
                .setStatus(TaskResult.Status.SUCCESS)
                .setOutputJson(output)
                .setDurationMs(duration)
                .build();

        } catch (Exception e) {
            long duration = System.currentTimeMillis() - startTime;
            return TaskResult.newBuilder()
                .setTaskId(task.getTaskId())
                .setStatus(TaskResult.Status.FAILED)
                .setErrorMessage(e.getMessage())
                .setDurationMs(duration)
                .build();
        } finally {
            activeTaskCount.decrementAndGet();
        }
    }

    // ===== 模拟AI任务实现 =====

    private String callLLM(java.util.Map<String, String> params) throws InterruptedException {
        String model = params.getOrDefault("model", "gpt-4o-mini");
        String prompt = params.getOrDefault("prompt", "");
        System.out.printf("[LLM] 调用 %s: %.50s...%n", model, prompt);
        Thread.sleep(ThreadLocalRandom.current().nextLong(500, 2000));
        return "{\"response\": \"这是LLM的回答\", \"tokens\": 150}";
    }

    private String generateEmbedding(java.util.Map<String, String> params)
            throws InterruptedException {
        String text = params.getOrDefault("text", "");
        System.out.printf("[Embedding] 文本长度: %d%n", text.length());
        Thread.sleep(200);
        return "{\"dimensions\": 1536}";
    }

    private String processText(java.util.Map<String, String> params)
            throws InterruptedException {
        String operation = params.getOrDefault("operation", "summarize");
        System.out.printf("[Text] 操作: %s%n", operation);
        Thread.sleep(300);
        return "{\"result\": \"处理后的文本...\"}";
    }

    // ===== 系统指标 =====

    private double getCpuUsage() {
        return java.lang.management.ManagementFactory
            .getOperatingSystemMXBean().getSystemLoadAverage();
    }

    private long getMemoryMB() {
        Runtime rt = Runtime.getRuntime();
        return (rt.totalMemory() - rt.freeMemory()) / (1024 * 1024);
    }

    public static void main(String[] args) throws Exception {
        String workerId = "java-worker-" + UUID.randomUUID()
            .toString().substring(0, 8);
        AIWorker worker = new AIWorker(workerId);
        worker.start();
    }
}

六、Redis任务队列(削峰填谷)

当Worker全部忙碌时,新的任务不应直接失败,而应进入Redis队列等待:

// Go端:任务入队
func (s *Scheduler) EnqueueTask(ctx context.Context, task *pb.TaskAssignment) error {
    data, _ := proto.Marshal(task)
    return s.redis.LPush(ctx, "workflow:queue:"+task.TaskType, data).Err()
}

// Go端:后台消费者
func (s *Scheduler) startDequeue(ctx context.Context) {
    for {
        result, err := s.redis.BRPop(ctx, 5*time.Second,
            "workflow:queue:llm_call",
            "workflow:queue:embedding",
            "workflow:queue:text_process",
        ).Result()
        if err != nil {
            continue
        }

        var task pb.TaskAssignment
        proto.Unmarshal([]byte(result[1]), &task)

        if _, err := s.SubmitTask(&task); err != nil {
            // Worker忙,重新入队
            s.redis.LPush(ctx, result[0], result[1])
            time.Sleep(100 * time.Millisecond)
        }
    }
}

七、完整调用链路

业务服务 --> Go Scheduler
                |
                +-- 1. 选Worker (能力匹配 + 最少连接)
                +-- 2. Worker忙? -> Redis队列缓冲
                +-- 3. gRPC流下发任务
                        |
                        v
              Java Worker
                +-- 4. 执行AI任务 (LLM/Embedding/Text)
                +-- 5. gRPC流回传结果
                +-- 6. 心跳上报负载
                        |
                        v
              Go Scheduler
                +-- 7. 判断工作流下一步
                +-- 8. 触发下一阶段任务

八、工程化增强

8.1 gRPC拦截器(重试+超时)

func retryInterceptor() grpc.UnaryClientInterceptor {
    return func(ctx context.Context, method string, req, reply interface{},
        cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
        for i := 0; i < 3; i++ {
            ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
            err := invoker(ctx, method, req, reply, cc, opts...)
            cancel()
            if err == nil {
                return nil
            }
            time.Sleep(time.Duration(100*(1<<i)) * time.Millisecond)
        }
        return fmt.Errorf("gRPC调用失败(已重试3次): %s", method)
    }
}

8.2 Prometheus指标

var (
    taskDuration = prometheus.NewHistogramVec(
        prometheus.HistogramOpts{
            Name:    "workflow_task_duration_seconds",
            Buckets: []float64{0.1, 0.5, 1, 2, 5, 10, 30},
        },
        []string{"task_type", "status"},
    )

    workerCount = prometheus.NewGaugeVec(
        prometheus.GaugeOpts{Name: "workflow_workers_healthy"},
        []string{},
    )
)

九、总结

组件 语言 核心能力
Scheduler Go 任务调度、负载均衡、健康检查、gRPC服务端
Worker Java AI任务执行(LLM/Embedding)、gRPC客户端
队列 Redis 削峰填谷、任务缓冲
协议 Protobuf+gRPC 跨语言类型安全通信

更多推荐