Skip to content

Commit

Permalink
add initial code for RL
Browse files Browse the repository at this point in the history
  • Loading branch information
flimdejong committed Oct 2, 2024
1 parent c62bad2 commit 987e39c
Show file tree
Hide file tree
Showing 8 changed files with 600 additions and 0 deletions.
18 changes: 18 additions & 0 deletions roboteam_ai/src/RL/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Nova AI

RL implementation to swap out the play system

## Features

- Feature 1
- Feature 2
- Feature 3

## Usage


## Explanation .py scripts
- getRefereeState.py gets the state of the referee
- GetState.py gets a combined state and contains 2 functions, one to get the ball position and one to get robot position
- sentActionCommand sends a command using proto to the legacy AI system
- teleportBall.py to tp the ball to a location we can define in our environment
210 changes: 210 additions & 0 deletions roboteam_ai/src/RL/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# The environment

import gymnasium
from gymnasium import spaces
import numpy as np
from google.protobuf.message import DecodeError


# Import our functions
from sentActionCommand import send_action_command
from getState import get_ball_state, get_robot_state
from getRefereeState import get_referee_state
from teleportBall import teleport_ball
from resetRefereeState import reset_referee_state

"""
This environment file is in the form of a gymnasium environment.
We are yellow and we play against blue.
Yellow cards do not stop the game, but maybe in the future it is nice to implement a punishment
"""

class RoboTeamEnv(gymnasium.env):

def __init__(self):

self.MAX_ROBOTS_US = 10

# Define the number of robots that are present in each grid + ball location
self.robot_grid = np.zeros((4, 2), dtype=int) # left up, right up, left down, right down
self.ball_position = np.zeros((1,2)) # Single row with 2 columns
self.ball_quadrant = 4 # This refers to the center

# Initialize the observation space
self.observation_space = spaces.Dict({
'robot_positions': spaces.Box(
low=0,
high=self.MAX_ROBOTS_US,
shape=(4, 2),
dtype=np.int32
),
'ball_position': spaces.Discrete(5), # 0-3 for quadrants, 4 for center
'is_yellow_dribbling' : spaces.Discrete(2) # 0 for false, 1 for true
})

# Action space: [attackers, defenders]
# Wallers will be automatically calculated
self.action_space = spaces.MultiDiscrete([self.MAX_ROBOTS_US + 1, self.MAX_ROBOTS_US + 1])

# Define the xy coordinates of the designated ball position
self.x = 0
self.y = 0

# Define ref state variables
self.yellow_yellow_cards = 0
self.blue_yellow_cards = 0
self.ref_command = "" # Empty string
self.yellow_score = 0 # Init score to zero

# Reward shaping
self.shaped_reward_given = False # A reward that is given once per episode
self.is_yellow_dribbling = False
self.is_blue_dribbling = False

def check_ball_placement(self):
"""
Function to teleport the ball to the designated position for ball placement if necessary
"""

# If ref gives command BALL_PLACEMENT_US OR BALL_PLACEMENT_THEM
if (self.ref_command == "BALL_PLACEMENT_US") or ("BALL_PLACEMENT_THEM"):
teleport_ball(self.x, self.y) # Teleport the ball to the designated location

def get_referee_state(self):
"""
Function to globally import the referee state
"""
self.x,self.y, # Designated pos
self.yellow_yellow_cards, self.blue_yellow_cards, # yellow cards
self.ref_command, # Ref command, such as HALT, STOP
self.yellow_score, self.blue_score = get_referee_state() # Scores

def calculate_reward(self):
"""
calculate_reward calculates the reward the agent gets for an action it did.
Based on if we have the ball and if they scored a goal.
"""

# When a goal is scored the ref command is HALT

# If we score a goal, give reward. If opponent scores, give negative reward.
if self.yellow_score == 1:
goal_scored_reward = 1
elif self.blue_score == 1:
goal_scored_reward = -1

# Reward shaping
if not self.shaped_reward_given and self.is_yellow_dribbling and (self.ball_quadrant == 1 or self.ball_quadrant == 3):
self.shaped_reward_given = True # Set it to true
shaped_reward = 0.1

# # If it gets a yellow card/ three times a foul, punish and reset
# if self.yellow_yellow_cards or self.blue_yellow_cards >= 1:
# yellow_card_punishment = 1

# Calculate final reward
reward = goal_scored_reward + shaped_reward

return reward


def get_observation(self):
"""
get_observation is meant to get the observation space (kinda like the state)
"""

# Get the robot grid representation
self.robot_grid, self.is_yellow_dribbling, self.is_blue_dribbling = get_robot_state() # Matrix of 4 by 2 + 2 booleans

# Get the ball location
self.ball_position, self.ball_quadrant = get_ball_state() # x,y coordinates, quadrant

observation_space = {
'robot_positions': self.robot_grid,
'ball_position': self.ball_quadrant,
'is_yellow_dribbling' : self.is_yellow_dribbling
}

return observation_space, self.calculate_reward()


def step(self, action):
"""
The step function is called in every loop the RL agent goes through.
It receives a state, reward and carries out an action
"""

# Only carry out "normal" loop if the game state is NORMAL_START (this indicates normal gameplay loop)
if self.ref_command == "RUNNING": # Maybe this needs to change to normal_start

attackers, defenders = action
wallers = self.MAX_ROBOTS - (attackers + defenders)

# Ensure non-negative values and total of 10
attackers = max(0, min(attackers, self.MAX_ROBOTS))
defenders = max(0, min(defenders, self.MAX_ROBOTS - attackers))
wallers = self.MAX_ROBOTS - (attackers + defenders)

# Sends the action command over proto to legacy AI
send_action_command(num_attacker=attackers, num_defender=defenders, num_waller= wallers)


# If the game is halted, stopped or ball placement is happening, execute this.

# Logic to TP the ball if there is ball placement of either side
self.check_ball_placement() # Run the function to check if we need to TP the ball

reward = self.calculate_reward()

# Update observation_space
observation_space = self.get_observation

done = self.is_terminated() # If task is completed (a goal was scored)
truncated = self.is_truncated() # Determine if the episode was truncated, too much time or a yellow card

return observation_space, reward, done, truncated



def is_terminated(self):
"""
Activates when the task has been completed (or it failed because of opponent scoring a goal)
"""

if self.ref_command == "HALT" and (self.yellow_score == 1 or self.blue_score == 1): # HALT command indicates that either team scored
return True

def is_truncated(self):
"""
is_truncated is meant for ending prematurely. For example when the time is ended (5 min)
"""

# Implement logic to reset the game if no goal is scored
pass


def reset(self, seed=None):
"""
The reset function resets the environment when a game is ended
"""

# Teleport ball to middle position
teleport_ball(0,0)

# Reset referee state
reset_referee_state() # This resets the cards, goals and initiates a kickoff.

# Reset shaped_reward_given boolean
self.shaped_reward_given = False
self.is_yellow_dribbling = False
self.is_blue_dribbling = False









141 changes: 141 additions & 0 deletions roboteam_ai/src/RL/getState.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import sys
import os
import zmq
from google.protobuf.message import DecodeError
import numpy as np

'''
GetState.py is a script to get the state of the game (where the robots and ball are and more information).
This data will be fed into the RL. It uses the proto:State object.
The two functions get_ball_state and get_robot_state get the states of the ball and robots respectively.
'''

# Make sure to go back to the main roboteam directory
current_dir = os.path.dirname(os.path.abspath(__file__))
roboteam_path = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))

# Add to sys.path
sys.path.append(roboteam_path)

# Now import the generated protobuf classes
from roboteam_networking.proto.State_pb2 import State

def get_ball_state():
# Initialize array for ball position
ball_position = np.zeros(2) # [x, y]
ball_quadrant = -1 # Initialize to invalid quadrant

# Define threshold for center position
CENTER_THRESHOLD = 0.01 # Adjust this value as needed

context = zmq.Context()
socket_world = context.socket(zmq.SUB)
socket_world.setsockopt_string(zmq.SUBSCRIBE, "")
socket_world.connect("tcp://127.0.0.1:5558")

try:
message = socket_world.recv()
state = State.FromString(message)

if not len(state.processed_vision_packets):
return ball_position, ball_quadrant

world = state.last_seen_world

# Get ball information
if world.HasField("ball"):
ball_position[0] = world.ball.pos.x
ball_position[1] = world.ball.pos.y

# Determine which quadrant the ball is in
# Assuming the field is centered at (0,0) and extends from -6 to 6 in x and -4.5 to 4.5 in y
if abs(ball_position[0]) <= CENTER_THRESHOLD and abs(ball_position[1]) <= CENTER_THRESHOLD:
ball_quadrant = 4 # Center
elif ball_position[0] < 0: # If left
ball_quadrant = 0 if ball_position[1] > 0 else 2
else:
ball_quadrant = 1 if ball_position[1] > 0 else 3

except DecodeError:
print("Failed to decode protobuf message")
except zmq.ZMQError as e:
print(f"ZMQ Error: {e}")
finally:
socket_world.close()
context.term()

return ball_position, ball_quadrant

def get_robot_state():
# 4 rows (4 grid positions), 2 columns (yellow and blue robot counts)
grid_array = np.zeros((4, 2), dtype=int)

# Flags to indicate if any robot in each team is dribbling
yellow_team_dribbling = False
blue_team_dribbling = False

context = zmq.Context()
socket_world = context.socket(zmq.SUB)
socket_world.setsockopt_string(zmq.SUBSCRIBE, "")
socket_world.connect("tcp://127.0.0.1:5558")

try:
message = socket_world.recv()
state = State.FromString(message)

if not len(state.processed_vision_packets):
return grid_array, yellow_team_dribbling, blue_team_dribbling

world = state.last_seen_world

def get_grid_position(x, y):
if x < 0:
return 0 if y > 0 else 2
else:
return 1 if y > 0 else 3

# Process yellow robots
for bot in world.yellow:
grid_pos = get_grid_position(bot.pos.x, bot.pos.y)
grid_array[grid_pos, 0] += 1
if bot.feedbackInfo.dribbler_sees_ball:
yellow_team_dribbling = True

# Process blue robots
for bot in world.blue:
grid_pos = get_grid_position(bot.pos.x, bot.pos.y)
grid_array[grid_pos, 1] += 1
if bot.feedbackInfo.dribbler_sees_ball:
blue_team_dribbling = True

# Process unseen robots
for bot in world.yellow_unseen_robots:
if bot.feedbackInfo.dribbler_sees_ball:
yellow_team_dribbling = True

for bot in world.blue_unseen_robots:
if bot.feedbackInfo.dribbler_sees_ball:
blue_team_dribbling = True

except DecodeError:
print("Failed to decode protobuf message")
except zmq.ZMQError as e:
print(f"ZMQ Error: {e}")
finally:
socket_world.close()
context.term()

return grid_array, yellow_team_dribbling, blue_team_dribbling

if __name__ == "__main__":
grid_array, yellow_team_dribbling, blue_team_dribbling = get_robot_state()
print("Grid-based Robot Array:")
print(grid_array)
print("\nInterpretation:")
quadrants = ["Bottom-Left", "Top-Left", "Bottom-Right", "Top-Right"]
for i, quadrant in enumerate(quadrants):
print(f"{quadrant}: {grid_array[i, 0]} yellow robots, {grid_array[i, 1]} blue robots")

print("\nDribbler Information:")
print(f"Yellow Team: {'Dribbling' if yellow_team_dribbling else 'Not Dribbling'}")
print(f"Blue Team: {'Dribbling' if blue_team_dribbling else 'Not Dribbling'}")
Loading

0 comments on commit 987e39c

Please sign in to comment.