Skip to content

Commit

Permalink
Improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hutanmihai committed Jan 16, 2024
1 parent 4330caa commit c152b77
Showing 1 changed file with 11 additions and 18 deletions.
29 changes: 11 additions & 18 deletions src/test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from argparse import ArgumentParser

import torch
from gymnasium import Env, make

from src.agent import Agent
from src.main import step
from src.utils.runners import init_new_episode


def run_test(env: Env, agent: Agent, episodes_to_run: int = 10):
agent.set_epsilon(0)
def run_test(env: Env, agent: Agent, episodes_to_run: int = 100):
agent.set_epsilon(0.1)
agent.policy_net.eval()
rewards_history = []
for episode in range(episodes_to_run):
total_reward = 0

Expand All @@ -23,8 +25,12 @@ def run_test(env: Env, agent: Agent, episodes_to_run: int = 10):
state = next_state
total_reward += reward

rewards_history.append(total_reward)

print(f"Episode {episode + 1} finished with reward {total_reward}!")

print(f"Average reward over {episodes_to_run} episodes: {sum(rewards_history) / episodes_to_run}")


if __name__ == "__main__":
parser = ArgumentParser()
Expand All @@ -40,25 +46,12 @@ def run_test(env: Env, agent: Agent, episodes_to_run: int = 10):
raise ValueError("Please specify either --dqn or --ddqn flag!")

# To take the best results in one episode models
if algorithm == "ddqn":
policy_net_path = f"models/policy_net_{algorithm}_solo.pth"
target_net_path = f"models/target_net_{algorithm}_solo.pth"
else:
policy_net_path = f"models/policy_net_{algorithm}_solo.pth"
target_net_path = None
policy_net_path = f"models/policy_net_{algorithm}_solo.pth"

# To take the best average results models in 100 episodes
# if algorithm == "ddqn":
# policy_net_path = f"models/policy_net_{algorithm}_avg.pth"
# target_net_path = f"models/target_net_{algorithm}_avg.pth"
# else:
# policy_net_path = f"models/policy_net_{algorithm}_avg.pth"
# target_net_path = None
# policy_net_path = f"models/policy_net_{algorithm}_avg.pth"

env: Env = make("ALE/Skiing-v5")
agent = Agent(action_space=env.action_space, algorithm=algorithm)
if algorithm == "ddqn":
agent.load(policy_net_path, target_net_path)
else:
agent.load(policy_net_path)
agent.policy_net.load_state_dict(torch.load(policy_net_path, map_location=torch.device("cpu")))
run_test(env, agent)

0 comments on commit c152b77

Please sign in to comment.