From 4694dc5c1edcccc3e787f9984d7f530e8fa993d9 Mon Sep 17 00:00:00 2001 From: sumwailiu <798465811@qq.com> Date: Wed, 8 Jan 2025 13:09:36 +0800 Subject: [PATCH] reward computation fix --- main.py | 12 ++++++------ utils.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index 2e5a63b..9d76b3e 100644 --- a/main.py +++ b/main.py @@ -91,7 +91,7 @@ parser.add_argument('--global-kl-beta', type=float, default=0, metavar='βg', help='Global KL weight (0 to disable)') parser.add_argument('--free-nats', type=float, default=3, metavar='F', help='Free nats') parser.add_argument('--bit-depth', type=int, default=5, metavar='B', help='Image bit depth (quantisation)') -parser.add_argument('--model_learning-rate', type=float, default=1e-3, metavar='α', help='Learning rate') +parser.add_argument('--model_learning-rate', type=float, default=6e-4, metavar='α', help='Learning rate') parser.add_argument('--actor_learning-rate', type=float, default=8e-5, metavar='α', help='Learning rate') parser.add_argument('--value_learning-rate', type=float, default=8e-5, metavar='α', help='Learning rate') parser.add_argument( @@ -374,10 +374,10 @@ def update_belief_and_act( ) if args.worldmodel_LogProbLoss: reward_dist = Normal(bottle(reward_model, (beliefs, posterior_states)), 1) - reward_loss = -reward_dist.log_prob(rewards[:-1]).mean(dim=(0, 1)) + reward_loss = -reward_dist.log_prob(rewards[1:]).mean(dim=(0, 1)) else: reward_loss = F.mse_loss( - bottle(reward_model, (beliefs, posterior_states)), rewards[:-1], reduction='none' + bottle(reward_model, (beliefs, posterior_states)), rewards[1:], reduction='none' ).mean(dim=(0, 1)) # transition loss div = kl_divergence(Normal(posterior_means, posterior_std_devs), Normal(prior_means, prior_std_devs)).sum(dim=2) @@ -479,7 +479,7 @@ def update_belief_and_act( imged_reward = bottle(reward_model, (imged_beliefs, imged_prior_states)) value_pred = bottle(value_model, (imged_beliefs, imged_prior_states)) returns = lambda_return( - imged_reward, value_pred, bootstrap=value_pred[-1], discount=args.discount, lambda_=args.disclam + imged_reward[:-1], value_pred[:-1], bootstrap=value_pred[-1], discount=args.discount, lambda_=args.disclam ) actor_loss = -torch.mean(returns) # Update model parameters @@ -494,7 +494,7 @@ def update_belief_and_act( value_prior_states = imged_prior_states.detach() target_return = returns.detach() value_dist = Normal( - bottle(value_model, (value_beliefs, value_prior_states)), 1 + bottle(value_model, (value_beliefs, value_prior_states))[:-1], 1 ) # detach the input tensor from the transition network. value_loss = -value_dist.log_prob(target_return).mean(dim=(0, 1)) # Update model parameters @@ -681,7 +681,7 @@ def update_belief_and_act( ) if args.checkpoint_experience: torch.save( - D, os.path.join(results_dir, 'experience.pth') + D, os.path.join(results_dir, 'experience.pth'), pickle_protocol=5 ) # Warning: will fail with MemoryError with large memory sizes diff --git a/utils.py b/utils.py index df7b3bf..12796f2 100644 --- a/utils.py +++ b/utils.py @@ -98,8 +98,8 @@ def imagine_ahead(prev_state, prev_belief, policy, transition_model, planning_ho # Return new hidden states # imagined_traj = [beliefs, prior_states, prior_means, prior_std_devs] imagined_traj = [ - torch.stack(beliefs[1:], dim=0), - torch.stack(prior_states[1:], dim=0), + torch.stack(beliefs, dim=0), + torch.stack(prior_states, dim=0), torch.stack(prior_means[1:], dim=0), torch.stack(prior_std_devs[1:], dim=0), ]