Skip to content

Commit

Permalink
Merge branch 'nikita'
Browse files Browse the repository at this point in the history
  • Loading branch information
NorthPhoenix committed May 3, 2023
2 parents ead06f3 + 4f772de commit d9af193
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 20 deletions.
48 changes: 35 additions & 13 deletions Game/agent.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from agent_based_game import Game, Action
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
import time
import os

#%matplotlib inline

class QLearning:
def __init__(
Expand All @@ -21,15 +23,15 @@ def __init__(
self.alpha = alpha
self.gamma = gamma
self.episodes = episodes
self.history = np.empty((0, 2), float)
if start_decay_episode is None:
self.start_decay_episode = episodes // 2
else:
self.start_decay_episode = min(episodes - 1, start_decay_episode)
self.epsilon_start = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = (self.epsilon_start - self.epsilon_end) / (
self.episodes - self.start_decay_episode
)
self.episodes - self.start_decay_episode)
self.Q = {}
# State space
self.state_space = [
Expand Down Expand Up @@ -69,12 +71,11 @@ def train(self):
# Training loop
epsilon = self.epsilon_start
for episode in range(self.episodes):
# visualize every 10 episodes
# if (episode) % 10 == 0:
# # game.reset(visualizeNext=True)
# self.printQTable()
# visualize every 500 episodes
# if (episode+1) % 500 == 0:
# self.game.reset(visualizeNext=True)
# else:
# game.reset()
# self.game.reset()
self.game.reset(visualizeNext=False)
gameover = False
state, _ = self.game.getState()
Expand All @@ -83,9 +84,7 @@ def train(self):

while not gameover:
# Epsilon-greedy action selection
if (
random.uniform(0, 1) < epsilon
): # If True, the agent chooses a random action from the action space
if random.uniform(0, 1) > epsilon:
action = random.choice(self.action_space)
else:
action = max(
Expand All @@ -95,7 +94,6 @@ def train(self):
# Perform the action and receive the new state, reward, and gameover status
next_state, reward, gameover, score = self.game.act(action)
total_reward += reward

# Update the Q-value for the current state-action pair using the Q-learning update rule
old_q_value = self.Q[(state, action)]
max_future_q_value = max(
Expand All @@ -114,6 +112,7 @@ def train(self):
# self.printQTable()

print(f"Episode {episode + 1}/{self.episodes} completed")
self.history = np.append(self.history, [[total_reward, score]], axis=0)
# print(f"Score: {score}")
# print(f"Total Reward (Train): {total_reward}")

Expand Down Expand Up @@ -151,13 +150,36 @@ def printQTable(self):
)
)

def graphHistory(self):
# Plot the model history for each model in a single plot
# model history is a plot of accuracy vs number of epochs
# you may want to create a large sized plot to show multiple lines
# in a same figure2
fig = plt.figure(figsize=(30,40))
fig.suptitle('Agent Performance', fontsize=30)
plt.rcParams.update({'font.size': 22})

plt.subplot(2, 1, 1)
plt.plot(self.history[:, 0], label='Total Agent Reward per Episode')
plt.ylabel('Reward')
plt.xlabel('Episodes')
plt.legend(loc='lower right')

plt.subplot(2, 1, 2)
plt.plot(self.history[:, 1], label='Game Score per Episode')
plt.ylabel('Score')
plt.xlabel('Episodes')
plt.legend(loc='lower right')

plt.show()

if __name__ == "__main__":
game = Game(target_reward=10, runtime=10)
agent = QLearning(game, episodes=2000, alpha=0.5, epsilon_start=0.9)
game = Game(runtime=15, move_reward=1, target_reward=100)
agent = QLearning(game, episodes=1000, alpha=0.5, epsilon_start=0.9)

agent.printQTable()
agent.train()
agent.printQTable()
agent.graphHistory()
while True:
agent.test()
32 changes: 25 additions & 7 deletions Game/agent_based_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Action(Enum):

class Game:
def __init__(
self, runtime=15, fps=60, target_reward=100, miss_reward=-1, visualize=False
self, runtime=15, fps=60, target_reward=100, move_reward=1, miss_reward=-1, visualize=False
):
# Initialize Pygame after __main__ is executed
self.initialized = False
Expand All @@ -38,7 +38,8 @@ def __init__(

# Setting up game variables
self.score = 0
self.TARGET_REWARD = target_reward # set to 100 from parameters
self.TARGET_REWARD = target_reward # set to 100 from parameters
self.MOVE_DIRECTION_REWARD = move_reward # set to 1 from parameters
self.MISS_REWARD = miss_reward # set to -1 from parameters
self.PLAYER_DIMENTIONS = (80, 80) # The size of the snake head
self.TARGET_DIMENTIONS = (40, 40) # Size of the targets
Expand Down Expand Up @@ -264,14 +265,15 @@ def act(self, action: Action):
distanceToTargetAfter = self._getDistanceToClosestTarget()
# If player is closer to the target than in the previous frame, give reward
if distanceToTargetAfter < distanceToTargetBefore:
reward = self.TARGET_REWARD
reward = self.MOVE_DIRECTION_REWARD
else:
reward = self.MISS_REWARD

# To be run if collision occurs between Player and Target
collision = pygame.sprite.spritecollideany(self.player, self.targets)
if collision:
self.score += 1 # If the snake moved and ate a target add 1 to the score
reward += self.TARGET_REWARD
collision.kill()
newTarget = self.Target(self.TARGET_DIMENTIONS, self.BLUE)
while True:
Expand Down Expand Up @@ -379,21 +381,37 @@ def getState(self):
return (0, 0), None

# Calculate the state of the game in horizontal axis
if self.player.rect.right < closestTarget.rect.left:
if self.player.rect.x < closestTarget.rect.x:
x = 1
elif self.player.rect.left > closestTarget.rect.right:
elif self.player.rect.x > closestTarget.rect.x:
x = -1
else:
x = 0

# Calculate the state of the game in vertical axis
if self.player.rect.top > closestTarget.rect.bottom:
if self.player.rect.y > closestTarget.rect.y:
y = 1
elif self.player.rect.bottom < closestTarget.rect.top:
elif self.player.rect.y < closestTarget.rect.y:
y = -1
else:
y = 0


# if self.player.rect.right < closestTarget.rect.left:
# x = 1
# elif self.player.rect.left > closestTarget.rect.right:
# x = -1
# else:
# x = 0

# # Calculate the state of the game in vertical axis
# if self.player.rect.top > closestTarget.rect.bottom:
# y = 1
# elif self.player.rect.bottom < closestTarget.rect.top:
# y = -1
# else:
# y = 0

return (x, y), closestTarget

# Returns current game score
Expand Down

0 comments on commit d9af193

Please sign in to comment.