-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcnn.py
40 lines (30 loc) · 1.42 KB
/
cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# -*- coding: utf-8 -*-
import tensorflow as tf
from keras.layers import (Activation, Conv2D, Dense, Flatten, Input)
from keras.layers.merge import dot
from keras.models import Model
from keras.optimizers import Adam
import keras.backend as K
'''
Code reference from:
"Beat Atari with Deep Reinforcement Learning!" by Adrien Lucas Ecoffet
Following is the link:
https://becominghuman.ai/lets-build-an-atari-ai-part-1-dqn-df57e8ff3b26
'''
def nn_model(frame=4, input_shape=[5,5], num_actions=5):
with tf.name_scope('deep_q_network'):
with tf.name_scope('input'):
# 5*5*4
input_state = Input(shape=(frame, input_shape[0], input_shape[1]))
input_action = Input(shape=(num_actions,))
with tf.name_scope('fc2'):
flattened = Flatten()(input_state)
dense2 = Dense(128, kernel_initializer='glorot_uniform', activation='relu')(flattened)
with tf.name_scope('output'):
q_values = Dense(num_actions,activation=None)(dense2)
q_v = dot([q_values, input_action], axes=1)
network_model = Model(inputs=[input_state, input_action], outputs=q_v) #方案1,输入state,action,输出一个q_value
q_values_func = K.function([input_state], [q_values]) #方案2,输入一个state,输出一系列[action,q_value]
network_model.summary()
return network_model, q_values_func
nn_model()