这个功能将强化学习(Reinforcement Learning)带到了浏览器环境中。

什么是 TensorFlow.js Agents?

TensorFlow.js Agents 是一个库,用于在浏览器中实现强化学习。它允许你创建能够通过与环境交互来学习最优行为的智能体(Agents)。

核心概念

  • Agent(智能体):学习的实体,做出决策
  • Environment(环境):智能体交互的世界
  • Action(动作):智能体采取的行为
  • Reward(奖励):环境对动作的反馈
  • Policy(策略):智能体决定动作的规则

安装和基础设置

# 安装核心库
npm install @tensorflow/tfjs-core
# 安装 Agents 库
npm install @tensorflow/tfjs-agents

或者使用 CDN:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-agents"></script>

基础用法示例

1. 创建一个简单的强化学习环境

import * as tf from '@tensorflow/tfjs';
import { createDQNAgent, getCartPole } from '@tensorflow/tfjs-agents';

// 创建一个简单的网格世界环境
class GridWorld {
  constructor() {
    this.gridSize = 5;
    this.agentPosition = [0, 0];
    this.goalPosition = [4, 4];
    this.obstacles = [[1, 1], [2, 2], [3, 3]];
    this.actions = ['up', 'down', 'left', 'right']; // 0, 1, 2, 3
  }

  getState() {
    // 将位置转换为状态表示
    return tf.tensor2d([this.agentPosition[0], this.agentPosition[1]], [1, 2]);
  }

  step(action) {
    const [x, y] = this.agentPosition;
    let newX = x, newY = y;
    let reward = -0.1; // 每一步的小惩罚,鼓励尽快到达目标
    let done = false;

    // 执行动作
    switch(action) {
      case 0: newY = Math.max(0, y - 1); break; // up
      case 1: newY = Math.min(this.gridSize - 1, y + 1); break; // down
      case 2: newX = Math.max(0, x - 1); break; // left
      case 3: newX = Math.min(this.gridSize - 1, x + 1); break; // right
    }

    // 检查是否碰到障碍物
    const isObstacle = this.obstacles.some(([ox, oy]) => ox === newX && oy === newY);
    
    if (!isObstacle) {
      this.agentPosition = [newX, newY];
    }

    // 检查是否到达目标
    if (newX === this.goalPosition[0] && newY === this.goalPosition[1]) {
      reward = 10; // 到达目标的奖励
      done = true;
    }

    return {
      observation: this.getState(),
      reward,
      done,
      info: {}
    };
  }

  reset() {
    this.agentPosition = [0, 0];
    return this.getState();
  }
}

2. 创建和训练 DQN Agent

async function createAndTrainAgent() {
  // 创建环境
  const environment = new GridWorld();
  
  // 定义模型架构
  const model = tf.sequential({
    layers: [
      tf.layers.dense({ units: 24, inputShape: [2], activation: 'relu' }),
      tf.layers.dense({ units: 24, activation: 'relu' }),
      tf.layers.dense({ units: 4, activation: 'linear' }) // 4个动作的输出
    ]
  });

  // 编译模型
  model.compile({
    optimizer: tf.train.adam(0.001),
    loss: 'meanSquaredError'
  });

  // 创建 DQN Agent
  const agent = await createDQNAgent({
    model: model,
    numActions: 4, // 上下左右四个动作
    memorySize: 1000,
    batchSize: 32,
    epsilon: 1.0, // 探索率
    epsilonDecay: 0.995,
    epsilonMin: 0.01
  });

  return { agent, environment };
}

3. 训练循环

async function trainAgent(episodes = 100) {
  const { agent, environment } = await createAndTrainAgent();
  const rewards = [];

  for (let episode = 0; episode < episodes; episode++) {
    let state = environment.reset();
    let totalReward = 0;
    let steps = 0;
    let done = false;

    while (!done && steps < 100) { // 最多100步
      // 选择动作
      const action = await agent.predict(state);
      
      // 执行动作
      const { observation: nextState, reward, done: episodeDone } = environment.step(action);
      
      // 存储经验
      await agent.remember(state, action, reward, nextState, episodeDone);
      
      // 训练智能体
      if (steps % 4 === 0) {
        await agent.train();
      }

      state = nextState;
      totalReward += reward;
      steps++;
      done = episodeDone;
    }

    rewards.push(totalReward);
    
    // 每10轮输出一次进度
    if (episode % 10 === 0) {
      console.log(`Episode ${episode}, Total Reward: ${totalReward.toFixed(2)}, Steps: ${steps}`);
    }

    // 更新探索率
    agent.updateEpsilon();
  }

  return { agent, rewards };
}

高级应用示例

4. 游戏 AI 训练

// 简单的游戏环境 - 避障游戏
class ObstacleGame {
  constructor() {
    this.playerY = 5;
    this.obstacles = [{ x: 10, gap: 3 }];
    this.gameWidth = 15;
    this.gameHeight = 10;
    this.score = 0;
  }

  getState() {
    // 状态包括:玩家位置和最近的障碍物信息
    const nearestObstacle = this.obstacles[0];
    return tf.tensor2d([
      this.playerY,
      nearestObstacle ? nearestObstacle.x / this.gameWidth : 1,
      nearestObstacle ? nearestObstacle.gap / this.gameHeight : 0.3
    ], [1, 3]);
  }

  step(action) {
    // 动作:0=上,1=下,2=保持
    if (action === 0) this.playerY = Math.max(0, this.playerY - 1);
    if (action === 1) this.playerY = Math.min(this.gameHeight - 1, this.playerY + 1);

    // 移动障碍物
    this.obstacles.forEach(obs => obs.x -= 1);
    
    let reward = 0.1; // 存活奖励
    let done = false;

    // 检查碰撞
    const nearestObstacle = this.obstacles[0];
    if (nearestObstacle && nearestObstacle.x === 0) {
      if (this.playerY < nearestObstacle.gap || this.playerY >= nearestObstacle.gap + 2) {
        reward = -10; // 碰撞惩罚
        done = true;
      } else {
        reward = 5; // 成功通过奖励
        this.score++;
      }
      this.obstacles.shift();
    }

    // 生成新障碍物
    if (this.obstacles.length === 0 || this.obstacles[this.obstacles.length - 1].x < 5) {
      this.obstacles.push({
        x: this.gameWidth,
        gap: Math.floor(Math.random() * (this.gameHeight - 3))
      });
    }

    return {
      observation: this.getState(),
      reward,
      done,
      info: { score: this.score }
    };
  }

  reset() {
    this.playerY = 5;
    this.obstacles = [{ x: 10, gap: 3 }];
    this.score = 0;
    return this.getState();
  }
}

5. 可视化训练过程

function visualizeTraining(environment, agent, canvasId) {
  const canvas = document.getElementById(canvasId);
  const ctx = canvas.getContext('2d');
  const cellSize = 30;

  function draw() {
    ctx.clearRect(0, 0, canvas.width, canvas.height);
    
    // 绘制网格
    for (let x = 0; x < environment.gridSize; x++) {
      for (let y = 0; y < environment.gridSize; y++) {
        ctx.strokeRect(x * cellSize, y * cellSize, cellSize, cellSize);
        
        // 绘制障碍物
        if (environment.obstacles.some(([ox, oy]) => ox === x && oy === y)) {
          ctx.fillStyle = 'red';
          ctx.fillRect(x * cellSize, y * cellSize, cellSize, cellSize);
        }
        
        // 绘制目标
        if (x === environment.goalPosition[0] && y === environment.goalPosition[1]) {
          ctx.fillStyle = 'green';
          ctx.fillRect(x * cellSize, y * cellSize, cellSize, cellSize);
        }
      }
    }
    
    // 绘制智能体
    const [agentX, agentY] = environment.agentPosition;
    ctx.fillStyle = 'blue';
    ctx.beginPath();
    ctx.arc(agentX * cellSize + cellSize/2, agentY * cellSize + cellSize/2, cellSize/3, 0, 2 * Math.PI);
    ctx.fill();
  }

  return { draw };
}

实际应用场景

  1. 游戏 AI:训练网页游戏中的智能 NPC
  2. 推荐系统:根据用户行为优化内容推荐
  3. 机器人控制:模拟机器人导航和决策
  4. 资源管理:优化网页资源加载策略
  5. 自适应 UI:根据用户习惯调整界面

最佳实践和注意事项

// 1. 内存管理
async function trainWithMemoryManagement() {
  const { agent, environment } = await createAndTrainAgent();
  
  for (let episode = 0; episode < 100; episode++) {
    let state = environment.reset();
    let done = false;
    
    // 使用 tf.tidy 自动清理中间张量
    while (!done) {
      const result = tf.tidy(() => {
        const action = agent.predict(state);
        const stepResult = environment.step(action);
        return { action, stepResult };
      });
      
      // 手动清理不再需要的张量
      state.dispose();
      state = result.stepResult.observation;
      
      if (result.stepResult.done) break;
    }
  }
}

// 2. 定期保存模型
async function saveModelPeriodically(agent, interval = 100) {
  for (let episode = 0; episode < 1000; episode++) {
    // ... 训练代码 ...
    
    if (episode % interval === 0) {
      await agent.model.save('indexeddb://my-trained-agent');
      console.log(`Model saved at episode ${episode}`);
    }
  }
}

TensorFlow.js Agents 为在浏览器中实现复杂的决策智能系统提供了强大的工具。虽然训练过程可能比服务器端慢,但它提供了独特的优势:完全在客户端运行、保护用户隐私、实时交互等。

Logo

更多推荐