forked from vivek3141/super-mario-neat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcont_train.py
46 lines (39 loc) · 1.44 KB
/
cont_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import neat
import process
import gym, ppaquette_gym_super_mario
import pickle
import multiprocessing as mp
import visualize
import train
gym.logger.set_level(40)
class Train(train.Train):
def __init__(self, generations, file_name, parallel, level):
super().__init__(generations, parallel, level)
self.actions = [
[0, 0, 0, 1, 0, 1],
[0, 0, 0, 1, 1, 1],
]
self.lock = mp.Lock()
self.file_name = file_name
def _run(self, config_file, n):
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
neat.DefaultSpeciesSet, neat.DefaultStagnation,
config_file)
# p = neat.Population(config)
p = neat.Checkpointer.restore_checkpoint(self.file_name)
p.add_reporter(neat.StdOutReporter(True))
p.add_reporter(neat.Checkpointer(5))
stats = neat.StatisticsReporter()
p.add_reporter(stats)
print("loaded checkpoint...")
winner = p.run(self._eval_genomes, n)
win = p.best_genome
pickle.dump(winner, open('winner.pkl', 'wb'))
pickle.dump(win, open('real_winner.pkl', 'wb'))
visualize.draw_net(config, winner, True)
visualize.plot_stats(stats, ylog=False, view=True)
visualize.plot_species(stats, view=True)
if __name__ == "__main__":
t = Train(1000, "./Files/neat-checkpoint-2492", 2, level="1-1")
t.main()