Go+Java混合架构:分布式AI工作流调度系统
·
一、为什么需要分布式调度
前四篇实现的都是单进程内的工作流引擎。但现实中的AI应用往往是异构的:
| 组件 | 最佳语言 | 原因 |
|---|---|---|
| 调度器/网关 | Go | 高并发、低延迟、轻量级goroutine |
| LLM Agent | Python/Java | 丰富的ML/AI生态库 |
| 数据处理 | Java | Spring生态,企业级数据处理 |
| 向量检索 | Go/C++ | 高性能计算 |
合理的选型是Go做调度,Java做执行——各取所长。
二、架构总览
+------------------+
| Redis Queue |
| (任务缓冲) |
+--------+---------+
|
+--------------+--------------+
| |
+--------v--------+ +--------v--------+
| Go Scheduler | gRPC | Java Worker 1 |
| (Master) |<------->| (Executor) |
| +------------+ | | +-----------+ |
| |DAG Engine | | | |LLM Agent | |
| |Health Check | | | |Text Process| |
| |Load Balance | | | |Embedding | |
| +------------+ | | +-----------+ |
+------------------+ +-----------------+
| |
|gRPC |gRPC
| |
+--------v--------+ +--------v--------+
| Go Worker 1 | | Java Worker 2 |
| (Vector Search)| | (Data Pipeline) |
+-----------------+ +-----------------+
三、Protobuf服务定义
这是Go和Java之间的契约——双方严格遵守同一份proto:
syntax = "proto3";
package aiworkflow;
option go_package = "github.com/yourorg/aiworkflow/proto";
option java_package = "com.demo.aiworkflow.proto";
option java_outer_classname = "WorkflowProto";
// 工作流调度服务
service WorkflowService {
// Worker向Scheduler注册
rpc Register(RegisterRequest) returns (RegisterResponse);
// Scheduler向Worker下发任务(双向流)
rpc AssignTask(stream TaskAssignment) returns (stream TaskResult);
// Worker心跳上报
rpc Heartbeat(HeartbeatRequest) returns (HeartbeatResponse);
}
message RegisterRequest {
string worker_id = 1;
string host = 2;
int32 port = 3;
repeated string capabilities = 4; // 能力标签: ["llm","embedding","search"]
int32 max_concurrency = 5;
}
message RegisterResponse {
bool accepted = 1;
string message = 2;
int32 heartbeat_interval_sec = 3;
}
message TaskAssignment {
string task_id = 1;
string task_type = 2; // "llm_call", "embedding", "search", "text_process"
string workflow_id = 3; // 所属工作流ID
int32 priority = 4; // 优先级 1-10
map<string, string> params = 5; // 任务参数
int64 timeout_ms = 6; // 超时毫秒
int64 created_at = 7;
}
message TaskResult {
string task_id = 1;
enum Status {
PENDING = 0;
RUNNING = 1;
SUCCESS = 2;
FAILED = 3;
TIMEOUT = 4;
}
Status status = 2;
string output_json = 3; // 任务输出(JSON序列化)
string error_message = 4;
int64 duration_ms = 5;
map<string, string> metadata = 6;
}
message HeartbeatRequest {
string worker_id = 1;
int32 active_tasks = 2;
double cpu_usage = 3;
int64 memory_mb = 4;
int64 timestamp = 5;
}
message HeartbeatResponse {
bool ok = 1;
}
四、Go Scheduler实现
package main
import (
"context"
"fmt"
"log"
"net"
"sync"
"time"
pb "github.com/yourorg/aiworkflow/proto"
"google.golang.org/grpc"
)
// WorkerInfo 已注册的Worker信息
type WorkerInfo struct {
ID string
Host string
Port int32
Capabilities []string
MaxConcurrency int32
ActiveTasks int32
LastHeartbeat time.Time
Healthy bool
}
// Scheduler 任务调度器
type Scheduler struct {
pb.UnimplementedWorkflowServiceServer
mu sync.RWMutex
workers map[string]*WorkerInfo // workerID -> info
}
func NewScheduler() *Scheduler {
return &Scheduler{
workers: make(map[string]*WorkerInfo),
}
}
// Register Worker注册
func (s *Scheduler) Register(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.workers[req.WorkerId] = &WorkerInfo{
ID: req.WorkerId,
Host: req.Host,
Port: req.Port,
Capabilities: req.Capabilities,
MaxConcurrency: req.MaxConcurrency,
LastHeartbeat: time.Now(),
Healthy: true,
}
log.Printf("Worker注册成功: %s (%s:%d) 能力: %v",
req.WorkerId, req.Host, req.Port, req.Capabilities)
return &pb.RegisterResponse{
Accepted: true,
Message: "注册成功",
HeartbeatIntervalSec: 10,
}, nil
}
// Heartbeat Worker心跳
func (s *Scheduler) Heartbeat(ctx context.Context, req *pb.HeartbeatRequest) (*pb.HeartbeatResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
worker, ok := s.workers[req.WorkerId]
if !ok {
return &pb.HeartbeatResponse{Ok: false}, fmt.Errorf("未知Worker: %s", req.WorkerId)
}
worker.ActiveTasks = req.ActiveTasks
worker.LastHeartbeat = time.Now()
worker.Healthy = true
return &pb.HeartbeatResponse{Ok: true}, nil
}
// 健康检查
func (s *Scheduler) startHealthCheck(ctx context.Context) {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.mu.Lock()
now := time.Now()
for id, w := range s.workers {
if now.Sub(w.LastHeartbeat) > 30*time.Second {
w.Healthy = false
log.Printf("Worker %s 心跳超时,标记为不健康", id)
}
}
s.mu.Unlock()
}
}
}
// AssignTask 双向流:Scheduler下发任务,Worker回报结果
func (s *Scheduler) AssignTask(stream pb.WorkflowService_AssignTaskServer) error {
go func() {
for {
result, err := stream.Recv()
if err != nil {
log.Printf("Worker结果接收结束: %v", err)
return
}
log.Printf("任务 %s 完成: status=%s, 耗时=%dms",
result.TaskId, result.Status, result.DurationMs)
}
}()
<-stream.Context().Done()
return stream.Context().Err()
}
// SubmitTask 对外暴露的任务提交接口
func (s *Scheduler) SubmitTask(task *pb.TaskAssignment) (*pb.TaskResult, error) {
worker, err := s.selectWorker(task.TaskType)
if err != nil {
return nil, fmt.Errorf("无可用Worker: %w", err)
}
log.Printf("任务 %s (%s) 分配给 Worker %s", task.TaskId, task.TaskType, worker.ID)
return &pb.TaskResult{
TaskId: task.TaskId,
Status: pb.TaskResult_PENDING,
}, nil
}
// selectWorker 按能力+负载选择Worker(最少连接数算法)
func (s *Scheduler) selectWorker(taskType string) (*WorkerInfo, error) {
s.mu.RLock()
defer s.mu.RUnlock()
var best *WorkerInfo
var bestLoad int32 = 9999
for _, w := range s.workers {
if !w.Healthy || !hasCapability(w.Capabilities, taskType) {
continue
}
if w.ActiveTasks < bestLoad && w.ActiveTasks < w.MaxConcurrency {
bestLoad = w.ActiveTasks
best = w
}
}
if best == nil {
return nil, fmt.Errorf("没有匹配 %s 的健康Worker", taskType)
}
return best, nil
}
func hasCapability(caps []string, target string) bool {
for _, c := range caps {
if c == target || c == "*" {
return true
}
}
return false
}
func main() {
lis, err := net.Listen("tcp", ":9090")
if err != nil {
log.Fatalf("监听失败: %v", err)
}
scheduler := NewScheduler()
go scheduler.startHealthCheck(context.Background())
s := grpc.NewServer()
pb.RegisterWorkflowServiceServer(s, scheduler)
log.Println("Go Scheduler 启动在 :9090")
if err := s.Serve(lis); err != nil {
log.Fatalf("服务失败: %v", err)
}
}
五、Java Worker实现
package com.demo.aiworkflow.worker;
import com.demo.aiworkflow.proto.WorkflowProto.*;
import com.demo.aiworkflow.proto.WorkflowServiceGrpc;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.stub.StreamObserver;
import java.util.UUID;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Java AI Worker: 执行LLM调用、文本处理等AI任务
*/
public class AIWorker {
private final String workerId;
private final ManagedChannel channel;
private final WorkflowServiceGrpc.WorkflowServiceStub asyncStub;
private final WorkflowServiceGrpc.WorkflowServiceBlockingStub blockingStub;
private final ExecutorService taskExecutor;
private final ScheduledExecutorService heartbeatExecutor;
private final AtomicInteger activeTaskCount;
private volatile boolean running;
private static final String SCHEDULER_HOST = "localhost";
private static final int SCHEDULER_PORT = 9090;
public AIWorker(String workerId) {
this.workerId = workerId;
this.channel = ManagedChannelBuilder
.forAddress(SCHEDULER_HOST, SCHEDULER_PORT)
.usePlaintext()
.build();
this.asyncStub = WorkflowServiceGrpc.newStub(channel);
this.blockingStub = WorkflowServiceGrpc.newBlockingStub(channel);
this.taskExecutor = Executors.newVirtualThreadPerTaskExecutor();
this.heartbeatExecutor = Executors.newSingleThreadScheduledExecutor();
this.activeTaskCount = new AtomicInteger(0);
this.running = true;
}
/**
* 启动Worker:注册 + 心跳 + 监听任务
*/
public void start() throws InterruptedException {
register();
startHeartbeat();
// 建立双向流接收任务
StreamObserver<TaskAssignment> requestObserver =
asyncStub.assignTask(new StreamObserver<TaskResult>() {
@Override
public void onNext(TaskResult result) {
// Scheduler的确认消息
}
@Override
public void onError(Throwable t) {
System.err.println("任务流错误: " + t.getMessage());
}
@Override
public void onCompleted() {
System.out.println("任务流关闭");
}
});
while (running) {
Thread.sleep(1000);
}
}
private void register() {
RegisterRequest request = RegisterRequest.newBuilder()
.setWorkerId(workerId)
.setHost("localhost")
.setPort(8080)
.addAllCapabilities(java.util.List.of("llm", "embedding", "text_process"))
.setMaxConcurrency(10)
.build();
RegisterResponse response = blockingStub.register(request);
if (response.getAccepted()) {
System.out.println("Worker注册成功: " + response.getMessage());
} else {
throw new RuntimeException("Worker注册被拒绝: " + response.getMessage());
}
}
private void startHeartbeat() {
heartbeatExecutor.scheduleAtFixedRate(() -> {
try {
HeartbeatRequest request = HeartbeatRequest.newBuilder()
.setWorkerId(workerId)
.setActiveTasks(activeTaskCount.get())
.setCpuUsage(getCpuUsage())
.setMemoryMb(getMemoryMB())
.setTimestamp(System.currentTimeMillis())
.build();
HeartbeatResponse response = blockingStub
.withDeadlineAfter(5, TimeUnit.SECONDS)
.heartbeat(request);
if (!response.getOk()) {
System.err.println("心跳失败");
}
} catch (Exception e) {
System.err.println("心跳异常: " + e.getMessage());
}
}, 5, 10, TimeUnit.SECONDS);
}
/**
* 执行具体AI任务
*/
private TaskResult executeTask(TaskAssignment task) {
activeTaskCount.incrementAndGet();
long startTime = System.currentTimeMillis();
try {
String output = switch (task.getTaskType()) {
case "llm_call" -> callLLM(task.getParamsMap());
case "embedding" -> generateEmbedding(task.getParamsMap());
case "text_process" -> processText(task.getParamsMap());
default -> throw new IllegalArgumentException(
"未知任务类型: " + task.getTaskType());
};
long duration = System.currentTimeMillis() - startTime;
return TaskResult.newBuilder()
.setTaskId(task.getTaskId())
.setStatus(TaskResult.Status.SUCCESS)
.setOutputJson(output)
.setDurationMs(duration)
.build();
} catch (Exception e) {
long duration = System.currentTimeMillis() - startTime;
return TaskResult.newBuilder()
.setTaskId(task.getTaskId())
.setStatus(TaskResult.Status.FAILED)
.setErrorMessage(e.getMessage())
.setDurationMs(duration)
.build();
} finally {
activeTaskCount.decrementAndGet();
}
}
// ===== 模拟AI任务实现 =====
private String callLLM(java.util.Map<String, String> params) throws InterruptedException {
String model = params.getOrDefault("model", "gpt-4o-mini");
String prompt = params.getOrDefault("prompt", "");
System.out.printf("[LLM] 调用 %s: %.50s...%n", model, prompt);
Thread.sleep(ThreadLocalRandom.current().nextLong(500, 2000));
return "{\"response\": \"这是LLM的回答\", \"tokens\": 150}";
}
private String generateEmbedding(java.util.Map<String, String> params)
throws InterruptedException {
String text = params.getOrDefault("text", "");
System.out.printf("[Embedding] 文本长度: %d%n", text.length());
Thread.sleep(200);
return "{\"dimensions\": 1536}";
}
private String processText(java.util.Map<String, String> params)
throws InterruptedException {
String operation = params.getOrDefault("operation", "summarize");
System.out.printf("[Text] 操作: %s%n", operation);
Thread.sleep(300);
return "{\"result\": \"处理后的文本...\"}";
}
// ===== 系统指标 =====
private double getCpuUsage() {
return java.lang.management.ManagementFactory
.getOperatingSystemMXBean().getSystemLoadAverage();
}
private long getMemoryMB() {
Runtime rt = Runtime.getRuntime();
return (rt.totalMemory() - rt.freeMemory()) / (1024 * 1024);
}
public static void main(String[] args) throws Exception {
String workerId = "java-worker-" + UUID.randomUUID()
.toString().substring(0, 8);
AIWorker worker = new AIWorker(workerId);
worker.start();
}
}
六、Redis任务队列(削峰填谷)
当Worker全部忙碌时,新的任务不应直接失败,而应进入Redis队列等待:
// Go端:任务入队
func (s *Scheduler) EnqueueTask(ctx context.Context, task *pb.TaskAssignment) error {
data, _ := proto.Marshal(task)
return s.redis.LPush(ctx, "workflow:queue:"+task.TaskType, data).Err()
}
// Go端:后台消费者
func (s *Scheduler) startDequeue(ctx context.Context) {
for {
result, err := s.redis.BRPop(ctx, 5*time.Second,
"workflow:queue:llm_call",
"workflow:queue:embedding",
"workflow:queue:text_process",
).Result()
if err != nil {
continue
}
var task pb.TaskAssignment
proto.Unmarshal([]byte(result[1]), &task)
if _, err := s.SubmitTask(&task); err != nil {
// Worker忙,重新入队
s.redis.LPush(ctx, result[0], result[1])
time.Sleep(100 * time.Millisecond)
}
}
}
七、完整调用链路
业务服务 --> Go Scheduler
|
+-- 1. 选Worker (能力匹配 + 最少连接)
+-- 2. Worker忙? -> Redis队列缓冲
+-- 3. gRPC流下发任务
|
v
Java Worker
+-- 4. 执行AI任务 (LLM/Embedding/Text)
+-- 5. gRPC流回传结果
+-- 6. 心跳上报负载
|
v
Go Scheduler
+-- 7. 判断工作流下一步
+-- 8. 触发下一阶段任务
八、工程化增强
8.1 gRPC拦截器(重试+超时)
func retryInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
for i := 0; i < 3; i++ {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
err := invoker(ctx, method, req, reply, cc, opts...)
cancel()
if err == nil {
return nil
}
time.Sleep(time.Duration(100*(1<<i)) * time.Millisecond)
}
return fmt.Errorf("gRPC调用失败(已重试3次): %s", method)
}
}
8.2 Prometheus指标
var (
taskDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "workflow_task_duration_seconds",
Buckets: []float64{0.1, 0.5, 1, 2, 5, 10, 30},
},
[]string{"task_type", "status"},
)
workerCount = prometheus.NewGaugeVec(
prometheus.GaugeOpts{Name: "workflow_workers_healthy"},
[]string{},
)
)
九、总结
| 组件 | 语言 | 核心能力 |
|---|---|---|
| Scheduler | Go | 任务调度、负载均衡、健康检查、gRPC服务端 |
| Worker | Java | AI任务执行(LLM/Embedding)、gRPC客户端 |
| 队列 | Redis | 削峰填谷、任务缓冲 |
| 协议 | Protobuf+gRPC | 跨语言类型安全通信 |
更多推荐
所有评论(0)