TensorFlow.js 中最前沿和强大的功能之一——Agents
TensorFlow.js Agents 是一个库,用于在浏览器中实现强化学习。它允许你创建能够通过与环境交互来学习最优行为的智能体(Agents)。
·
这个功能将强化学习(Reinforcement Learning)带到了浏览器环境中。
什么是 TensorFlow.js Agents?
TensorFlow.js Agents 是一个库,用于在浏览器中实现强化学习。它允许你创建能够通过与环境交互来学习最优行为的智能体(Agents)。
核心概念
- Agent(智能体):学习的实体,做出决策
- Environment(环境):智能体交互的世界
- Action(动作):智能体采取的行为
- Reward(奖励):环境对动作的反馈
- Policy(策略):智能体决定动作的规则
安装和基础设置
# 安装核心库
npm install @tensorflow/tfjs-core
# 安装 Agents 库
npm install @tensorflow/tfjs-agents
或者使用 CDN:
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-agents"></script>
基础用法示例
1. 创建一个简单的强化学习环境
import * as tf from '@tensorflow/tfjs';
import { createDQNAgent, getCartPole } from '@tensorflow/tfjs-agents';
// 创建一个简单的网格世界环境
class GridWorld {
constructor() {
this.gridSize = 5;
this.agentPosition = [0, 0];
this.goalPosition = [4, 4];
this.obstacles = [[1, 1], [2, 2], [3, 3]];
this.actions = ['up', 'down', 'left', 'right']; // 0, 1, 2, 3
}
getState() {
// 将位置转换为状态表示
return tf.tensor2d([this.agentPosition[0], this.agentPosition[1]], [1, 2]);
}
step(action) {
const [x, y] = this.agentPosition;
let newX = x, newY = y;
let reward = -0.1; // 每一步的小惩罚,鼓励尽快到达目标
let done = false;
// 执行动作
switch(action) {
case 0: newY = Math.max(0, y - 1); break; // up
case 1: newY = Math.min(this.gridSize - 1, y + 1); break; // down
case 2: newX = Math.max(0, x - 1); break; // left
case 3: newX = Math.min(this.gridSize - 1, x + 1); break; // right
}
// 检查是否碰到障碍物
const isObstacle = this.obstacles.some(([ox, oy]) => ox === newX && oy === newY);
if (!isObstacle) {
this.agentPosition = [newX, newY];
}
// 检查是否到达目标
if (newX === this.goalPosition[0] && newY === this.goalPosition[1]) {
reward = 10; // 到达目标的奖励
done = true;
}
return {
observation: this.getState(),
reward,
done,
info: {}
};
}
reset() {
this.agentPosition = [0, 0];
return this.getState();
}
}
2. 创建和训练 DQN Agent
async function createAndTrainAgent() {
// 创建环境
const environment = new GridWorld();
// 定义模型架构
const model = tf.sequential({
layers: [
tf.layers.dense({ units: 24, inputShape: [2], activation: 'relu' }),
tf.layers.dense({ units: 24, activation: 'relu' }),
tf.layers.dense({ units: 4, activation: 'linear' }) // 4个动作的输出
]
});
// 编译模型
model.compile({
optimizer: tf.train.adam(0.001),
loss: 'meanSquaredError'
});
// 创建 DQN Agent
const agent = await createDQNAgent({
model: model,
numActions: 4, // 上下左右四个动作
memorySize: 1000,
batchSize: 32,
epsilon: 1.0, // 探索率
epsilonDecay: 0.995,
epsilonMin: 0.01
});
return { agent, environment };
}
3. 训练循环
async function trainAgent(episodes = 100) {
const { agent, environment } = await createAndTrainAgent();
const rewards = [];
for (let episode = 0; episode < episodes; episode++) {
let state = environment.reset();
let totalReward = 0;
let steps = 0;
let done = false;
while (!done && steps < 100) { // 最多100步
// 选择动作
const action = await agent.predict(state);
// 执行动作
const { observation: nextState, reward, done: episodeDone } = environment.step(action);
// 存储经验
await agent.remember(state, action, reward, nextState, episodeDone);
// 训练智能体
if (steps % 4 === 0) {
await agent.train();
}
state = nextState;
totalReward += reward;
steps++;
done = episodeDone;
}
rewards.push(totalReward);
// 每10轮输出一次进度
if (episode % 10 === 0) {
console.log(`Episode ${episode}, Total Reward: ${totalReward.toFixed(2)}, Steps: ${steps}`);
}
// 更新探索率
agent.updateEpsilon();
}
return { agent, rewards };
}
高级应用示例
4. 游戏 AI 训练
// 简单的游戏环境 - 避障游戏
class ObstacleGame {
constructor() {
this.playerY = 5;
this.obstacles = [{ x: 10, gap: 3 }];
this.gameWidth = 15;
this.gameHeight = 10;
this.score = 0;
}
getState() {
// 状态包括:玩家位置和最近的障碍物信息
const nearestObstacle = this.obstacles[0];
return tf.tensor2d([
this.playerY,
nearestObstacle ? nearestObstacle.x / this.gameWidth : 1,
nearestObstacle ? nearestObstacle.gap / this.gameHeight : 0.3
], [1, 3]);
}
step(action) {
// 动作:0=上,1=下,2=保持
if (action === 0) this.playerY = Math.max(0, this.playerY - 1);
if (action === 1) this.playerY = Math.min(this.gameHeight - 1, this.playerY + 1);
// 移动障碍物
this.obstacles.forEach(obs => obs.x -= 1);
let reward = 0.1; // 存活奖励
let done = false;
// 检查碰撞
const nearestObstacle = this.obstacles[0];
if (nearestObstacle && nearestObstacle.x === 0) {
if (this.playerY < nearestObstacle.gap || this.playerY >= nearestObstacle.gap + 2) {
reward = -10; // 碰撞惩罚
done = true;
} else {
reward = 5; // 成功通过奖励
this.score++;
}
this.obstacles.shift();
}
// 生成新障碍物
if (this.obstacles.length === 0 || this.obstacles[this.obstacles.length - 1].x < 5) {
this.obstacles.push({
x: this.gameWidth,
gap: Math.floor(Math.random() * (this.gameHeight - 3))
});
}
return {
observation: this.getState(),
reward,
done,
info: { score: this.score }
};
}
reset() {
this.playerY = 5;
this.obstacles = [{ x: 10, gap: 3 }];
this.score = 0;
return this.getState();
}
}
5. 可视化训练过程
function visualizeTraining(environment, agent, canvasId) {
const canvas = document.getElementById(canvasId);
const ctx = canvas.getContext('2d');
const cellSize = 30;
function draw() {
ctx.clearRect(0, 0, canvas.width, canvas.height);
// 绘制网格
for (let x = 0; x < environment.gridSize; x++) {
for (let y = 0; y < environment.gridSize; y++) {
ctx.strokeRect(x * cellSize, y * cellSize, cellSize, cellSize);
// 绘制障碍物
if (environment.obstacles.some(([ox, oy]) => ox === x && oy === y)) {
ctx.fillStyle = 'red';
ctx.fillRect(x * cellSize, y * cellSize, cellSize, cellSize);
}
// 绘制目标
if (x === environment.goalPosition[0] && y === environment.goalPosition[1]) {
ctx.fillStyle = 'green';
ctx.fillRect(x * cellSize, y * cellSize, cellSize, cellSize);
}
}
}
// 绘制智能体
const [agentX, agentY] = environment.agentPosition;
ctx.fillStyle = 'blue';
ctx.beginPath();
ctx.arc(agentX * cellSize + cellSize/2, agentY * cellSize + cellSize/2, cellSize/3, 0, 2 * Math.PI);
ctx.fill();
}
return { draw };
}
实际应用场景
- 游戏 AI:训练网页游戏中的智能 NPC
- 推荐系统:根据用户行为优化内容推荐
- 机器人控制:模拟机器人导航和决策
- 资源管理:优化网页资源加载策略
- 自适应 UI:根据用户习惯调整界面
最佳实践和注意事项
// 1. 内存管理
async function trainWithMemoryManagement() {
const { agent, environment } = await createAndTrainAgent();
for (let episode = 0; episode < 100; episode++) {
let state = environment.reset();
let done = false;
// 使用 tf.tidy 自动清理中间张量
while (!done) {
const result = tf.tidy(() => {
const action = agent.predict(state);
const stepResult = environment.step(action);
return { action, stepResult };
});
// 手动清理不再需要的张量
state.dispose();
state = result.stepResult.observation;
if (result.stepResult.done) break;
}
}
}
// 2. 定期保存模型
async function saveModelPeriodically(agent, interval = 100) {
for (let episode = 0; episode < 1000; episode++) {
// ... 训练代码 ...
if (episode % interval === 0) {
await agent.model.save('indexeddb://my-trained-agent');
console.log(`Model saved at episode ${episode}`);
}
}
}
TensorFlow.js Agents 为在浏览器中实现复杂的决策智能系统提供了强大的工具。虽然训练过程可能比服务器端慢,但它提供了独特的优势:完全在客户端运行、保护用户隐私、实时交互等。
更多推荐
所有评论(0)