Skip to content

Commit

Permalink
优化PPO
Browse files Browse the repository at this point in the history
优化PPO算法,使其更加稳定。
  • Loading branch information
yangtao121 committed Feb 24, 2023
1 parent a70eb80 commit c285182
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 40 deletions.
118 changes: 81 additions & 37 deletions AquaML/rlalgo/PPO.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,51 +215,95 @@ def _optimize_(self):
for idx in self.expand_dims_idx:
actor_obs[idx] = tf.expand_dims(actor_obs[idx], axis=1)

info_list = []
buffer_size = train_actor_input['actor_obs'][0].shape[0]
critic_buffer_size = self.hyper_parameters.buffer_size
critic_batch_steps = self.hyper_parameters.batch_size

for _ in range(self.hyper_parameters.update_times):
# train actor
# TODO: wrap this part into a function
for _ in range(self.hyper_parameters.update_actor_times):
start_index = 0
end_index = 0
# fusion ppo firstly update critic
start_index = 0
end_index = 0
critic_start_index = 0
while end_index < buffer_size:
end_index = min(start_index + self.hyper_parameters.batch_size,
buffer_size)
critic_end_index = min(critic_start_index + critic_batch_steps, critic_buffer_size)
critic_optimize_info_list = []
actor_optimize_info_list = []
while end_index < self.hyper_parameters.buffer_size:
end_index = min(start_index + self.hyper_parameters.batch_size, self.hyper_parameters.buffer_size)

batch_train_actor_input = self.get_batch_data(train_actor_input, start_index, end_index)

start_index = end_index
batch_train_actor_input = self.get_batch_data(train_actor_input, start_index, end_index)
batch_train_critic_input = self.get_batch_data(train_critic_input, critic_start_index, critic_end_index)
start_index = end_index
critic_start_index = critic_end_index
for _ in range(self.hyper_parameters.update_critic_times):
critic_optimize_info = self.train_critic(
critic_obs=batch_train_critic_input['critic_obs'],
target=batch_train_critic_input['target'],
)
critic_optimize_info_list.append(critic_optimize_info)

for _ in range(self.hyper_parameters.update_actor_times):
actor_optimize_info = self.train_actor(
actor_obs=batch_train_actor_input['actor_obs'],
advantage=batch_train_actor_input['advantage'],
old_log_prob=batch_train_actor_input['old_log_prob'],
action=batch_train_actor_input['action'],
epsilon=tf.cast(self.hyper_parameters.epsilon, dtype=tf.float32),
entropy_coefficient=tf.cast(self.hyper_parameters.entropy_coeff, dtype=tf.float32),
)
actor_optimize_info_list.append(actor_optimize_info)
critic_optimize_info = self.cal_average_batch_dict(critic_optimize_info_list)
actor_optimize_info = self.cal_average_batch_dict(actor_optimize_info_list)

# train critic
for _ in range(self.hyper_parameters.update_critic_times):
start_index = 0
end_index = 0
critic_optimize_info_list = []
for _ in range(self.hyper_parameters.update_critic_times):
while end_index < self.hyper_parameters.buffer_size:
end_index = min(start_index + self.hyper_parameters.batch_size,
self.hyper_parameters.buffer_size)

batch_train_critic_input = self.get_batch_data(train_critic_input, start_index, end_index)

start_index = end_index

critic_optimize_info = self.train_critic(
critic_obs=batch_train_critic_input['critic_obs'],
target=batch_train_critic_input['target'],
)
critic_optimize_info_list.append(critic_optimize_info)
critic_optimize_info = self.cal_average_batch_dict(critic_optimize_info_list)

return_dict = {**actor_optimize_info, **critic_optimize_info}
return return_dict
info = {**critic_optimize_info, **actor_optimize_info}
info_list.append(info)

info = self.cal_average_batch_dict(info_list)

return info

# for _ in range(self.hyper_parameters.update_times):
# # train actor
# # TODO: wrap this part into a function
# for _ in range(self.hyper_parameters.update_actor_times):
# start_index = 0
# end_index = 0
# actor_optimize_info_list = []
# while end_index < self.hyper_parameters.buffer_size:
# end_index = min(start_index + self.hyper_parameters.batch_size, self.hyper_parameters.buffer_size)
#
# batch_train_actor_input = self.get_batch_data(train_actor_input, start_index, end_index)
#
# start_index = end_index
#
# actor_optimize_info = self.train_actor(
# actor_obs=batch_train_actor_input['actor_obs'],
# advantage=batch_train_actor_input['advantage'],
# old_log_prob=batch_train_actor_input['old_log_prob'],
# action=batch_train_actor_input['action'],
# epsilon=tf.cast(self.hyper_parameters.epsilon, dtype=tf.float32),
# entropy_coefficient=tf.cast(self.hyper_parameters.entropy_coeff, dtype=tf.float32),
# )
# actor_optimize_info_list.append(actor_optimize_info)
# actor_optimize_info = self.cal_average_batch_dict(actor_optimize_info_list)
#
# # train critic
# for _ in range(self.hyper_parameters.update_critic_times):
# start_index = 0
# end_index = 0
# critic_optimize_info_list = []
# for _ in range(self.hyper_parameters.update_critic_times):
# while end_index < self.hyper_parameters.buffer_size:
# end_index = min(start_index + self.hyper_parameters.batch_size,
# self.hyper_parameters.buffer_size)
#
# batch_train_critic_input = self.get_batch_data(train_critic_input, start_index, end_index)
#
# start_index = end_index
#
# critic_optimize_info = self.train_critic(
# critic_obs=batch_train_critic_input['critic_obs'],
# target=batch_train_critic_input['target'],
# )
# critic_optimize_info_list.append(critic_optimize_info)
# critic_optimize_info = self.cal_average_batch_dict(critic_optimize_info_list)
#
# return_dict = {**actor_optimize_info, **critic_optimize_info}
# return return_dict
6 changes: 3 additions & 3 deletions Tutorial/Tutorial3.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ def close(self):
epoch_length=200,
n_epochs=2000,
total_steps=4000,
batch_size=32,
batch_size=128,
update_times=4,
update_actor_times=1,
update_critic_times=2,
update_actor_times=4,
update_critic_times=4,
gamma=0.99,
epsilon=0.2,
lambada=0.95
Expand Down

0 comments on commit c285182

Please sign in to comment.