-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit f640797
Showing
10 changed files
with
449 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# IDEs | ||
.idea/ | ||
.vscode/ | ||
|
||
# OS & miscellaneous | ||
.DS_Store | ||
out/ | ||
|
||
### Python ### | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
models/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# reinforcement-learning-skiing | ||
|
||
## Required libraries | ||
|
||
```bash | ||
conda create -n rl-skiing python=3.11.5 | ||
conda activate rl-skiing | ||
pip install numpy matplotlib jupyter opencv-python gymnasium | ||
pip install "gymnasium[accept-rom-license, atari]" | ||
``` | ||
|
||
- For windows (GPU): | ||
|
||
```bash | ||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 | ||
``` | ||
|
||
- For windows (CPU) / macos (CPU) / linux (GPU): | ||
|
||
```bash | ||
pip install torch torchvision | ||
``` | ||
|
||
- For linux (CPU): | ||
|
||
```bash | ||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu | ||
``` | ||
|
||
## How to run the project | ||
|
||
### 1. Set the PYTHONPATH | ||
|
||
- Windows - Powershell: | ||
|
||
```bash | ||
$env:PYTHONPATH='.' | ||
``` | ||
|
||
- Windows - CMD: | ||
|
||
```bash | ||
set PYTHONPATH=. | ||
``` | ||
|
||
- Linux / MacOS: | ||
|
||
```bash | ||
export PYTHONPATH=. | ||
``` | ||
|
||
### 2. Run the project | ||
|
||
- Run DQN: | ||
|
||
```bash | ||
python src/dqn/main.py | ||
``` | ||
|
||
- Run PPO: | ||
|
||
```bash | ||
python src/ppo/main.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import torch | ||
from matplotlib import pyplot as plt | ||
|
||
from dqn import DQN | ||
from replay_memory import ReplayMemory | ||
from random import random | ||
from src.dqn.constants import ( | ||
POLICY_NET_PATH, | ||
TARGET_NET_PATH, | ||
MODELS_PATH, | ||
EPSILON_MAX, | ||
EPSILON_MIN, | ||
EPSILON_DECAY, | ||
GAMMA, | ||
BATCH_SIZE, | ||
DEVICE, | ||
) | ||
import numpy as np | ||
|
||
from src.utils.helpers import check_if_dirs_exist | ||
|
||
|
||
class Agent: | ||
def __init__(self, action_space): | ||
self.action_space = action_space | ||
self.gamma: float = GAMMA | ||
self.device = DEVICE | ||
|
||
self.epsilon: float = EPSILON_MAX | ||
self.epsilon_min: float = EPSILON_MIN | ||
self.epsilon_decay: float = EPSILON_DECAY | ||
|
||
self.replay_memory: ReplayMemory = ReplayMemory() | ||
|
||
self.policy_net: DQN = DQN().to(self.device) | ||
self.target_net: DQN = DQN().to(self.device) | ||
self.update_target_net() | ||
|
||
def update_target_net(self): | ||
self.target_net.load_state_dict(self.policy_net.state_dict()) | ||
|
||
def select_action(self, state): | ||
if random() < self.epsilon: | ||
return self.action_space.sample() | ||
|
||
if not torch.is_tensor(state): | ||
state = torch.from_numpy(np.array(state)).float().unsqueeze(0).to(self.device) | ||
|
||
with torch.no_grad(): | ||
action = torch.argmax(self.policy_net(state)) | ||
|
||
return action.item() | ||
|
||
def decay_epsilon(self): | ||
self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min) | ||
|
||
def please_learn(self): | ||
if len(self.replay_memory) < BATCH_SIZE: | ||
return | ||
|
||
states, actions, rewards, dones, next_states = self.replay_memory.sample(self.device) | ||
|
||
predicted_qs = self.policy_net(states) | ||
target_qs = self.target_net(next_states) | ||
target_qs = torch.max(target_qs, dim=1).values.reshape(-1, 1) | ||
target_qs[dones] = 0.0 | ||
target_qs = rewards + (self.gamma * target_qs) | ||
|
||
loss = self.policy_net.loss(predicted_qs, target_qs) | ||
self.policy_net.optimizer.zero_grad() | ||
loss.backward() | ||
self.policy_net.optimizer.step() | ||
|
||
def save(self): | ||
check_if_dirs_exist([MODELS_PATH]) | ||
torch.save(self.policy_net.state_dict(), POLICY_NET_PATH) | ||
torch.save(self.target_net.state_dict(), TARGET_NET_PATH) | ||
|
||
def load(self): | ||
self.policy_net.load_state_dict(torch.load(POLICY_NET_PATH)) | ||
self.target_net.load_state_dict(torch.load(TARGET_NET_PATH)) | ||
self.target_net.eval() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from pathlib import Path | ||
import torch | ||
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
NUM_ACTIONS = 3 | ||
INPUT_SHAPE = (1, 75, 70) # PyTorch uses (channels, height, width) format | ||
|
||
# TODO: Fine tuning | ||
LEARNING_RATE = 1e-2 | ||
MEMORY_CAPACITY = 32 | ||
NUM_EPISODES = 10 | ||
BATCH_SIZE = 16 | ||
UPDATE_FREQUENCY = 20 | ||
|
||
# These might be good | ||
GAMMA = 0.99 | ||
EPSILON_MAX = 1.0 | ||
EPSILON_MIN = 0.01 | ||
EPSILON_DECAY = 0.99 | ||
|
||
MODELS_PATH = Path("models") | ||
POLICY_NET_PATH = MODELS_PATH / "policy_net.pth" | ||
TARGET_NET_PATH = MODELS_PATH / "target_net.pth" | ||
MODEL_PATH = MODELS_PATH / "model.pth" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch import optim | ||
from src.dqn.constants import INPUT_SHAPE, NUM_ACTIONS, LEARNING_RATE | ||
|
||
|
||
class DQN(nn.Module): | ||
def __init__(self): | ||
super(DQN, self).__init__() | ||
self.conv1 = nn.Conv2d(INPUT_SHAPE[0], 32, kernel_size=8, stride=4) | ||
self.norm1 = nn.BatchNorm2d(32) | ||
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) | ||
self.norm2 = nn.BatchNorm2d(64) | ||
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) | ||
self.norm3 = nn.BatchNorm2d(64) | ||
|
||
self.flatten = nn.Flatten() | ||
|
||
self.fc = nn.Linear(64 * 5 * 5, 512) | ||
self.output = nn.Linear(512, NUM_ACTIONS) | ||
|
||
# TODO: Maybe use RMSProp? | ||
self.optimizer = optim.Adam(self.parameters(), lr=LEARNING_RATE) | ||
self.loss = nn.MSELoss() | ||
|
||
def _forward_features(self, x): | ||
x = torch.relu(self.norm1(self.conv1(x))) | ||
x = torch.relu(self.norm2(self.conv2(x))) | ||
x = torch.relu(self.norm3(self.conv3(x))) | ||
return x | ||
|
||
def forward(self, x): | ||
x = self._forward_features(x) | ||
x = self.flatten(x) | ||
x = torch.relu(self.fc(x)) | ||
x = self.output(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import numpy as np | ||
import torch | ||
from gymnasium import make, Env | ||
from src.dqn.agent import Agent | ||
from src.dqn.constants import MEMORY_CAPACITY, MODEL_PATH, NUM_EPISODES, MODELS_PATH, UPDATE_FREQUENCY, BATCH_SIZE | ||
from src.utils.helpers import show_image, check_if_dirs_exist | ||
from src.utils.preprocessing import preprocess | ||
|
||
|
||
def reset(env: Env): | ||
state, _info = env.reset() | ||
state = preprocess(state) | ||
return state | ||
|
||
|
||
def step(env: Env, action: int): | ||
next_state, reward, terminated, truncated, info = env.step(action) | ||
next_state = preprocess(next_state) | ||
done = terminated or truncated | ||
return next_state, reward, done, info | ||
|
||
|
||
def fill_memory(env: Env, agent: Agent): | ||
for _ in range(MEMORY_CAPACITY): | ||
state = reset(env) | ||
done = False | ||
while not done: | ||
action = agent.select_action(state) | ||
next_state, reward, done, info = step(env, action) | ||
agent.replay_memory.store(state, action, reward, done, next_state) | ||
state = next_state | ||
|
||
|
||
def train( | ||
env, | ||
agent: Agent, | ||
): | ||
fill_memory(env, agent) | ||
print("Memory filled with random actions!") | ||
|
||
counter = 0 | ||
reward_history = [] | ||
best_score = -np.inf | ||
|
||
for episode in range(NUM_EPISODES): | ||
state = reset(env) | ||
done = False | ||
episode_reward = 0 | ||
|
||
while not done: | ||
action = agent.select_action(state) | ||
next_state, reward, done, info = step(env, action) | ||
agent.replay_memory.store(state, action, reward, done, next_state) | ||
agent.please_learn() | ||
|
||
if counter % UPDATE_FREQUENCY == 0: | ||
agent.update_target_net() | ||
|
||
state = next_state | ||
episode_reward += reward | ||
counter += 1 | ||
|
||
agent.decay_epsilon() | ||
reward_history.append(episode_reward) | ||
|
||
current_avg_score = np.mean(reward_history[-20:]) # moving average over last 20 episodes | ||
|
||
print( | ||
f"Episode: {episode + 1}, Reward: {episode_reward}, Avg. Reward: {current_avg_score}, Epsilon: {agent.epsilon}" | ||
) | ||
|
||
if current_avg_score > best_score: | ||
best_score = current_avg_score | ||
check_if_dirs_exist([MODELS_PATH]) | ||
torch.save(agent.policy_net.state_dict(), MODEL_PATH) | ||
|
||
|
||
if __name__ == "__main__": | ||
env: Env = make("ALE/Skiing-v5", max_episode_steps=1000) | ||
agent = Agent(action_space=env.action_space) | ||
train(env, agent) | ||
agent.save() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from random import sample | ||
import torch | ||
import numpy as np | ||
|
||
from src.dqn.constants import MEMORY_CAPACITY, BATCH_SIZE | ||
|
||
|
||
class ReplayMemory: | ||
def __init__(self): | ||
self.capacity = MEMORY_CAPACITY | ||
self.states = [] | ||
self.actions = [] | ||
self.rewards = [] | ||
self.dones = [] | ||
self.next_states = [] | ||
self.index: int = 0 | ||
|
||
def store(self, state, action, reward, done, next_state): | ||
if len(self.states) < self.capacity: | ||
self.states.append(state) | ||
self.actions.append(action) | ||
self.rewards.append(reward) | ||
self.dones.append(done) | ||
self.next_states.append(next_state) | ||
else: | ||
self.states[self.index] = state | ||
self.actions[self.index] = action | ||
self.rewards[self.index] = reward | ||
self.dones[self.index] = done | ||
self.next_states[self.index] = next_state | ||
|
||
self.index = (self.index + 1) % self.capacity | ||
|
||
def sample(self, device): | ||
indices_to_sample = sample(range(len(self)), BATCH_SIZE) | ||
states = torch.from_numpy(np.array(self.states)[indices_to_sample]).float().to(device) | ||
actions = torch.from_numpy(np.array(self.actions)[indices_to_sample]).to(device).reshape((-1, 1)) | ||
rewards = torch.from_numpy(np.array(self.rewards)[indices_to_sample]).float().to(device).reshape((-1, 1)) | ||
dones = torch.from_numpy(np.array(self.dones)[indices_to_sample]).to(device).reshape((-1, 1)) | ||
next_states = torch.from_numpy(np.array(self.next_states)[indices_to_sample]).float().to(device) | ||
|
||
return states, actions, rewards, dones, next_states | ||
|
||
def __len__(self): | ||
return len(self.states) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import numpy as np | ||
import gymnasium | ||
|
||
env = gymnasium.make("ALE/Skiing-v5", render_mode="human") | ||
env.metadata["render_fps"] = 60 | ||
|
||
|
||
def run_episode(env, policy, render=False, max_steps=10000): | ||
"""Run a single episode with the given policy""" | ||
obs = env.reset() | ||
obs = obs[0] | ||
for _ in range(max_steps): | ||
if render: | ||
env.render() | ||
# time.sleep(0.5) | ||
action = policy(obs) | ||
next_state, reward, terminated, truncated, info = env.step(action) | ||
obs = next_state | ||
env.render() | ||
|
||
|
||
def random_policy(obs): | ||
"""A random policy for the Skiing environment""" | ||
return np.random.randint(0, 3) | ||
|
||
|
||
run_episode(env, random_policy, render=True, max_steps=10000) |
Oops, something went wrong.