diff --git a/roboteam_ai/src/RL/getRefereeState.py b/roboteam_ai/src/RL/getRefereeState.py new file mode 100644 index 000000000..4b9fddf05 --- /dev/null +++ b/roboteam_ai/src/RL/getRefereeState.py @@ -0,0 +1,72 @@ +import os +import sys +import socket +import struct +import binascii +from google.protobuf.json_format import MessageToJson + +''' +getRefereeState.py is a script to get state of the referee. +This includes the current command, designed position for ball placement, and the score for both teams. +''' + +# Make sure to go back to the main roboteam directory +current_dir = os.path.dirname(os.path.abspath(__file__)) +roboteam_path = os.path.abspath(os.path.join(current_dir, "..", "..", "..")) + +# Add to sys.path +sys.path.append(roboteam_path) + +# Now import the generated protobuf classes +from roboteam_networking.proto.ssl_gc_referee_message_pb2 import Referee +from roboteam_networking.proto.ssl_gc_game_event_pb2 import GameEvent + +MULTICAST_GROUP = '224.5.23.1' +MULTICAST_PORT = 10003 + +def get_referee_state(): + # Create the socket + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + # Bind to the server address + sock.bind(('', MULTICAST_PORT)) + + # Tell the operating system to add the socket to the multicast group + group = socket.inet_aton(MULTICAST_GROUP) + mreq = struct.pack('4sL', group, socket.INADDR_ANY) + sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) + + print(f"Listening for Referee messages on {MULTICAST_GROUP}:{MULTICAST_PORT}") + + command = None + pos_x = None + pos_y = None + yellow_score = None + blue_score = None + + try: + data, _ = sock.recvfrom(4096) # Increased buffer size to 4096 bytes + referee = Referee() + referee.ParseFromString(data) + command = Referee.Command.Name(referee.command) + + if referee.HasField('designated_position'): + pos_x = referee.designated_position.x + pos_y = referee.designated_position.y + + yellow_score = referee.yellow.score + blue_score = referee.blue.score + except Exception as e: + print(f"Error parsing message: {e}") + finally: + sock.close() + + return command, pos_x, pos_y, yellow_score, blue_score + +if __name__ == "__main__": + command, pos_x, pos_y, yellow_score, blue_score = get_referee_state() + print(f"Command: {command}") + print(f"Designated Position: ({pos_x}, {pos_y})") + print(f"Yellow Team Score: {yellow_score}") + print(f"Blue Team Score: {blue_score}") \ No newline at end of file diff --git a/roboteam_ai/src/RL/receiveActionCommand.cpp b/roboteam_ai/src/RL/receiveActionCommand.cpp new file mode 100644 index 000000000..1a11dc588 --- /dev/null +++ b/roboteam_ai/src/RL/receiveActionCommand.cpp @@ -0,0 +1,40 @@ +#include +#include +#include +#include +#include "ActionCommand.pb.h" + +int main() { + zmq::context_t context(1); + zmq::socket_t socket(context, ZMQ_SUB); + + std::cout << "Connecting to ActionCommand sender..." << std::endl; + socket.connect("tcp://localhost:5555"); + socket.setsockopt(ZMQ_SUBSCRIBE, "", 0); + + std::cout << "ActionCommand receiver started. Ctrl+C to exit." << std::endl; + + while (true) { + zmq::message_t message; + socket.recv(&message); + + ActionCommand action_command; + if (!action_command.ParseFromArray(message.data(), message.size())) { + std::cerr << "Failed to parse ActionCommand." << std::endl; + continue; + } + + if (action_command.numrobots_size() != 3) { + std::cerr << "Received incorrect number of values. Expected 3, got " + << action_command.numrobots_size() << std::endl; + continue; + } + + std::cout << "Received: [" + << action_command.numrobots(0) << ", " + << action_command.numrobots(1) << ", " + << action_command.numrobots(2) << "]" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/roboteam_ai/src/RL/resetRefereeState.py b/roboteam_ai/src/RL/resetRefereeState.py new file mode 100644 index 000000000..f57f832bc --- /dev/null +++ b/roboteam_ai/src/RL/resetRefereeState.py @@ -0,0 +1,51 @@ +""" +resetRefereeState is a script to reset the referee state of the game (it basically resets the match (clock)) +""" + +import asyncio +import websockets +from google.protobuf.json_format import MessageToJson, Parse +import os +import sys + +# Make sure to go back to the main roboteam directory +current_dir = os.path.dirname(os.path.abspath(__file__)) +roboteam_path = os.path.abspath(os.path.join(current_dir, "..", "..", "..")) + +# Add to sys.path +sys.path.append(roboteam_path) + +# Now import the generated protobuf classes +from roboteam_networking.proto.ssl_gc_api_pb2 import Input +from roboteam_networking.proto.ssl_gc_change_pb2 import Change +from roboteam_networking.proto.ssl_gc_common_pb2 import Team +from roboteam_networking.proto.ssl_gc_state_pb2 import Command + + +async def reset_and_stop_match(uri='ws://localhost:8081/api/control'): + async with websockets.connect(uri) as websocket: + + # Step 1: Reset the match + reset_message = Input(reset_match=True) + await websocket.send(MessageToJson(reset_message)) + response = await websocket.recv() + + # Step 2: Send STOP command + stop_message = Input( + change=Change( + new_command_change=Change.NewCommand( + command=Command( + type=Command.Type.STOP, + for_team=Team.UNKNOWN + ) + ) + ) + ) + await websocket.send(MessageToJson(stop_message)) + print(f"Sent STOP command: {MessageToJson(stop_message)}") + response = await websocket.recv() + + print("Reset and STOP commands sent to SSL Game Controller") + +if __name__ == "__main__": + asyncio.run(reset_and_stop_match()) \ No newline at end of file diff --git a/roboteam_ai/src/RL/train.py b/roboteam_ai/src/RL/train.py new file mode 100644 index 000000000..467a421d4 --- /dev/null +++ b/roboteam_ai/src/RL/train.py @@ -0,0 +1,24 @@ +import gymnasium as gym +from stable_baselines3 import PPO + +# Import your custom environment +from env import RoboTeamEnv + +# Create the environment +env = RoboTeamEnv() + +# Create and train the PPO model +model = PPO("MultiInputPolicy", env, verbose=1) +model.learn(total_timesteps=1000) + +# Test the trained model +obs, _ = env.reset() +for i in range(1000): + action, _states = model.predict(obs, deterministic=True) + obs, reward, terminated, truncated, info = env.step(action) + env.render() + if terminated or truncated: + obs, _ = env.reset() + +env.close() + diff --git a/roboteam_mpi/README.md b/roboteam_mpi/README.md new file mode 100644 index 000000000..c8f3fc57b --- /dev/null +++ b/roboteam_mpi/README.md @@ -0,0 +1,13 @@ +### Explanation +roboteam_mpi is meant to house all the communication for MPI (message processing interface) to work on HPC clusters. + + +This will be a framework where other teams can attach their AI to. + + +## Core components + +# MPIManager +Handles initialization, finalization and standard operations + + diff --git a/roboteam_mpi/mpi_combined.py b/roboteam_mpi/mpi_combined.py new file mode 100644 index 000000000..7d782f99e --- /dev/null +++ b/roboteam_mpi/mpi_combined.py @@ -0,0 +1,28 @@ +# mpi_combined.py +from mpi4py import MPI +import sys + +def main(): + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + size = comm.Get_size() + + print(f"Process {rank}: I am rank {rank} out of {size} processes") + sys.stdout.flush() + + if rank == 0: + number = 42 + print(f"Process {rank}: Sending number {number} to rank 1") + sys.stdout.flush() + comm.send(number, dest=1, tag=11) + print(f"Process {rank}: Number {number} sent to rank 1") + sys.stdout.flush() + elif rank == 1: + print(f"Process {rank}: Waiting to receive number from rank 0") + sys.stdout.flush() + number = comm.recv(source=0, tag=11) + print(f"Process {rank}: Received number {number} from rank 0") + sys.stdout.flush() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/roboteam_mpi/mpi_message b/roboteam_mpi/mpi_message new file mode 100755 index 000000000..6a8e71955 Binary files /dev/null and b/roboteam_mpi/mpi_message differ diff --git a/roboteam_mpi/mpi_message.cpp b/roboteam_mpi/mpi_message.cpp new file mode 100644 index 000000000..503ab94e0 --- /dev/null +++ b/roboteam_mpi/mpi_message.cpp @@ -0,0 +1,38 @@ +#include +#include + +int main(int argc, char *argv[]){ + +MPI_Init(&argc, &argv); + +int world_size; //World size is total amount of processes +MPI_Comm_size(MPI_COMM_WORLD, &world_size); + +int world_rank; +MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); + +if (world_rank == 0) +{ +// Sending a message +const int message = 42; +MPI_Send(&message, // +1, +MPI_INT, +1, +0, +MPI_COMM_WORLD); + +std::cout << "Process 0 sends number " << message << " to process 1\n"; +} +else if (world_rank == 1) +{ +// Receiving a message +int received_message; +MPI_Recv(&received_message, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); +std::cout << "Process 1 received number " << received_message << " from process 0\n"; +} + +MPI_Finalize(); +return 0; + +} \ No newline at end of file diff --git a/roboteam_mpi/mpi_receiver.py b/roboteam_mpi/mpi_receiver.py new file mode 100644 index 000000000..59cb0c90a --- /dev/null +++ b/roboteam_mpi/mpi_receiver.py @@ -0,0 +1,14 @@ +from mpi4py import MPI + +def main(): + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + + if rank == 1: + while True: + message = comm.recv(source=0, tag=11) + print(f"Receiver: {message}") + +if __name__ == "__main__": + main() + \ No newline at end of file diff --git a/roboteam_mpi/mpi_sender.py b/roboteam_mpi/mpi_sender.py new file mode 100644 index 000000000..bd14fbd4a --- /dev/null +++ b/roboteam_mpi/mpi_sender.py @@ -0,0 +1,12 @@ +from mpi4py import MPI + +def main(): + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + + if rank == 0: + message = "Test" + comm.send(message, dest=1, tag=11) + print(f"Sender: {message}") +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/roboteam_mpi/mpi_test.py b/roboteam_mpi/mpi_test.py new file mode 100644 index 000000000..9eea25c76 --- /dev/null +++ b/roboteam_mpi/mpi_test.py @@ -0,0 +1,7 @@ +from mpi4py import MPI + +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +size = comm.Get_size() + +print(f"Hello from process {rank} out of {size} processes") \ No newline at end of file diff --git a/roboteam_mpi/src/MPIManager.cpp b/roboteam_mpi/src/MPIManager.cpp new file mode 100644 index 000000000..60e41b04d --- /dev/null +++ b/roboteam_mpi/src/MPIManager.cpp @@ -0,0 +1,37 @@ +#include "MPIManager.h" +#include + +bool MPIManager::initialized = false; +int MPIManager::world_rank; +int MPIManager::world_size; + +void MPIManager::init(int& argc, char**& argv) { + if (!initialized) { + MPI_Init(&argc, &argv); + MPI_Comm_rank(MPI_COMM_WORLD, &world_rank); // Unique ID for every process + MPI_Comm_size(MPI_COMM_WORLD, &world_size); // Total number of processes + initialized = true; + } +} + +void MPIManager::finalize() { + if (initialized) { + MPI_Finalize(); + initialized = false; + } +} + +int MPIManager::getRank() { + if (!initialized) throw std::runtime_error("MPI not initialized"); + return rank; +} + +int MPIManager::getSize() { + if (!initialized) throw std::runtime_error("MPI not initialized"); + return size; +} + +void MPIManager::send(const void* data, int count, MPI_Datatype datatype, int dest, int tag) { + if (!initialized) throw std::runtime_error("MPI not initialized"); + MPI_Send(data, count, datatype, dest, tag, MPI_COMM_WORLD); +} diff --git a/roboteam_mpi/src/MPIManager.h b/roboteam_mpi/src/MPIManager.h new file mode 100644 index 000000000..75a365a89 --- /dev/null +++ b/roboteam_mpi/src/MPIManager.h @@ -0,0 +1,25 @@ +#ifndef MPI_MANAGER_H +#define MPI_MANAGER_H + +#include +#include +#include + +class MPIManager { +public: + static void init(int& argc, char**& argv); + static void finalize(); + static int getRank(); + static int getSize(); + + static void send(const void* data, int count, MPI_Datatype datatype, int dest, int tag); + static void recv(void* data, int count, MPI_Datatype datatype, int source, int tag); + static void bcast(void* data, int count, MPI_Datatype datatype, int root); + +private: + static bool initialized; + static int rank; + static int size; + + MPIManager() = delete; // Prevent instantiation +}; \ No newline at end of file diff --git a/ssl-game-controller b/ssl-game-controller new file mode 160000 index 000000000..feb9b7636 --- /dev/null +++ b/ssl-game-controller @@ -0,0 +1 @@ +Subproject commit feb9b76361f30a9e9b476bc1f52516df0f33de17