From da39fdb4287a61472b2d784e3c0be2d97f058019 Mon Sep 17 00:00:00 2001 From: pierfabre Date: Sun, 14 Jan 2024 12:20:03 +0100 Subject: [PATCH] adjust gains on the robot, move files and rename folders --- .pre-commit-config.yaml | 2 +- README.md | 11 +- {furuta_gym => furuta}/__init__.py | 0 .../controls}/controllers.py | 0 .../hardware => furuta/logging}/__init__.py | 0 .../logging/protobuf}/__init__.py | 0 .../logging/protobuf/pendulum_state.proto | 2 +- furuta/logging/protobuf/pendulum_state_pb2.py | 25 ++ {furuta_gym/envs => furuta/rl}/__init__.py | 0 {furuta_gym => furuta/rl}/algos.py | 0 {furuta_gym => furuta/rl}/envs/furuta_base.py | 2 +- {furuta_gym => furuta/rl}/envs/furuta_real.py | 33 +- {furuta_gym => furuta/rl}/envs/furuta_sim.py | 2 +- {furuta_gym => furuta/rl}/wrappers.py | 4 +- {furuta_gym => furuta}/robot.py | 4 +- {furuta_gym => furuta}/utils.py | 0 furuta_gym/envs/hardware/LS7366R.py | 116 ------- furuta_gym/envs/hardware/motor.py | 60 ---- furuta_gym/logging/protobuf/__init__.py | 0 .../logging/protobuf/pendulum_state_pb2.py | 25 -- pyproject.toml | 4 +- scripts/configs/parameters.json | 10 +- scripts/control.py | 23 +- scripts/replay_mcap.py | 14 +- scripts/robot_inference.py | 99 +++--- scripts/train.py | 4 +- scripts/train_sac.py | 290 ++++++++++-------- 27 files changed, 305 insertions(+), 425 deletions(-) rename {furuta_gym => furuta}/__init__.py (100%) rename {robot/software => furuta/controls}/controllers.py (100%) rename {furuta_gym/envs/hardware => furuta/logging}/__init__.py (100%) rename {furuta_gym/logging => furuta/logging/protobuf}/__init__.py (100%) rename {furuta_gym => furuta}/logging/protobuf/pendulum_state.proto (99%) create mode 100644 furuta/logging/protobuf/pendulum_state_pb2.py rename {furuta_gym/envs => furuta/rl}/__init__.py (100%) rename {furuta_gym => furuta/rl}/algos.py (100%) rename {furuta_gym => furuta/rl}/envs/furuta_base.py (99%) rename {furuta_gym => furuta/rl}/envs/furuta_real.py (84%) rename {furuta_gym => furuta/rl}/envs/furuta_sim.py (98%) rename {furuta_gym => furuta/rl}/wrappers.py (98%) rename {furuta_gym => furuta}/robot.py (90%) rename {furuta_gym => furuta}/utils.py (100%) delete mode 100644 furuta_gym/envs/hardware/LS7366R.py delete mode 100755 furuta_gym/envs/hardware/motor.py delete mode 100644 furuta_gym/logging/protobuf/__init__.py delete mode 100644 furuta_gym/logging/protobuf/pendulum_state_pb2.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b1c5207..9cf1b56 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,7 +55,7 @@ repos: "--extend-ignore", "E203,E402,E501,F401,F841", "--exclude", - "logs/*,data/*,furuta_gym/logging/protobuf/*", + "logs/*,data/*,furuta/logging/protobuf/*", ] # yaml formatting diff --git a/README.md b/README.md index 9021cbc..288ae80 100644 --- a/README.md +++ b/README.md @@ -20,25 +20,30 @@ In this repository you will find everything you need to build and train a rotary If you have any question feel free to open an issue or DM me [@armand_dpl](twitter.com/armand_dpl). +## Usage +1. Plug-in the robot +2. Run `sudo dmesg | grep tty` in terminal to find which port is used for the device +3. run `python tests/interactive_robot_self_test.py` + ## MLOps During this project we leveraged [Weights and Biases](https://wandb.ai/site) MLOps tools to make our life easier. You can find our experiments, pre-trained models and reports [on our dashboard](https://wandb.ai/armandpl/furuta). You can also read more about [Training Reproducible Robots with W&B here](https://wandb.ai/armandpl/furuta/reports/Training-Reproducible-Robots-with-W-B--VmlldzoxMTY5NTM5). ## Credits To make this robot work we built on top of existing work! - We got the encoder precision and the idea to use a direct drive motor from the [Quanser Qube design](https://quanserinc.box.com/shared/static/5wnibclu7rp6xihm7mbxqxincu6dogur.pdf). -- We re-used bits from [Quanser's code](https://git.ias.informatik.tu-darmstadt.de/quanser/clients/-/tree/master/quanser_robots/qube). Notably: +- We re-used bits from [Quanser's code](https://git.ias.informatik.tu-darmstadt.de/quanser/clients/-/tree/master/quanser_robots/qube). Notably: * their VelocityFilter class to compute the angular speeds * their GentlyTerminating wrapper to send a zero command to the robot at the end of each episode * their rotary inverted pendulum simulation * their ActionLimiter class - The arm assembly is inspired by this [YouTube video](https://www.youtube.com/watch?v=xowrt6ShdCw) by Mack Tang. - The visualization we use for the simulation is copy-pasted from https://github.com/angelolovatto/gym-cartpole-swingup -- We use the [StableBaselines3](https://github.com/DLR-RM/stable-baselines3) library to train the robot. +- We use the [StableBaselines3](https://github.com/DLR-RM/stable-baselines3) library to train the robot. - We implemented tricks from [Antonin Raffin's talk at RLVS 2021](https://www.youtube.com/watch?v=Ikngt0_DXJg). * HistoryWrapper and continuity cost * [gSDE](https://arxiv.org/abs/2005.05719) - We use [code from Federico Bolanos](https://github.com/fbolanos/LS7366R/blob/master/LS7366R.py) to read the encoders counters. ## Authors -[Armand du Parc Locmaria](https://armandpl.com) +[Armand du Parc Locmaria](https://armandpl.com) [Pierre Fabre](https://www.linkedin.com/in/p-fabre/) diff --git a/furuta_gym/__init__.py b/furuta/__init__.py similarity index 100% rename from furuta_gym/__init__.py rename to furuta/__init__.py diff --git a/robot/software/controllers.py b/furuta/controls/controllers.py similarity index 100% rename from robot/software/controllers.py rename to furuta/controls/controllers.py diff --git a/furuta_gym/envs/hardware/__init__.py b/furuta/logging/__init__.py similarity index 100% rename from furuta_gym/envs/hardware/__init__.py rename to furuta/logging/__init__.py diff --git a/furuta_gym/logging/__init__.py b/furuta/logging/protobuf/__init__.py similarity index 100% rename from furuta_gym/logging/__init__.py rename to furuta/logging/protobuf/__init__.py diff --git a/furuta_gym/logging/protobuf/pendulum_state.proto b/furuta/logging/protobuf/pendulum_state.proto similarity index 99% rename from furuta_gym/logging/protobuf/pendulum_state.proto rename to furuta/logging/protobuf/pendulum_state.proto index eaccc1b..1ca63a5 100644 --- a/furuta_gym/logging/protobuf/pendulum_state.proto +++ b/furuta/logging/protobuf/pendulum_state.proto @@ -9,4 +9,4 @@ message PendulumState { bool done = 6; float action = 7; float corrected_action = 8; -} \ No newline at end of file +} diff --git a/furuta/logging/protobuf/pendulum_state_pb2.py b/furuta/logging/protobuf/pendulum_state_pb2.py new file mode 100644 index 0000000..3cdfa87 --- /dev/null +++ b/furuta/logging/protobuf/pendulum_state_pb2.py @@ -0,0 +1,25 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: pendulum_state.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x14pendulum_state.proto"\xc3\x01\n\rPendulumState\x12\x13\n\x0bmotor_angle\x18\x01 \x01(\x02\x12\x16\n\x0ependulum_angle\x18\x02 \x01(\x02\x12\x1c\n\x14motor_angle_velocity\x18\x03 \x01(\x02\x12\x1f\n\x17pendulum_angle_velocity\x18\x04 \x01(\x02\x12\x0e\n\x06reward\x18\x05 \x01(\x02\x12\x0c\n\x04\x64one\x18\x06 \x01(\x08\x12\x0e\n\x06\x61\x63tion\x18\x07 \x01(\x02\x12\x18\n\x10\x63orrected_action\x18\x08 \x01(\x02\x62\x06proto3' +) + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "pendulum_state_pb2", globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _PENDULUMSTATE._serialized_start = 25 + _PENDULUMSTATE._serialized_end = 220 +# @@protoc_insertion_point(module_scope) diff --git a/furuta_gym/envs/__init__.py b/furuta/rl/__init__.py similarity index 100% rename from furuta_gym/envs/__init__.py rename to furuta/rl/__init__.py diff --git a/furuta_gym/algos.py b/furuta/rl/algos.py similarity index 100% rename from furuta_gym/algos.py rename to furuta/rl/algos.py diff --git a/furuta_gym/envs/furuta_base.py b/furuta/rl/envs/furuta_base.py similarity index 99% rename from furuta_gym/envs/furuta_base.py rename to furuta/rl/envs/furuta_base.py index a778bef..bdd4246 100644 --- a/furuta_gym/envs/furuta_base.py +++ b/furuta/rl/envs/furuta_base.py @@ -6,7 +6,7 @@ import numpy as np from gym.spaces import Box -from furuta_gym.utils import ALPHA, ALPHA_DOT, THETA, THETA_DOT, Timing +from furuta.utils import ALPHA, ALPHA_DOT, THETA, THETA_DOT, Timing def alpha_reward(state): diff --git a/furuta_gym/envs/furuta_real.py b/furuta/rl/envs/furuta_real.py similarity index 84% rename from furuta_gym/envs/furuta_real.py rename to furuta/rl/envs/furuta_real.py index 80a0b1e..f4e9af7 100644 --- a/furuta_gym/envs/furuta_real.py +++ b/furuta/rl/envs/furuta_real.py @@ -4,21 +4,25 @@ import numpy as np -from furuta_gym.common import VelocityFilter -from furuta_gym.envs.furuta_base import FurutaBase -from furuta_gym.envs.hardware.robot import FurutaRobot +from furuta.common import VelocityFilter +from furuta.envs.furuta_base import FurutaBase +from furuta.envs.hardware.robot import FurutaRobot class FurutaReal(FurutaBase): - - def __init__(self, fs=100, fs_ctrl=100, - action_limiter=False, safety_th_lim=1.5, - reward="simple", state_limits='low', - config_file="robot.ini"): - super().__init__(fs, fs_ctrl, action_limiter, safety_th_lim, - reward, state_limits) - - self.robot = FurutaRobot() + def __init__( + self, + fs=100, + fs_ctrl=100, + action_limiter=False, + safety_th_lim=1.5, + reward="simple", + state_limits="low", + config_file="robot.ini", + ): + super().__init__(fs, fs_ctrl, action_limiter, safety_th_lim, reward, state_limits) + + self.robot = FurutaRobot() self.vel_filt = VelocityFilter(2, dt=self.timing.dt) @@ -61,7 +65,6 @@ def get_state(self): def _reset_pendulum(self, tolerance=10, still_time=1, clear=True): pass - def reset(self): logging.info("Reset env...") # reset pendulum @@ -69,7 +72,7 @@ def reset(self): # reset motor logging.debug("Reset motor") while True: - state = self._read_state()*180/np.pi + state = self._read_state() * 180 / np.pi motor_angle = state[0] motor_speed = state[2] @@ -83,7 +86,7 @@ def reset(self): elif motor_angle > 0: self.motor.set_speed(0.3) - sleep(10/100) + sleep(10 / 100) self.motor.set_speed(0) diff --git a/furuta_gym/envs/furuta_sim.py b/furuta/rl/envs/furuta_sim.py similarity index 98% rename from furuta_gym/envs/furuta_sim.py rename to furuta/rl/envs/furuta_sim.py index 9d06dad..a6ae5e4 100644 --- a/furuta_gym/envs/furuta_sim.py +++ b/furuta/rl/envs/furuta_sim.py @@ -4,7 +4,7 @@ import numpy as np from numpy.linalg import inv -from furuta_gym.utils import ALPHA, ALPHA_DOT, THETA, THETA_DOT, VelocityFilter +from furuta.utils import ALPHA, ALPHA_DOT, THETA, THETA_DOT, VelocityFilter from .furuta_base import FurutaBase diff --git a/furuta_gym/wrappers.py b/furuta/rl/wrappers.py similarity index 98% rename from furuta_gym/wrappers.py rename to furuta/rl/wrappers.py index 500ed8b..7d9003f 100644 --- a/furuta_gym/wrappers.py +++ b/furuta/rl/wrappers.py @@ -5,11 +5,11 @@ import gym import numpy as np +import wandb from gym.spaces import Box from mcap_protobuf.writer import Writer -import wandb -from furuta_gym.logging.protobuf.pendulum_state_pb2 import PendulumState +from furuta.logging.protobuf.pendulum_state_pb2 import PendulumState class GentlyTerminating(gym.Wrapper): diff --git a/furuta_gym/robot.py b/furuta/robot.py similarity index 90% rename from furuta_gym/robot.py rename to furuta/robot.py index 5be9848..eaa1b2d 100644 --- a/furuta_gym/robot.py +++ b/furuta/robot.py @@ -1,9 +1,7 @@ import struct -import matplotlib.pyplot as plt import numpy as np import serial -import simple_pid class Robot: @@ -11,7 +9,7 @@ def __init__(self, device="dev/ttyACM0", baudrate=921600): self.ser = serial.Serial(device, baudrate) def step(self, motor_command: float): - """motor command is a float between -1 and 1.""" + """Motor command is a float between -1 and 1.""" direction = motor_command < 0 # convert motor command to 16 bit unsigned int int_motor_command = int(np.abs(motor_command) * (2**16 - 1)) diff --git a/furuta_gym/utils.py b/furuta/utils.py similarity index 100% rename from furuta_gym/utils.py rename to furuta/utils.py diff --git a/furuta_gym/envs/hardware/LS7366R.py b/furuta_gym/envs/hardware/LS7366R.py deleted file mode 100644 index db77c20..0000000 --- a/furuta_gym/envs/hardware/LS7366R.py +++ /dev/null @@ -1,116 +0,0 @@ -#!/usr/bin/python - -# Python library to interface with the chip LS7366R for the Raspberry Pi -# Written by Federico Bolanos -# Last Edit: May 12th 2020 -# Reason: Updating to python3... better late than never eh? - -from time import sleep -import spidev - -# Usage: import LS7366R then call enc = LS7366R(CSX, CLK, BTMD) -# CSX is either CE0 or CE1, CLK is the speed, -# BTMD is the bytemode 1-4 the resolution of your counter. -# example: lever.Encoder(0, 1000000, 4) -# These are the values I normally use. - - -class LS7366R(): - # ------------------------------------------- - # Constants - - # Commands - CLEAR_COUNTER = 0x20 - CLEAR_STATUS = 0x30 - READ_COUNTER = 0x60 - READ_STATUS = 0x70 - WRITE_MODE0 = 0x88 - WRITE_MODE1 = 0x90 - - # Modes - FOURX_COUNT = 0x03 - - FOURBYTE_COUNTER = 0x00 - THREEBYTE_COUNTER = 0x01 - TWOBYTE_COUNTER = 0x02 - ONEBYTE_COUNTER = 0x03 - - BYTE_MODE = [ONEBYTE_COUNTER, TWOBYTE_COUNTER, - THREEBYTE_COUNTER, FOURBYTE_COUNTER] - - # Values - max_val = 4294967295 - - # Global Variables - - counterSize = 4 # Default 4 - - # ---------------------------------------------- - # Constructor - - def __init__(self, CSX, CLK, BTMD, BUS): - self.counterSize = BTMD # Sets the byte mode that will be used - - self.spi = spidev.SpiDev() # Initialize object - self.spi.open(BUS, CSX) # Which CS line will be used - self.spi.max_speed_hz = CLK # Speed of clk - - # Init the Encoder - print(f'Clearing Encoder CS{CSX}\'s Count...\t{self.clearCounter()}') - print(f'Clearing Encoder CS{CSX}\'s Status..\t{self.clearStatus}') - - self.spi.xfer2([self.WRITE_MODE0, self.FOURX_COUNT]) - - sleep(.1) # Rest - - self.spi.xfer2([self.WRITE_MODE1, self.BYTE_MODE[self.counterSize-1]]) - - def close(self): - print('\nClosing SPI port') - self.spi.close() - - def clearCounter(self): - self.spi.xfer2([self.CLEAR_COUNTER]) - - return '[DONE]' - - def clearStatus(self): - self.spi.xfer2([self.CLEAR_STATUS]) - - return '[DONE]' - - def readCounter(self): - readTransaction = [self.READ_COUNTER] - - for i in range(self.counterSize): - readTransaction.append(0) - - data = self.spi.xfer2(readTransaction) - - EncoderCount = 0 - for i in range(self.counterSize): - EncoderCount = (EncoderCount << 8) + data[i+1] - - if data[1] != 255: - return EncoderCount - else: - return (EncoderCount - (self.max_val+1)) - - def readStatus(self): - data = self.spi.xfer2([self.READ_STATUS, 0xFF]) - - return data[1] - - -if __name__ == "__main__": - - encoder = LS7366R(1, 1000000, 4, 1) - try: - while True: - print("Encoder count: ", - str(encoder.readCounter()).zfill(8), - end="\r") - sleep(0.05) - except KeyboardInterrupt: - encoder.close() - print("Test programming ending.") diff --git a/furuta_gym/envs/hardware/motor.py b/furuta_gym/envs/hardware/motor.py deleted file mode 100755 index 3fd10a2..0000000 --- a/furuta_gym/envs/hardware/motor.py +++ /dev/null @@ -1,60 +0,0 @@ -import Jetson.GPIO as GPIO - - -class Motor(): - def __init__(self, D2, IN1, IN2, freq=500): - self.D2 = D2 - self.IN1 = IN1 - self.IN2 = IN2 - - GPIO.setmode(GPIO.BOARD) - - GPIO.setup(self.IN1, GPIO.OUT) - GPIO.setup(self.IN2, GPIO.OUT) - GPIO.setup(self.D2, GPIO.OUT, initial=GPIO.HIGH) - - self.pwm = GPIO.PWM(self.D2, freq) - self.pwm.start(0) - - def set_speed(self, action): - if action >= 0: - self.set_direction(1) - elif action < 0: - self.set_direction(-1) - - # make sure the motor gets the minimum voltage - # TODO: should depend on motor specs - action = abs(action) - if action > 0.0: - action = 0.2 + 0.8 * action - - duty_cyle = action*100.0 - self.pwm.ChangeDutyCycle(duty_cyle) - - def set_direction(self, direction): - if direction == -1: - GPIO.output(self.IN2, GPIO.LOW) - GPIO.output(self.IN1, GPIO.HIGH) - elif direction == 1: - GPIO.output(self.IN1, GPIO.LOW) - GPIO.output(self.IN2, GPIO.HIGH) - - def close(self): - self.set_speed(0) - self.pwm.stop() - GPIO.cleanup() - - -if __name__ == "__main__": - from time import sleep - - motor = Motor(32, 29, 31) - - print("go") - for i in range(10): - motor.set_speed(0.3) - sleep(1/3) - motor.set_speed(-0.3) - sleep(1/3) - - motor.close() diff --git a/furuta_gym/logging/protobuf/__init__.py b/furuta_gym/logging/protobuf/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/furuta_gym/logging/protobuf/pendulum_state_pb2.py b/furuta_gym/logging/protobuf/pendulum_state_pb2.py deleted file mode 100644 index 9263c67..0000000 --- a/furuta_gym/logging/protobuf/pendulum_state_pb2.py +++ /dev/null @@ -1,25 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: pendulum_state.proto -"""Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x14pendulum_state.proto\"\xc3\x01\n\rPendulumState\x12\x13\n\x0bmotor_angle\x18\x01 \x01(\x02\x12\x16\n\x0ependulum_angle\x18\x02 \x01(\x02\x12\x1c\n\x14motor_angle_velocity\x18\x03 \x01(\x02\x12\x1f\n\x17pendulum_angle_velocity\x18\x04 \x01(\x02\x12\x0e\n\x06reward\x18\x05 \x01(\x02\x12\x0c\n\x04\x64one\x18\x06 \x01(\x08\x12\x0e\n\x06\x61\x63tion\x18\x07 \x01(\x02\x12\x18\n\x10\x63orrected_action\x18\x08 \x01(\x02\x62\x06proto3') - -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'pendulum_state_pb2', globals()) -if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - _PENDULUMSTATE._serialized_start=25 - _PENDULUMSTATE._serialized_end=220 -# @@protoc_insertion_point(module_scope) diff --git a/pyproject.toml b/pyproject.toml index d2af5af..6d95111 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,10 @@ [tool.poetry] -name = "furuta-gym" +name = "furuta" version = "0.1.0" description = "" authors = ["Your Name "] readme = "README.md" -packages = [{include = "furuta_gym"}] +packages = [{include = "furuta"}] [tool.poetry.dependencies] python = "^3.8" diff --git a/scripts/configs/parameters.json b/scripts/configs/parameters.json index cd17486..4d22635 100644 --- a/scripts/configs/parameters.json +++ b/scripts/configs/parameters.json @@ -3,17 +3,17 @@ "pendulum_controller":{ "controller_type": "PIDController", "control_frequency": 2500, - "Kp": 1.0, + "Kp": 3.5, "Ki": 20.0, - "Kd": 0.01, + "Kd": 0.1, "setpoint": 180.0 }, "motor_controller":{ "controller_type": "PIDController", - "Kp": 1.0, - "Ki": 0.0, + "Kp": 0.2, + "Ki": 0.001, "Kd": 0.0, "setpoint": 0.0 }, - "angle_threshold": 30.0 + "angle_threshold": 60.0 } diff --git a/scripts/control.py b/scripts/control.py index 4bfd18e..bf39ad0 100644 --- a/scripts/control.py +++ b/scripts/control.py @@ -1,16 +1,10 @@ import json -import matplotlib - -# Tkinter needs to be installed for matplotlib to work -# sudo apt-get install python3-tk -matplotlib.use("TkAgg") - import matplotlib.pyplot as plt import numpy as np -from furuta_gym.robot import Robot -from robot.software.controllers import Controller +from furuta.controls.controllers import Controller +from furuta.robot import Robot def read_parameters_file(): @@ -20,7 +14,6 @@ def read_parameters_file(): def has_pendulum_fallen(pendulum_angle: float, parameters: dict): - # TODO: Implement in firware.ino for better safety? setpoint = parameters["pendulum_controller"]["setpoint"] angle_threshold = parameters["angle_threshold"] return np.abs(pendulum_angle - setpoint) > angle_threshold @@ -79,21 +72,23 @@ def plot_data(actions, motor_angles, pendulum_angles): # Reset encoders robot.reset_encoders() - print("GO") # Wait for user input to start the control loop - input() + input("Encoders reset, lift the pendulum and press enter to start the control loop.") # Get the initial motor and pendulum angles motor_angle, pendulum_angle = robot.step(0) # Control loop while True: - # Compute the motor command from pendulum controller - action = pendulum_controller.compute_command(pendulum_angle) + # Init action + action = 0 + + # Add the motor command from pendulum controller + action -= pendulum_controller.compute_command(pendulum_angle) # Add the motor command from motor controller - action += motor_controller.compute_command(motor_angle) + action -= motor_controller.compute_command(motor_angle) # Clip the command between -1 and 1 action = np.clip(action, -1, 1) diff --git a/scripts/replay_mcap.py b/scripts/replay_mcap.py index 231c4ce..59de778 100644 --- a/scripts/replay_mcap.py +++ b/scripts/replay_mcap.py @@ -1,8 +1,10 @@ -from mcap_protobuf.reader import read_protobuf_messages -from furuta_gym.envs.furuta_base import CartPoleSwingUpViewer -import numpy as np import time +import numpy as np +from mcap_protobuf.reader import read_protobuf_messages + +from furuta.envs.furuta_base import CartPoleSwingUpViewer + if __name__ == "__main__": # TODO use smth like argparse to choose the file # or a text interface @@ -13,7 +15,9 @@ for msg in read_protobuf_messages(mcap_path, log_time_order=True): p = msg.proto_msg - state = np.array([p.motor_angle, p.pendulum_angle, p.motor_angle_velocity, p.pendulum_angle_velocity]) + state = np.array( + [p.motor_angle, p.pendulum_angle, p.motor_angle_velocity, p.pendulum_angle_velocity] + ) viewer.update(state) viewer.render(return_rgb_array=False) - time.sleep(dt) \ No newline at end of file + time.sleep(dt) diff --git a/scripts/robot_inference.py b/scripts/robot_inference.py index fe9d27c..b127102 100644 --- a/scripts/robot_inference.py +++ b/scripts/robot_inference.py @@ -1,22 +1,20 @@ import argparse -from collections import namedtuple -from distutils.util import strtobool import logging import os +from collections import namedtuple +from distutils.util import strtobool import gym +import wandb from gym.wrappers import TimeLimit from stable_baselines3 import SAC from stable_baselines3.common.monitor import Monitor -import wandb -import furuta_gym # noqa F420 -from furuta_gym.envs.wrappers import GentlyTerminating, \ - HistoryWrapper, \ - ControlFrequency +import furuta # noqa F420 +from furuta.envs.wrappers import ControlFrequency, GentlyTerminating, HistoryWrapper -class Robot(): +class Robot: def __init__(self, args): self.args = args self.load_model() @@ -45,9 +43,7 @@ def run_episode(self): total_reward = 0 obs = self.env.reset() while True: - action, _states = \ - self.model.predict(obs, - deterministic=self.args.deterministic) + action, _states = self.model.predict(obs, deterministic=self.args.deterministic) obs, reward, done, info = self.env.step(action) total_reward += reward if done: @@ -67,7 +63,7 @@ def main(args): config=vars(args), sync_tensorboard=True, save_code=True, - job_type="inference" + job_type="inference", ) robot = Robot(run.config) @@ -84,10 +80,14 @@ def main(args): def setup_env(args): # base env - env = gym.make(args.gym_id, fs=args.fs, fs_ctrl=args.fs_ctrl, - action_limiter=args.action_limiter, - safety_th_lim=args.safety_th_lim, - state_limits=args.state_limits) + env = gym.make( + args.gym_id, + fs=args.fs, + fs_ctrl=args.fs_ctrl, + action_limiter=args.action_limiter, + safety_th_lim=args.safety_th_lim, + state_limits=args.state_limits, + ) wandb.run.summary["state_max"] = env.state_max @@ -108,37 +108,45 @@ def setup_env(args): def parse_args(): - parser = argparse.ArgumentParser(description='TD3 agent') + parser = argparse.ArgumentParser(description="TD3 agent") # Common arguments - parser.add_argument('model_artifact', type=str, - help="the artifact version of the model to load") - - parser.add_argument('--deterministic', default=True, - type=lambda x: bool(strtobool(x)), - help='Whether to use a deterministic policy or not') - parser.add_argument('--gym_id', type=str, default="FurutaReal-v0", - help='the id of the gym environment') - parser.add_argument('--wandb_project', type=str, default="furuta", - help="the wandb's project name") - parser.add_argument('--wandb_entity', type=str, default=None, - help="the entity (team) of wandb's project") + parser.add_argument( + "model_artifact", type=str, help="the artifact version of the model to load" + ) + + parser.add_argument( + "--deterministic", + default=True, + type=lambda x: bool(strtobool(x)), + help="Whether to use a deterministic policy or not", + ) + parser.add_argument( + "--gym_id", type=str, default="FurutaReal-v0", help="the id of the gym environment" + ) + parser.add_argument( + "--wandb_project", type=str, default="furuta", help="the wandb's project name" + ) + parser.add_argument( + "--wandb_entity", type=str, default=None, help="the entity (team) of wandb's project" + ) parser.add_argument("-d", "--debug", action="store_true") # env params - parser.add_argument('--fs', type=int, - help='Sampling frequency') - parser.add_argument('--fs_ctrl', type=int, - help='control frequency') - parser.add_argument('--episode_length', type=int, - help='the maximum length of each episode. \ - -1 = infinite') - parser.add_argument('--safety_th_lim', type=float, - help='Max motor (theta) angle in rad.') - parser.add_argument('--action_limiter', - type=lambda x: bool(strtobool(x)), - help='Restrict actions') - parser.add_argument('--state_limits', type=str, - help='Wether to use high or low limits. See code.') + parser.add_argument("--fs", type=int, help="Sampling frequency") + parser.add_argument("--fs_ctrl", type=int, help="control frequency") + parser.add_argument( + "--episode_length", + type=int, + help="the maximum length of each episode. \ + -1 = infinite", + ) + parser.add_argument("--safety_th_lim", type=float, help="Max motor (theta) angle in rad.") + parser.add_argument( + "--action_limiter", type=lambda x: bool(strtobool(x)), help="Restrict actions" + ) + parser.add_argument( + "--state_limits", type=str, help="Wether to use high or low limits. See code." + ) args = parser.parse_args() @@ -148,8 +156,5 @@ def parse_args(): if __name__ == "__main__": args = parse_args() logging_level = logging.DEBUG if args.debug else logging.INFO - logging.basicConfig( - format='%(levelname)s: %(message)s', - level=logging_level - ) + logging.basicConfig(format="%(levelname)s: %(message)s", level=logging_level) main(args) diff --git a/scripts/train.py b/scripts/train.py index 58ef1df..b1590ae 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -5,12 +5,12 @@ import random from pathlib import Path -import furuta_gym # noqa F420 import gym import hydra import numpy as np import stable_baselines3 import torch +import wandb from omegaconf import DictConfig, OmegaConf, open_dict from stable_baselines3.common.callbacks import ( EvalCallback, @@ -18,7 +18,7 @@ ) from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder -import wandb +import furuta # noqa F420 # TODO # - save model/replay buffer X diff --git a/scripts/train_sac.py b/scripts/train_sac.py index 24c4b21..13ae9e5 100644 --- a/scripts/train_sac.py +++ b/scripts/train_sac.py @@ -1,22 +1,24 @@ import argparse import configparser -from distutils.util import strtobool import logging -from pathlib import Path import os +from distutils.util import strtobool +from pathlib import Path import gym +import wandb from gym.wrappers import TimeLimit from stable_baselines3 import SAC from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder -import wandb -import furuta_gym # noqa F420 -from furuta_gym.wrappers import GentlyTerminating, \ - HistoryWrapper, \ - ControlFrequency, \ - MCAPLogger +import furuta # noqa F420 +from furuta.rl.wrappers import ( + ControlFrequency, + GentlyTerminating, + HistoryWrapper, + MCAPLogger, +) def main(args): @@ -26,7 +28,7 @@ def main(args): config=args, sync_tensorboard=True, monitor_gym=args.capture_video, - save_code=True + save_code=True, ) args = run.config @@ -35,29 +37,32 @@ def main(args): verbose = 2 if args.debug else 0 model = SAC( - "MlpPolicy", env, verbose=verbose, - learning_rate=args.learning_rate, seed=args.seed, - buffer_size=args.buffer_size, tau=args.tau, - gamma=args.gamma, batch_size=args.batch_size, - target_update_interval=args.target_update_interval, - learning_starts=args.learning_starts, - use_sde=args.use_sde, use_sde_at_warmup=args.use_sde_at_warmup, - sde_sample_freq=args.sde_sample_freq, - train_freq=(args.train_freq, args.train_freq_unit), - gradient_steps=args.gradient_steps, - tensorboard_log=f"runs/{run.id}", - ) + "MlpPolicy", + env, + verbose=verbose, + learning_rate=args.learning_rate, + seed=args.seed, + buffer_size=args.buffer_size, + tau=args.tau, + gamma=args.gamma, + batch_size=args.batch_size, + target_update_interval=args.target_update_interval, + learning_starts=args.learning_starts, + use_sde=args.use_sde, + use_sde_at_warmup=args.use_sde_at_warmup, + sde_sample_freq=args.sde_sample_freq, + train_freq=(args.train_freq, args.train_freq_unit), + gradient_steps=args.gradient_steps, + tensorboard_log=f"runs/{run.id}", + ) # TODO seed everything if args.model_artifact: - model.load(download_artifact_file(f"sac_model:{args.model_artifact}", - "sac.zip")) + model.load(download_artifact_file(f"sac_model:{args.model_artifact}", "sac.zip")) if args.rb_artifact: - rb_path = download_artifact_file( - f"sac_replay_buffer:{args.model_artifact}", - "buffer.pkl") + rb_path = download_artifact_file(f"sac_replay_buffer:{args.model_artifact}", "buffer.pkl") model.load_replay_buffer(rb_path) try: @@ -80,10 +85,13 @@ def main(args): def setup_env(args): # base env - env = gym.make(args.gym_id, fs=args.fs, - action_limiter=args.action_limiter, - safety_th_lim=args.safety_th_lim, - state_limits=args.state_limits) + env = gym.make( + args.gym_id, + fs=args.fs, + action_limiter=args.action_limiter, + safety_th_lim=args.safety_th_lim, + state_limits=args.state_limits, + ) wandb.run.summary["state_max"] = env.state_max @@ -95,9 +103,7 @@ def setup_env(args): if args.log_mcap: env = MCAPLogger( - env, - f"../data/{wandb.run.id}", - use_sim_time=(args.gym_id == "FurutaSim-v0") + env, f"../data/{wandb.run.id}", use_sim_time=(args.gym_id == "FurutaSim-v0") ) if args.episode_length != -1: @@ -117,16 +123,18 @@ def setup_env(args): # import pyvirtualdisplay # pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start() env = DummyVecEnv([lambda: env]) - env = VecVideoRecorder(env, f"videos/{wandb.run.id}", - record_video_trigger=lambda x: x % 3000 == 0, - video_length=300) + env = VecVideoRecorder( + env, + f"videos/{wandb.run.id}", + record_video_trigger=lambda x: x % 3000 == 0, + video_length=300, + ) return env def download_artifact_file(artifact_alias, filename): - """ - Download artifact and returns path to filename. + """Download artifact and returns path to filename. :param artifact_name: wandb artifact alias :param filename: filename in the artifact @@ -164,95 +172,136 @@ def load_sim_params(env, param_pth): def parse_args(): - parser = argparse.ArgumentParser(description='TD3 agent') + parser = argparse.ArgumentParser(description="TD3 agent") # Common arguments - parser.add_argument('--gym_id', type=str, default="FurutaSim-v0", - help='the id of the gym environment') - parser.add_argument('--learning_rate', type=float, default=3e-4, - help='the learning rate for the optimizer') - parser.add_argument('--seed', type=int, default=1, - help='seed of the experiment') - parser.add_argument('--total_timesteps', type=int, default=1000000, - help='total timesteps of the experiments') - parser.add_argument('--capture_video', - type=lambda x: bool(strtobool(x)), default=False, - help='capture videos of the agent\ - (check out `videos` folder)') # TODO make that an an int n and if = 0 then no video - # else it captures every n steps - parser.add_argument('--log_mcap', - type=lambda x: bool(strtobool(x)), default=False, - help='log mcap data') - parser.add_argument('--wandb_project', type=str, default="furuta", - help="the wandb's project name") - parser.add_argument('--wandb_entity', type=str, default=None, - help="the entity (team) of wandb's project") + parser.add_argument( + "--gym_id", type=str, default="FurutaSim-v0", help="the id of the gym environment" + ) + parser.add_argument( + "--learning_rate", type=float, default=3e-4, help="the learning rate for the optimizer" + ) + parser.add_argument("--seed", type=int, default=1, help="seed of the experiment") + parser.add_argument( + "--total_timesteps", type=int, default=1000000, help="total timesteps of the experiments" + ) + parser.add_argument( + "--capture_video", + type=lambda x: bool(strtobool(x)), + default=False, + help="capture videos of the agent\ + (check out `videos` folder)", + ) # TODO make that an an int n and if = 0 then no video + # else it captures every n steps + parser.add_argument( + "--log_mcap", type=lambda x: bool(strtobool(x)), default=False, help="log mcap data" + ) + parser.add_argument( + "--wandb_project", type=str, default="furuta", help="the wandb's project name" + ) + parser.add_argument( + "--wandb_entity", type=str, default=None, help="the entity (team) of wandb's project" + ) parser.add_argument("-d", "--debug", action="store_true") # Algorithm specific arguments - parser.add_argument('--buffer_size', type=int, default=int(1e6), - help='the replay memory buffer size') - parser.add_argument('--tau', type=float, default=0.005, - help="target smoothing coefficient.") - parser.add_argument('--gamma', type=float, default=0.99, - help="target smoothing coefficient.") - parser.add_argument('--batch_size', type=int, default=256, - help="the batch size of sample from the replay memory") - parser.add_argument('--target_update_interval', - type=int, default=1, - help="the frequency of training policy (delayed)") - parser.add_argument('--learning_starts', - type=int, default=25e3, - help="when to start learning") - parser.add_argument('--sde_sample_freq', - type=int, default=-1, - help="Sample a new noise matrix every n steps when using gSDE \ + parser.add_argument( + "--buffer_size", type=int, default=int(1e6), help="the replay memory buffer size" + ) + parser.add_argument("--tau", type=float, default=0.005, help="target smoothing coefficient.") + parser.add_argument("--gamma", type=float, default=0.99, help="target smoothing coefficient.") + parser.add_argument( + "--batch_size", + type=int, + default=256, + help="the batch size of sample from the replay memory", + ) + parser.add_argument( + "--target_update_interval", + type=int, + default=1, + help="the frequency of training policy (delayed)", + ) + parser.add_argument("--learning_starts", type=int, default=25e3, help="when to start learning") + parser.add_argument( + "--sde_sample_freq", + type=int, + default=-1, + help="Sample a new noise matrix every n steps when using gSDE \ Default: -1 \ - (only sample at the beginning of the rollout)") - parser.add_argument('--use_sde', - type=lambda x: bool(strtobool(x)), default=True, - help="Whether to use generalized State Dependent Exploration (gSDE) \ - instead of action noise exploration") - parser.add_argument('--use_sde_at_warmup', - type=lambda x: bool(strtobool(x)), default=True, - help="Whether to use gSDE instead of uniform sampling during the warm up \ - phase (before learning starts)") + (only sample at the beginning of the rollout)", + ) + parser.add_argument( + "--use_sde", + type=lambda x: bool(strtobool(x)), + default=True, + help="Whether to use generalized State Dependent Exploration (gSDE) \ + instead of action noise exploration", + ) + parser.add_argument( + "--use_sde_at_warmup", + type=lambda x: bool(strtobool(x)), + default=True, + help="Whether to use gSDE instead of uniform sampling during the warm up \ + phase (before learning starts)", + ) # params to accomodate embedded system - parser.add_argument('--model_artifact', type=str, default=None, - help="the artifact version of the model to load") - parser.add_argument('--rb_artifact', type=str, default=None, - help="Artifact version of the replay buffer to load") - parser.add_argument('--train_freq', - type=int, default=1, - help="The frequency of training critics/q functions") - parser.add_argument('--train_freq_unit', - type=str, default="episode", - help="The frequency unit") - parser.add_argument('--gradient_steps', - type=int, default=-1, - help="How many training iterations.") + parser.add_argument( + "--model_artifact", + type=str, + default=None, + help="the artifact version of the model to load", + ) + parser.add_argument( + "--rb_artifact", + type=str, + default=None, + help="Artifact version of the replay buffer to load", + ) + parser.add_argument( + "--train_freq", type=int, default=1, help="The frequency of training critics/q functions" + ) + parser.add_argument( + "--train_freq_unit", type=str, default="episode", help="The frequency unit" + ) + parser.add_argument( + "--gradient_steps", type=int, default=-1, help="How many training iterations." + ) # env params - parser.add_argument('--fs', type=int, default=100, - help='Sampling frequency') - parser.add_argument('--episode_length', type=int, default=3000, - help='the maximum length of each episode. \ - -1 = infinite') - parser.add_argument('--safety_th_lim', type=float, default=1.5, - help='Max motor (theta) angle in rad.') - parser.add_argument('--action_limiter', - type=lambda x: bool(strtobool(x)), default=False, - help='Restrict actions') - parser.add_argument('--state_limits', type=str, default="low", - help='Wether to use high or low limits. See code.') - parser.add_argument('--continuity_cost', - type=lambda x: bool(strtobool(x)), default=False, - help='If true use continuity cost from HistoryWrapper') - parser.add_argument('--history', type=int, default=1, - help='If >1 use HistoryWrapper') - parser.add_argument('--custom_sim', - type=str, default=None, - help='Use params from the provided file.') + parser.add_argument("--fs", type=int, default=100, help="Sampling frequency") + parser.add_argument( + "--episode_length", + type=int, + default=3000, + help="the maximum length of each episode. \ + -1 = infinite", + ) + parser.add_argument( + "--safety_th_lim", type=float, default=1.5, help="Max motor (theta) angle in rad." + ) + parser.add_argument( + "--action_limiter", + type=lambda x: bool(strtobool(x)), + default=False, + help="Restrict actions", + ) + parser.add_argument( + "--state_limits", + type=str, + default="low", + help="Wether to use high or low limits. See code.", + ) + parser.add_argument( + "--continuity_cost", + type=lambda x: bool(strtobool(x)), + default=False, + help="If true use continuity cost from HistoryWrapper", + ) + parser.add_argument("--history", type=int, default=1, help="If >1 use HistoryWrapper") + parser.add_argument( + "--custom_sim", type=str, default=None, help="Use params from the provided file." + ) args = parser.parse_args() @@ -267,8 +316,5 @@ def parse_args(): args = parse_args() logging_level = logging.DEBUG if args.debug else logging.INFO - logging.basicConfig( - format='%(levelname)s: %(message)s', - level=logging_level - ) + logging.basicConfig(format="%(levelname)s: %(message)s", level=logging_level) main(args)