上QQ阅读APP看书,第一时间看更新
7.3 PTAN版本的CartPole解决方案
现在我们来使用PTAN中的类(暂时不使用Ignite)并尝试将所有内容组合在一起,解决我们遇到的第一个环境:CartPole。完整的代码在Chapter07/06_cartpole.py
。此处仅展示与刚刚介绍的内容相关的重要代码。
在开始时,创建了NN(之前在CartPole中使用过的简单的两层前馈NN)、目标NN、ε-greedy动作选择器以及DQNAgent
。然后又创建了经验源和回放缓冲区。仅这几行代码,就完成了数据管道。接下来只需要调用缓冲区的populate()
方法来采样一些训练批。
在每个训练循环开始时,我们都要求缓冲区从经验源获取一个样本并检查片段是否结束。ExperienceSource
类的pop_rewards_steps()
方法返回一个元组列表,包含了自上一次调用该方法后的所有已结束的片段信息。
在训练循环的后半部分,我们将一批ExperienceFirstLast
对象转换成了适合DQN训练的张量、计算了损失并且执行了反向传播。最后,衰减动作选择器的epsilon
值(根据所使用的超参数,epsilon
会在训练的第500步衰减至0),并让目标网络每10次训练迭代进行一次同步。
代码执行后,应该在1000~2000个训练迭代后收敛。