-
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c62bad2
commit 987e39c
Showing
8 changed files
with
600 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Nova AI | ||
|
||
RL implementation to swap out the play system | ||
|
||
## Features | ||
|
||
- Feature 1 | ||
- Feature 2 | ||
- Feature 3 | ||
|
||
## Usage | ||
|
||
|
||
## Explanation .py scripts | ||
- getRefereeState.py gets the state of the referee | ||
- GetState.py gets a combined state and contains 2 functions, one to get the ball position and one to get robot position | ||
- sentActionCommand sends a command using proto to the legacy AI system | ||
- teleportBall.py to tp the ball to a location we can define in our environment |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
# The environment | ||
|
||
import gymnasium | ||
from gymnasium import spaces | ||
import numpy as np | ||
from google.protobuf.message import DecodeError | ||
|
||
|
||
# Import our functions | ||
from sentActionCommand import send_action_command | ||
from getState import get_ball_state, get_robot_state | ||
from getRefereeState import get_referee_state | ||
from teleportBall import teleport_ball | ||
from resetRefereeState import reset_referee_state | ||
|
||
""" | ||
This environment file is in the form of a gymnasium environment. | ||
We are yellow and we play against blue. | ||
Yellow cards do not stop the game, but maybe in the future it is nice to implement a punishment | ||
""" | ||
|
||
class RoboTeamEnv(gymnasium.env): | ||
|
||
def __init__(self): | ||
|
||
self.MAX_ROBOTS_US = 10 | ||
|
||
# Define the number of robots that are present in each grid + ball location | ||
self.robot_grid = np.zeros((4, 2), dtype=int) # left up, right up, left down, right down | ||
self.ball_position = np.zeros((1,2)) # Single row with 2 columns | ||
self.ball_quadrant = 4 # This refers to the center | ||
|
||
# Initialize the observation space | ||
self.observation_space = spaces.Dict({ | ||
'robot_positions': spaces.Box( | ||
low=0, | ||
high=self.MAX_ROBOTS_US, | ||
shape=(4, 2), | ||
dtype=np.int32 | ||
), | ||
'ball_position': spaces.Discrete(5), # 0-3 for quadrants, 4 for center | ||
'is_yellow_dribbling' : spaces.Discrete(2) # 0 for false, 1 for true | ||
}) | ||
|
||
# Action space: [attackers, defenders] | ||
# Wallers will be automatically calculated | ||
self.action_space = spaces.MultiDiscrete([self.MAX_ROBOTS_US + 1, self.MAX_ROBOTS_US + 1]) | ||
|
||
# Define the xy coordinates of the designated ball position | ||
self.x = 0 | ||
self.y = 0 | ||
|
||
# Define ref state variables | ||
self.yellow_yellow_cards = 0 | ||
self.blue_yellow_cards = 0 | ||
self.ref_command = "" # Empty string | ||
self.yellow_score = 0 # Init score to zero | ||
|
||
# Reward shaping | ||
self.shaped_reward_given = False # A reward that is given once per episode | ||
self.is_yellow_dribbling = False | ||
self.is_blue_dribbling = False | ||
|
||
def check_ball_placement(self): | ||
""" | ||
Function to teleport the ball to the designated position for ball placement if necessary | ||
""" | ||
|
||
# If ref gives command BALL_PLACEMENT_US OR BALL_PLACEMENT_THEM | ||
if (self.ref_command == "BALL_PLACEMENT_US") or ("BALL_PLACEMENT_THEM"): | ||
teleport_ball(self.x, self.y) # Teleport the ball to the designated location | ||
|
||
def get_referee_state(self): | ||
""" | ||
Function to globally import the referee state | ||
""" | ||
self.x,self.y, # Designated pos | ||
self.yellow_yellow_cards, self.blue_yellow_cards, # yellow cards | ||
self.ref_command, # Ref command, such as HALT, STOP | ||
self.yellow_score, self.blue_score = get_referee_state() # Scores | ||
|
||
def calculate_reward(self): | ||
""" | ||
calculate_reward calculates the reward the agent gets for an action it did. | ||
Based on if we have the ball and if they scored a goal. | ||
""" | ||
|
||
# When a goal is scored the ref command is HALT | ||
|
||
# If we score a goal, give reward. If opponent scores, give negative reward. | ||
if self.yellow_score == 1: | ||
goal_scored_reward = 1 | ||
elif self.blue_score == 1: | ||
goal_scored_reward = -1 | ||
|
||
# Reward shaping | ||
if not self.shaped_reward_given and self.is_yellow_dribbling and (self.ball_quadrant == 1 or self.ball_quadrant == 3): | ||
self.shaped_reward_given = True # Set it to true | ||
shaped_reward = 0.1 | ||
|
||
# # If it gets a yellow card/ three times a foul, punish and reset | ||
# if self.yellow_yellow_cards or self.blue_yellow_cards >= 1: | ||
# yellow_card_punishment = 1 | ||
|
||
# Calculate final reward | ||
reward = goal_scored_reward + shaped_reward | ||
|
||
return reward | ||
|
||
|
||
def get_observation(self): | ||
""" | ||
get_observation is meant to get the observation space (kinda like the state) | ||
""" | ||
|
||
# Get the robot grid representation | ||
self.robot_grid, self.is_yellow_dribbling, self.is_blue_dribbling = get_robot_state() # Matrix of 4 by 2 + 2 booleans | ||
|
||
# Get the ball location | ||
self.ball_position, self.ball_quadrant = get_ball_state() # x,y coordinates, quadrant | ||
|
||
observation_space = { | ||
'robot_positions': self.robot_grid, | ||
'ball_position': self.ball_quadrant, | ||
'is_yellow_dribbling' : self.is_yellow_dribbling | ||
} | ||
|
||
return observation_space, self.calculate_reward() | ||
|
||
|
||
def step(self, action): | ||
""" | ||
The step function is called in every loop the RL agent goes through. | ||
It receives a state, reward and carries out an action | ||
""" | ||
|
||
# Only carry out "normal" loop if the game state is NORMAL_START (this indicates normal gameplay loop) | ||
if self.ref_command == "RUNNING": # Maybe this needs to change to normal_start | ||
|
||
attackers, defenders = action | ||
wallers = self.MAX_ROBOTS - (attackers + defenders) | ||
|
||
# Ensure non-negative values and total of 10 | ||
attackers = max(0, min(attackers, self.MAX_ROBOTS)) | ||
defenders = max(0, min(defenders, self.MAX_ROBOTS - attackers)) | ||
wallers = self.MAX_ROBOTS - (attackers + defenders) | ||
|
||
# Sends the action command over proto to legacy AI | ||
send_action_command(num_attacker=attackers, num_defender=defenders, num_waller= wallers) | ||
|
||
|
||
# If the game is halted, stopped or ball placement is happening, execute this. | ||
|
||
# Logic to TP the ball if there is ball placement of either side | ||
self.check_ball_placement() # Run the function to check if we need to TP the ball | ||
|
||
reward = self.calculate_reward() | ||
|
||
# Update observation_space | ||
observation_space = self.get_observation | ||
|
||
done = self.is_terminated() # If task is completed (a goal was scored) | ||
truncated = self.is_truncated() # Determine if the episode was truncated, too much time or a yellow card | ||
|
||
return observation_space, reward, done, truncated | ||
|
||
|
||
|
||
def is_terminated(self): | ||
""" | ||
Activates when the task has been completed (or it failed because of opponent scoring a goal) | ||
""" | ||
|
||
if self.ref_command == "HALT" and (self.yellow_score == 1 or self.blue_score == 1): # HALT command indicates that either team scored | ||
return True | ||
|
||
def is_truncated(self): | ||
""" | ||
is_truncated is meant for ending prematurely. For example when the time is ended (5 min) | ||
""" | ||
|
||
# Implement logic to reset the game if no goal is scored | ||
pass | ||
|
||
|
||
def reset(self, seed=None): | ||
""" | ||
The reset function resets the environment when a game is ended | ||
""" | ||
|
||
# Teleport ball to middle position | ||
teleport_ball(0,0) | ||
|
||
# Reset referee state | ||
reset_referee_state() # This resets the cards, goals and initiates a kickoff. | ||
|
||
# Reset shaped_reward_given boolean | ||
self.shaped_reward_given = False | ||
self.is_yellow_dribbling = False | ||
self.is_blue_dribbling = False | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import sys | ||
import os | ||
import zmq | ||
from google.protobuf.message import DecodeError | ||
import numpy as np | ||
|
||
''' | ||
GetState.py is a script to get the state of the game (where the robots and ball are and more information). | ||
This data will be fed into the RL. It uses the proto:State object. | ||
The two functions get_ball_state and get_robot_state get the states of the ball and robots respectively. | ||
''' | ||
|
||
# Make sure to go back to the main roboteam directory | ||
current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
roboteam_path = os.path.abspath(os.path.join(current_dir, "..", "..", "..", "..")) | ||
|
||
# Add to sys.path | ||
sys.path.append(roboteam_path) | ||
|
||
# Now import the generated protobuf classes | ||
from roboteam_networking.proto.State_pb2 import State | ||
|
||
def get_ball_state(): | ||
# Initialize array for ball position | ||
ball_position = np.zeros(2) # [x, y] | ||
ball_quadrant = -1 # Initialize to invalid quadrant | ||
|
||
# Define threshold for center position | ||
CENTER_THRESHOLD = 0.01 # Adjust this value as needed | ||
|
||
context = zmq.Context() | ||
socket_world = context.socket(zmq.SUB) | ||
socket_world.setsockopt_string(zmq.SUBSCRIBE, "") | ||
socket_world.connect("tcp://127.0.0.1:5558") | ||
|
||
try: | ||
message = socket_world.recv() | ||
state = State.FromString(message) | ||
|
||
if not len(state.processed_vision_packets): | ||
return ball_position, ball_quadrant | ||
|
||
world = state.last_seen_world | ||
|
||
# Get ball information | ||
if world.HasField("ball"): | ||
ball_position[0] = world.ball.pos.x | ||
ball_position[1] = world.ball.pos.y | ||
|
||
# Determine which quadrant the ball is in | ||
# Assuming the field is centered at (0,0) and extends from -6 to 6 in x and -4.5 to 4.5 in y | ||
if abs(ball_position[0]) <= CENTER_THRESHOLD and abs(ball_position[1]) <= CENTER_THRESHOLD: | ||
ball_quadrant = 4 # Center | ||
elif ball_position[0] < 0: # If left | ||
ball_quadrant = 0 if ball_position[1] > 0 else 2 | ||
else: | ||
ball_quadrant = 1 if ball_position[1] > 0 else 3 | ||
|
||
except DecodeError: | ||
print("Failed to decode protobuf message") | ||
except zmq.ZMQError as e: | ||
print(f"ZMQ Error: {e}") | ||
finally: | ||
socket_world.close() | ||
context.term() | ||
|
||
return ball_position, ball_quadrant | ||
|
||
def get_robot_state(): | ||
# 4 rows (4 grid positions), 2 columns (yellow and blue robot counts) | ||
grid_array = np.zeros((4, 2), dtype=int) | ||
|
||
# Flags to indicate if any robot in each team is dribbling | ||
yellow_team_dribbling = False | ||
blue_team_dribbling = False | ||
|
||
context = zmq.Context() | ||
socket_world = context.socket(zmq.SUB) | ||
socket_world.setsockopt_string(zmq.SUBSCRIBE, "") | ||
socket_world.connect("tcp://127.0.0.1:5558") | ||
|
||
try: | ||
message = socket_world.recv() | ||
state = State.FromString(message) | ||
|
||
if not len(state.processed_vision_packets): | ||
return grid_array, yellow_team_dribbling, blue_team_dribbling | ||
|
||
world = state.last_seen_world | ||
|
||
def get_grid_position(x, y): | ||
if x < 0: | ||
return 0 if y > 0 else 2 | ||
else: | ||
return 1 if y > 0 else 3 | ||
|
||
# Process yellow robots | ||
for bot in world.yellow: | ||
grid_pos = get_grid_position(bot.pos.x, bot.pos.y) | ||
grid_array[grid_pos, 0] += 1 | ||
if bot.feedbackInfo.dribbler_sees_ball: | ||
yellow_team_dribbling = True | ||
|
||
# Process blue robots | ||
for bot in world.blue: | ||
grid_pos = get_grid_position(bot.pos.x, bot.pos.y) | ||
grid_array[grid_pos, 1] += 1 | ||
if bot.feedbackInfo.dribbler_sees_ball: | ||
blue_team_dribbling = True | ||
|
||
# Process unseen robots | ||
for bot in world.yellow_unseen_robots: | ||
if bot.feedbackInfo.dribbler_sees_ball: | ||
yellow_team_dribbling = True | ||
|
||
for bot in world.blue_unseen_robots: | ||
if bot.feedbackInfo.dribbler_sees_ball: | ||
blue_team_dribbling = True | ||
|
||
except DecodeError: | ||
print("Failed to decode protobuf message") | ||
except zmq.ZMQError as e: | ||
print(f"ZMQ Error: {e}") | ||
finally: | ||
socket_world.close() | ||
context.term() | ||
|
||
return grid_array, yellow_team_dribbling, blue_team_dribbling | ||
|
||
if __name__ == "__main__": | ||
grid_array, yellow_team_dribbling, blue_team_dribbling = get_robot_state() | ||
print("Grid-based Robot Array:") | ||
print(grid_array) | ||
print("\nInterpretation:") | ||
quadrants = ["Bottom-Left", "Top-Left", "Bottom-Right", "Top-Right"] | ||
for i, quadrant in enumerate(quadrants): | ||
print(f"{quadrant}: {grid_array[i, 0]} yellow robots, {grid_array[i, 1]} blue robots") | ||
|
||
print("\nDribbler Information:") | ||
print(f"Yellow Team: {'Dribbling' if yellow_team_dribbling else 'Not Dribbling'}") | ||
print(f"Blue Team: {'Dribbling' if blue_team_dribbling else 'Not Dribbling'}") |
Oops, something went wrong.