From c152b771390c81002384eb8eb4a52a3400b8db0c Mon Sep 17 00:00:00 2001 From: hutanmihai Date: Tue, 16 Jan 2024 02:14:35 +0200 Subject: [PATCH] Improve tests --- src/test.py | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/src/test.py b/src/test.py index a8458f3..4deb292 100644 --- a/src/test.py +++ b/src/test.py @@ -1,5 +1,6 @@ from argparse import ArgumentParser +import torch from gymnasium import Env, make from src.agent import Agent @@ -7,9 +8,10 @@ from src.utils.runners import init_new_episode -def run_test(env: Env, agent: Agent, episodes_to_run: int = 10): - agent.set_epsilon(0) +def run_test(env: Env, agent: Agent, episodes_to_run: int = 100): + agent.set_epsilon(0.1) agent.policy_net.eval() + rewards_history = [] for episode in range(episodes_to_run): total_reward = 0 @@ -23,8 +25,12 @@ def run_test(env: Env, agent: Agent, episodes_to_run: int = 10): state = next_state total_reward += reward + rewards_history.append(total_reward) + print(f"Episode {episode + 1} finished with reward {total_reward}!") + print(f"Average reward over {episodes_to_run} episodes: {sum(rewards_history) / episodes_to_run}") + if __name__ == "__main__": parser = ArgumentParser() @@ -40,25 +46,12 @@ def run_test(env: Env, agent: Agent, episodes_to_run: int = 10): raise ValueError("Please specify either --dqn or --ddqn flag!") # To take the best results in one episode models - if algorithm == "ddqn": - policy_net_path = f"models/policy_net_{algorithm}_solo.pth" - target_net_path = f"models/target_net_{algorithm}_solo.pth" - else: - policy_net_path = f"models/policy_net_{algorithm}_solo.pth" - target_net_path = None + policy_net_path = f"models/policy_net_{algorithm}_solo.pth" # To take the best average results models in 100 episodes - # if algorithm == "ddqn": - # policy_net_path = f"models/policy_net_{algorithm}_avg.pth" - # target_net_path = f"models/target_net_{algorithm}_avg.pth" - # else: - # policy_net_path = f"models/policy_net_{algorithm}_avg.pth" - # target_net_path = None + # policy_net_path = f"models/policy_net_{algorithm}_avg.pth" env: Env = make("ALE/Skiing-v5") agent = Agent(action_space=env.action_space, algorithm=algorithm) - if algorithm == "ddqn": - agent.load(policy_net_path, target_net_path) - else: - agent.load(policy_net_path) + agent.policy_net.load_state_dict(torch.load(policy_net_path, map_location=torch.device("cpu"))) run_test(env, agent)