Skip to content

Commit

Permalink
Merge pull request #6 from yaacov/fix-train
Browse files Browse the repository at this point in the history
Fix-train
  • Loading branch information
yaacov authored Jan 24, 2024
2 parents 125c0a6 + 1dad3dd commit bbb9c38
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
Binary file modified checkpoints/driver.pth
Binary file not shown.
29 changes: 16 additions & 13 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,6 @@

from model import DriverModel, action_to_outputs, actions, obstacles, view_to_inputs

OBSTACLE_TO_INDEX = {
"": 0,
"crack": 1,
"trash": 2,
"penguin": 3,
"bike": 4,
"water": 5,
"barrier": 6,
}

# Training parameters
num_epochs = 0
batch_size = 0
Expand All @@ -61,12 +51,25 @@ def generate_obstacle_array():
Returns:
list[list[str]]: 4x6 2D array with random obstacles.
"""
OBSTACLE_TO_INDEX = {
"": 0,
"crack": 1,
"trash": 2,
"penguin": 3,
"bike": 4,
"water": 5,
"barrier": 6,
}

array = [["" for _ in range(6)] for _ in range(4)]

for i in range(4):
obstacle = random.choice(list(OBSTACLE_TO_INDEX.keys()))
position = random.randint(0, 5)
# lane A
position = random.randint(0, 2)
array[i][position] = obstacle
# lane B
array[i][3 + position] = obstacle

return array

Expand Down Expand Up @@ -172,10 +175,10 @@ def main():
"--checkpoint-out", default="", help="Path to the output checkpoint file."
)
parser.add_argument(
"--num-epochs", type=int, default=100, help="Number of epochs for training."
"--num-epochs", type=int, default=30, help="Number of epochs for training."
)
parser.add_argument(
"--batch-size", type=int, default=250, help="Batch size for training."
"--batch-size", type=int, default=200, help="Batch size for training."
)
parser.add_argument(
"--learning-rate", type=float, default=0.001, help="Learning rate for training."
Expand Down

0 comments on commit bbb9c38

Please sign in to comment.