-
Notifications
You must be signed in to change notification settings - Fork 50
Go‐MPC
完整版请点击传送门
- 学习 GO-MPC 论文
- 根据 github 源码分析实现的原理及结构
- train.py
-
model = ALGOS[args.algo]()
创建模型; - 其中
ALGOS[args.algo]
代表算法名; - 如果继续训练通过
model = ALGOS[args.algo].load()
进行加载load()
方法 通过继承ActorCriticRLModel
实现; - 通过
model.learn(n_timesteps, **kwargs)
来训练模型;
-
- ppo2mpc.py
-
__init__()
中进行初始化 并通过self.setup_model()
创建模型; -
learn()
方法进行训练;
-
- 定义命令行参数并解析
- 在此程序中 通过对命令行输入的解析 来确定调用的算法库
- 从 yaml 文件加载超参数
- 根据所选算法创建 RL 算法对象
- 设置算法需要的 learning rate 和 schedules 这个代码提供了一个使用Stable Baselines库训练强化学习代理的脚本 代码导入了各种模块并定义了用于配置训练过程的命令行参数
以下是代码的主要组成部分的解析:
-
导入必要的模块并设置环境:
-
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
禁用了 GPU 的使用 - 导入了各种模块,包括 Stable Baselines 和其他自定义模块
- 配置设置和环境的设置
-
-
命令行参数解析:
- 脚本使用
argparse.ArgumentParser
来定义和解析命令行参数 - 可以提供参数来指定强化学习算法、环境、日志设置、训练代理文件、超参数等
- 脚本使用
-
超参数设置和自定义:
- 脚本根据选择的强化学习算法和环境从一个 YAML 文件中加载超参数
- 超参数可以通过命令行参数进行覆盖
-
训练设置和配置:
- 脚本根据选择的强化学习算法和环境设置训练过程
- 配置了时间步数、评估设置、日志目录、保存检查点和其他训练参数
-
环境设置和初始化:
- 脚本根据选择的强化学习算法和环境 ID 创建训练和评估环境
- 还设置了环境的包装器或预处理器
-
回调函数和日志记录:
- 脚本定义了各种回调函数 如保存检查点和评估 在训练过程中执行
- 设置了 TensorBoard 日志记录并保存了超参数和训练模型
-
执行训练过程:
- 在所有配置和设置步骤之后 脚本根据指定的时间步数开始训练循环
- 训练进度根据配置的设置定期记录和保存
这个YAML文件中包含了不同环境的强化学习算法的超参数配置 每个环境都有其对应的配置 以下是每个环境的超参数含义:
-
gym-collision-avoidance
: 使用 MlpLstmPolicy 算法 进行碰撞避免训练 -
atari
: 使用 CnnPolicy 算法 在 Atari 游戏上进行训练 -
Pendulum-v0
: 使用 MlpPolicy 算法 对 Pendulum-v0 环境进行训练 -
CartPole-v1
: 使用 MlpPolicy 算法 对 CartPole-v1 环境进行训练 -
CartPoleBulletEnv-v1
: 使用 MlpPolicy 算法 对 CartPoleBulletEnv-v1 环境进行训练 -
CartPoleContinuousBulletEnv-v0
: 使用 MlpPolicy 算法 对 CartPoleContinuousBulletEnv-v0 环境进行训练 -
MountainCar-v0
: 使用 MlpPolicy 算法 对 MountainCar-v0 环境进行训练 -
MountainCarContinuous-v0
: 使用 MlpPolicy 算法 对 MountainCarContinuous-v0 环境进行训练 -
Acrobot-v1
: 使用 MlpPolicy 算法 对 Acrobot-v1 环境进行训练 -
BipedalWalker-v3
: 使用 MlpPolicy 算法 对 BipedalWalker-v3 环境进行训练 -
BipedalWalkerHardcore-v3
: 使用 MlpPolicy 算法 对 BipedalWalkerHardcore-v3 环境进行训练 -
LunarLander-v2
: 使用 MlpPolicy 算法 对 LunarLander-v2 环境进行训练 -
LunarLanderContinuous-v2
: 使用 MlpPolicy 算法 对 LunarLanderContinuous-v2 环境进行训练 -
Walker2DBulletEnv-v0
: 使用 MlpPolicy 算法 对 Walker2DBulletEnv-v0 环境进行训练 -
HalfCheetahBulletEnv-v0
: 使用 MlpPolicy 算法 对 HalfCheetahBulletEnv-v0 环境进行训练 -
HalfCheetah-v2
: 使用 MlpPolicy 算法 对 HalfCheetah-v2 环境进行训练 -
AntBulletEnv-v0
: 使用 CustomMlpPolicy 算法 对 AntBulletEnv-v0 环境进行训练 -
HopperBulletEnv-v0
: 使用 MlpPolicy 算法 对 HopperBulletEnv-v0 环境进行训练 -
ReacherBulletEnv-v0
: 使用 MlpPolicy 算法 对 ReacherBulletEnv-v0 环境进行训练 -
MinitaurBulletEnv-v0
: 使用 MlpPolicy 算法 对 MinitaurBulletEnv-v0 环境进行训练 -
MinitaurBulletDuckEnv-v0
: 使用 MlpPolicy 算法 对 MinitaurBulletDuckEnv-v0 环境进行训练 -
HumanoidBulletEnv-v0
: 使用 MlpPolicy 算法 对 HumanoidBulletEnv-v0 环境进行训练 -
InvertedDoublePendulumBulletEnv-v0
: 使用 MlpPolicy 算法 对 InvertedDoublePendulumBulletEnv-v0 环境进行训练 -
InvertedPendulumSwingupBulletEnv-v0
: 使用 MlpPolicy 算法 对 InvertedPendulumSwingupBulletEnv-v0 环境进行训练 -
MiniGrid-DoorKey-5x5-v0
: 使用 MlpPolicy 算法 对 MiniGrid-DoorKey-5x5-v0 环境进行训练 -
MiniGrid-FourRooms-v0
: 使用 MlpPolicy 算法 对 MiniGrid-FourRooms-v0 环境进行训练
每个环境都指定了超参数 例如训练步数(n_timesteps)、策略(policy)、学习率(learning_rate)等 这些超参数将用于训练强化学习模型
- 创建了一个 PPO2MPC 类 函数和类构成如下
-
代码定义了一个名为
PPO2MPC
的类 它是ActorCriticRLModel
的子类(Stable Baselines 中用于 Actor-Critic 模型的抽象基类)PPO2MPC
类实现了 PPO 算法 并使用MPCRunner
类提供对模型预测控制(MPC)的额外支持 -
以下是代码的一些关键组成部分:
-
__init__
方法定义了PPO2MPC
类的构造函数 并初始化了各种参数 如学习率、剪切范围、熵系数等 它还设置了用于训练模型的 TensorFlow 图和会话 -
setup_model
方法构建了 PPO 模型的计算图、创建了用于输入数据(观测、动作、优势等)的 TensorFlow 占位符、定义了用于训练模型的损失函数和优化器、创建一个TensorFlow计算图、建立训练操作和监督训练操作、根据策略是否为递归策略,设置每个步骤的批次大小和训练的批次大小 -
Runner
和MPCRunner
类用于通过与环境交互收集经验 它们通过在环境中执行当前策略生成轨迹 并将收集的数据(观测、动作、奖励等)存储用于训练 -
training_step
方法执行 PPO 算法的单个更新步骤 它接收 Runner 收集的数据批次 计算替代损失 并执行梯度下降更新模型的参数 -
learn
方法是PPO算法的主要训练循环 它重复调用training_step
方法以使用收集的数据执行多个更新步骤
-
- 主要思想是通过限制每次更新的步长 使得策略的更新不会对旧策略产生影响 这样可以保证在更新策略网络时 新的策略和旧的策之间的差距不会太大
- 在训练过程中使用了一个价值网络来评估每个状态的价值函数 从而更准确地计算出回报函数
- 还采用了重要性采样的技术 使得算法可以利用以前的经验来更新策略网络
- PPO 算法是 Actor-Critic 算法的变体 Actor 和 Critic 都通过神经网络来实现(ActorCritic类) (而在本类中 通过 ActorCriticRLModel 来实现) Actor 网络用于学习和输出动作的概率 Critic 网络用来评估状态值(也就是评估这个动作的好坏)
- 算法原理参考
- 在此项目中的具体原理实现
- 目标函数
- 损失函数 ->
- 奖励机制
- state
- action
- 标签
- 数据