DQN
DQN
一、解读
背景:强化学习 (Reinforcement Learning, RL) 与挑战
- 基本设定: RL的基本设定,无需赘述
- 挑战:
- RL在高维输入情况下的不足: [source: 5] 直接从像视觉这样的高维感官输入学习控制策略是 RL 的长期挑战。传统方法通常依赖手工设计的特征的好坏
- 于是我们引入深度学习的神经网络,利用它特征提取的能力来改善RL在高位输入情况下的不足,然而,在RL与DL结合的过程中有如下挑战:
- 稀疏/延迟奖励: [source: 13, 14] 与监督学习不同,RL 的奖励信号可能是稀疏的(不常出现)、有噪声的,并且可能是延迟的(执行一个动作后很久才看到奖励结果)。
- 数据相关性: [source: 15] RL 中智能体经历的状态序列通常是高度相关的,这违反了许多深度学习算法(如 SGD)样本独立性的假设。
- 非平稳分布: [source: 16] 随着智能体学习新的行为,它遇到的数据分布会发生变化,而深度学习方法通常假设数据来自一个固定的分布。
核心方法:深度 Q 网络 (Deep Q-Network, DQN)
为了克服上述挑战,作者提出了 DQN 算法,主要包含以下几个关键部分:
- Q-Learning 基础:
- 动作价值函数 (Q-function): [source: 43] $Q^*(s, a)$ 表示在状态(或状态序列) s 下,执行动作 a,然后遵循最优策略所能获得的最大期望未来折扣奖励。
- 贝尔曼方程 (Bellman Equation): [source: 44, 46] 这是 Q-learning 的核心,它表明最优 Q 值可以通过当前奖励 r 和下一状态 s’ 的最大可能 Q 值来递归定义:
- 函数近似: [source: 48, 49, 50] 由于状态空间(尤其是基于图像序列的状态)极其巨大,无法为每个状态动作对存储 Q 值。因此,使用一个函数(这里是神经网络,称为 Q-network)来近似 Q^*(s,a),记为 Q(s,a;$\theta$),其中 $\theta$ 是网络参数。
- 训练: [source: 51, 52, 55] 通过最小化损失函数(通常是预测 Q 值与目标 Q 值之间的均方误差)来训练网络。目标 Q 值 $y_i = r + \gamma \max_{a’} Q(s’, a’; \theta_{i-1})$,其中 $\theta_{i-1}$ 是前一次迭代的网络参数(目标网络)。使用随机梯度下降 (SGD) 更新参数 $\theta_i$。
- 关键创新点:
- 卷积神经网络 (CNN): 作者没有使用手工特征,而是直接将预处理后的游戏屏幕像素(一个包含最近几帧信息的图像堆栈)作为输入,利用 CNN 强大的特征提取能力来自动学习与游戏状态和价值相关的表示。
- 相比于Q- learning:
- 其实DQN就是在Qlearning的基础上,把计算 ”$max_{a’}$” 的方法从使用Qtable,变成使用深度神经网络这个函数来计算$max_{a’}$。
- 我们知道,解决连续型问题,如果表格不能表示,就用函数,而最好的函数就是深度神经网络。
- 输入预处理: 原始 Atari 帧先转换为灰度图,缩放到 $110 \times 84$,再裁剪出 $84 \times 84$ 的区域。将最近的 4 帧处理后堆叠起来,形成 $84 \times 84 \times 4$ 的输入,以捕捉运动信息。
- 网络架构: 输入层后接两个卷积层(带 ReLU激活函数)和一个全连接层(带 ReLU),最后是一个全连接的线性输出层。 [source: 117, 118] 这个输出层为每个可能的动作都提供一个输出单元,直接预测该状态下每个动作的 Q 值。这种架构的优点是只需一次前向传播即可获得所有动作的 Q 值。 [source: 119]
- 相比于Q- learning:
- 经验回放 (Experience Replay): 解决数据相关性和非平稳分布问题的关键。
- 机制: 将智能体在交互过程中产生的经历(状态 s_t、动作 a_t、奖励 r_t、下一状态 $s_{t+1}$,表示为 $e_t = (s_t, a_t, r_t, s_{t+1}))$存储在一个固定大小的回放记忆库 $\mathcal{D}$ 中。
- 训练: 在训练时,不是使用当前的经历,而是从记忆库中随机采样一个小批量(minibatch)的经历来进行 Q-learning 更新。
- 优点:
- [source: 96] 提高数据效率: 每个经历样本可能被用于多次参数更新。
- [source: 99] 打破相关性: 随机采样打破了连续样本间的强相关性,降低了更新的方差。
- [source: 104] 稳定学习过程: 通过从大量过去的经历中采样,平滑了训练数据的分布,避免了在线学习中可能出现的参数震荡或发散问题。 [source: 103]
- 目标网络 (Target Network - 隐含在公式中): [source: 51, 52] 在计算目标 Q 值 $y_i$ 时,使用的是一个较旧的网络参数副本 $\theta_{i-1}$(或者在后来的 DQN 改进版中,是一个定期同步参数的目标网络 $\theta^-)$,而不是当前正在更新的参数 $\theta_i$。这有助于进一步稳定训练,防止目标值和预测值同时快速变化导致的不稳定。
- 卷积神经网络 (CNN): 作者没有使用手工特征,而是直接将预处理后的游戏屏幕像素(一个包含最近几帧信息的图像堆栈)作为输入,利用 CNN 强大的特征提取能力来自动学习与游戏状态和价值相关的表示。
实验设置与结果
- 环境与任务: [source: 3, 129] 使用 Arcade Learning Environment (ALE) 中的 7 款 Atari 2600 游戏进行测试(如 Pong, Breakout, Space Invaders 等)。
- 通用性: [source: 130] 关键在于,作者对所有 7 款游戏使用了完全相同的网络架构、学习算法和超参数设置,证明了方法的通用性。
- 训练细节:
- 奖励裁剪: [source: 131, 132, 133] 为了使不同游戏(得分范围差异很大)的训练更稳定,训练时将正奖励裁剪为 +1,负奖励裁剪为 -1,0 奖励不变。
- 优化器与探索: [source: 135] 使用 RMSProp 优化器,小批量大小为 32。采用 ε-greedy 探索策略,ε 从 1 线性退火到 0.1(在前 100 万帧),之后固定为 0.1。
- 训练时长与记忆库: [source: 136] 总共训练 1000 万帧,经验回放库大小为 100 万帧。
- 跳帧 (Frame Skipping): [source: 137, 138] 智能体每 k 帧观察一次并选择动作,中间跳过的帧重复执行上一个动作(大多数游戏 k=4)。这可以加快训练速度。 [source: 139]
- 评估与结果:
- 稳定性: [source: 145, 148, 150, 151] 实验表明,虽然单局平均得分(评估指标)波动较大,但预测的平均 Q 值在训练过程中稳定上升,显示了学习过程的稳定性,且未出现发散。
- 可视化: [source: 160, 161, 162, 163] 通过可视化 Seaquest 游戏中的 Q 值变化,展示了网络确实学到了与游戏事件(如敌人出现、发射鱼雷、击中敌人)相关的价值信息。
- 性能比较 (Table 1): [source: 177, 180, 181]
- DQN 在所有 7 款游戏上的平均得分显著优于之前的 RL 方法(如 Sarsa,使用了手工特征)。 [source: 166]
- [source: 189] 在 6 款游戏中,DQN 的平均得分和最高得分均优于基于进化算法的 HNeat 方法(该方法通常需要针对特定游戏设计或利用游戏内部信息)。 [source: 183, 184]
- [source: 4, 190] 在 Breakout、Enduro 和 Pong 这 3 款游戏中,DQN 的表现超过了人类专家玩家。
- [source: 191] 在需要更长远规划的游戏(如 Q*bert, Seaquest)上,DQN 与人类水平尚有差距。
结论与意义
[source: 192] 这篇论文成功地将深度学习(特别是 CNN)与强化学习(特别是 Q-learning)结合起来,创建了一个能够直接从原始像素输入学习复杂控制策略的模型(DQN)。
[source: 193] 结合经验回放和目标网络(虽然论文中未显式称为目标网络,但计算目标 Q 值时使用旧参数起到了类似作用)的技术,有效解决了在高维输入下训练 RL 智能体的稳定性和效率问题。
[source: 194] DQN 在多款 Atari 游戏上取得了当时最先进的性能,并且具有很好的通用性,无需针对特定游戏调整架构或超参数。这项工作开创了深度强化学习这一研究方向,为后续更复杂的决策制定任务(如机器人控制、更复杂的游戏等)打下了坚实的基础。
总结
这篇论文的核心贡献在于提出了 DQN 算法,它巧妙地利用 CNN 处理高维视觉输入,并通过经验回放机制解决了 RL 与深度学习结合时面临的数据相关性和非平稳性挑战,最终在一个具有挑战性的基准(Atari 游戏)上取得了突破性的成果,证明了端到端学习控制策略的可行性和巨大潜力。
二、疑问
1)关于经验回放
问题1:经验回放不会导致过拟合吗?因为它重复使用相同的经验
就像数字识别,从10张图片里挑选,不管怎么打乱顺序,得到的训练成果都是不变的吧?
答:经验回放通常不导致过拟合,反而是有益于减少过拟合的:
强化学习与静态数据集的监督学习有本质的不同:
1. 静态数据集 vs. 动态环境交互:
- 数字识别 (静态数据集): 训练数据 (图片和标签) 的每个数据点 (图片) 是独立的,与前后的数据点没有直接的因果关系。打乱顺序只是改变了训练算法看到数据的顺序,但数据本身和数据之间的关系没有改变。
- DQN (动态环境交互): 训练数据 (经验) 是智能体与环境交互 动态生成 的。经验是一个 序列,包含了状态 (s)、动作 (a)、奖励 (r)、和下一个状态 (s’)。 经验之间存在 时序相关性 和 因果关系。 智能体在某个状态采取动作会影响环境的下一个状态和奖励,进而影响后续的经验序列。 顺序在这里至关重要,因为它反映了智能体在环境中探索和学习的 轨迹 (trajectory)。
问题2:随机采样打破了经验序列的顺序,会不会导致使用的经验都是无逻辑的而产生误导?
- 关键:马尔可夫性可以解释这个疑惑
- 类比:
- 原始序列学习 (不回放): 就像你只按顺序阅读一篇文章,每次只关注当前句子,而忽略之前读过的句子。 你可能会记住一些句子,但可能难以形成对文章整体内容的理解。
- 经验回放学习: 随机抽取出一些句子进行学习。 虽然句子之间的顺序被打乱了,但你仍然可以学习每个句子的含义、语法结构和词汇。 通过学习大量的随机句子,你最终可以提高你的语言能力,即使你没有记住原始文章的顺序。
- 补充:这样打乱顺序可以打破相关性,提高泛化性: (经验回放最重要的优势)
- 正如之前解释的,连续的经验是高度相关的。 如果直接使用这些相关经验进行训练,网络容易 过拟合到当前轨迹的特征,而无法泛化到更广泛的状态空间。
- 随机采样 batch 可以 打破这种相关性,使得网络能够从更分散、更独立的经验中学习,提高泛化能力。 它迫使网络学习更鲁棒的特征,而不是仅仅记住特定轨迹的模式。
This post is licensed under CC BY 4.0 by the author.