diff --git a/deep_maxent_irl.py b/deep_maxent_irl.py index bc44959..b641b1a 100644 --- a/deep_maxent_irl.py +++ b/deep_maxent_irl.py @@ -16,26 +16,29 @@ class DeepIRLFC: - def __init__(self, n_input, n_actions, lr, n_h1=400, n_h2=300, l2=10, name='deep_irl_fc'): + def __init__(self, n_input, n_actions, lr, n_h1=400, n_h2=300, l2=10, sparse=False, name='deep_irl_fc'): self.n_input = n_input self.lr = lr self.n_h1 = n_h1 self.n_h2 = n_h2 self.name = name + self.sparse = sparse - config = tf.ConfigProto() - config.gpu_options.allow_growth = True - self.sess = tf.Session(config=config) + self.sess = tf.Session() self.input_s, self.reward, self.theta = self._build_network(self.name) # value iteration - self.P_a = tf.placeholder(tf.float32, shape=(n_input, n_actions, n_input)) + if sparse: + self.P_a = tf.sparse_placeholder(tf.float32, shape=(n_input, n_actions, n_input)) + else: + self.P_a = tf.placeholder(tf.float32, shape=(n_input, n_actions, n_input)) + self.gamma = tf.placeholder(tf.float32) self.epsilon = tf.placeholder(tf.float32) self.values, self.policy = self._vi(self.reward) self.optimizer = tf.train.GradientDescentOptimizer(lr) - + self.grad_r = tf.placeholder(tf.float32, [n_input, 1]) self.l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in self.theta]) self.grad_l2 = tf.gradients(self.l2_loss, self.theta) @@ -67,7 +70,12 @@ def _vi(self, rewards): def body(i, c, t): old_values = t.read(i) - new_values = tf.reduce_max(tf.reduce_sum(self.P_a * (rewards + self.gamma * old_values), axis=2), axis=1) + if self.sparse: + new_values = tf.sparse_reduce_max( + tf.sparse_reduce_sum_sparse(self.P_a * (rewards + self.gamma * old_values), axis=2), axis=1) + else: + new_values = tf.reduce_max(tf.reduce_sum(self.P_a * (rewards + self.gamma * old_values), axis=2), axis=1) + c = tf.reduce_max(tf.abs(new_values - old_values)) > self.epsilon c.set_shape(()) t = t.write(i + 1, new_values) @@ -82,7 +90,10 @@ def condition(i, c, t): i, _, values = tf.while_loop(condition, body, [0, True, t], parallel_iterations=1, back_prop=False, name='VI_loop') values = values.read(i) - policy = tf.argmax(tf.reduce_sum(self.P_a * (rewards + self.gamma * values), axis=2), axis=1) + if self.sparse: + policy = tf.argmax(tf.sparse_tensor_to_dense(tf.sparse_reduce_sum_sparse(self.P_a * (rewards + self.gamma * values), axis=2)), axis=1) + else: + policy = tf.argmax(tf.reduce_sum(self.P_a * (rewards + self.gamma * values), axis=2), axis=1) return values, policy @@ -100,14 +111,14 @@ def get_policy(self, states, P_a, gamma, epsilon=0.01): def apply_grads(self, feat_map, grad_r): grad_r = np.reshape(grad_r, [-1, 1]) feat_map = np.reshape(feat_map, [-1, self.n_input]) - _, grad_theta, l2_loss, grad_norms = self.sess.run([self.optimize, self.grad_theta, self.l2_loss, self.grad_norms], + _, grad_theta, l2_loss, grad_norms = self.sess.run([self.optimize, self.grad_theta, self.l2_loss, self.grad_norms], feed_dict={self.grad_r: grad_r, self.input_s: feat_map}) return grad_theta, l2_loss, grad_norms def compute_state_visition_freq(P_a, gamma, trajs, policy, deterministic=True): - """compute the expected states visition frequency p(s| theta, T) + """compute the expected states visition frequency p(s| theta, T) using dynamic programming inputs: @@ -116,7 +127,7 @@ def compute_state_visition_freq(P_a, gamma, trajs, policy, deterministic=True): trajs list of list of Steps - collected from expert policy Nx1 vector (or NxN_ACTIONS if deterministic=False) - policy - + returns: p Nx1 vector - state visitation frequencies """ @@ -161,11 +172,11 @@ def step(t, start, end): def demo_svf(trajs, n_states): """ compute state visitation frequences from demonstrations - + input: trajs list of list of Steps - collected from expert returns: - p Nx1 vector - state visitation frequences + p Nx1 vector - state visitation frequences """ p = np.zeros(n_states) @@ -175,7 +186,7 @@ def demo_svf(trajs, n_states): p = p/len(trajs) return p -def deep_maxent_irl(feat_map, P_a, gamma, trajs, lr, n_iters): +def deep_maxent_irl(feat_map, P_a, gamma, trajs, lr, n_iters, sparse): """ Maximum Entropy Inverse Reinforcement Learning (Maxent IRL) @@ -194,16 +205,22 @@ def deep_maxent_irl(feat_map, P_a, gamma, trajs, lr, n_iters): """ # tf.set_random_seed(1) - + N_STATES, _, N_ACTIONS = np.shape(P_a) # init nn model - nn_r = DeepIRLFC(feat_map.shape[1], N_ACTIONS, lr, 3, 3) + nn_r = DeepIRLFC(feat_map.shape[1], N_ACTIONS, lr, 3, 3, sparse=sparse) # find state visitation frequencies using demonstrations mu_D = demo_svf(trajs, N_STATES) - # training + P_a_t = P_a.transpose(0, 2, 1) + if sparse: + mask = P_a_t > 0 + indices = np.argwhere(mask) + P_a_t = tf.SparseTensorValue(indices, P_a_t[mask], P_a_t.shape) + + # training for iteration in range(n_iters): if iteration % (n_iters/10) == 0: print 'iteration: {}'.format(iteration) @@ -216,7 +233,7 @@ def deep_maxent_irl(feat_map, P_a, gamma, trajs, lr, n_iters): # compute rewards and policy at the same time t = time.time() - rewards, _, policy = nn_r.get_policy(feat_map, P_a.transpose(0, 2, 1), gamma, 0.01) + rewards, _, policy = nn_r.get_policy(feat_map, P_a_t, gamma, 0.01) print('tensorflow VI', time.time() - t) # compute expected svf diff --git a/deep_maxent_irl_gridworld.py b/deep_maxent_irl_gridworld.py index 720b379..845dd2e 100644 --- a/deep_maxent_irl_gridworld.py +++ b/deep_maxent_irl_gridworld.py @@ -27,6 +27,7 @@ PARSER.set_defaults(rand_start=True) PARSER.add_argument('-lr', '--learning_rate', default=0.02, type=float, help='learning rate') PARSER.add_argument('-ni', '--n_iters', default=20, type=int, help='number of iterations') +PARSER.add_argument('-s', '--sparse', default=False, action='store_true', help='flag to use sparse tensors in tf') ARGS = PARSER.parse_args() print ARGS @@ -100,7 +101,7 @@ def main(): print 'Deep Max Ent IRL training ..' t = time.time() - rewards = deep_maxent_irl(feat_map, P_a, GAMMA, trajs, LEARNING_RATE, N_ITERS) + rewards = deep_maxent_irl(feat_map, P_a, GAMMA, trajs, LEARNING_RATE, N_ITERS, ARGS.sparse) print('time for dirl', time.time() - t) values, _ = value_iteration.value_iteration(P_a, rewards, GAMMA, error=0.01, deterministic=True)