Skip to content

Commit

Permalink
Fix numpy VI
Browse files Browse the repository at this point in the history
  • Loading branch information
magnusja committed Nov 29, 2017
1 parent c603c6b commit 49ee38a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
21 changes: 12 additions & 9 deletions deep_maxent_irl.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def deep_maxent_irl(feat_map, P_a, gamma, trajs, lr, n_iters, sparse):
N_STATES, _, N_ACTIONS = np.shape(P_a)

# init nn model
nn_r = DeepIRLFC(feat_map.shape[1], N_ACTIONS, lr, len(trajs[0]), 3, 3, deterministic=False, sparse=sparse)
nn_r = DeepIRLFC(feat_map.shape[1], N_ACTIONS, lr, len(trajs[0]), 3, 3, deterministic=True, sparse=sparse)

# find state visitation frequencies using demonstrations
mu_D = demo_svf(trajs, N_STATES)
Expand All @@ -330,30 +330,33 @@ def deep_maxent_irl(feat_map, P_a, gamma, trajs, lr, n_iters, sparse):
rewards = nn_r.get_rewards(feat_map)

# compute policy
#_, policy = value_iteration.value_iteration(P_a, rewards, gamma, error=0.01, deterministic=False)
#_, policy = value_iteration.value_iteration(P_a, rewards, gamma, error=0.01, deterministic=True)

# compute rewards and policy at the same time
#t = time.time()
#rewards, _, policy = nn_r.get_policy(feat_map, P_a_t, gamma, 0.01)
#print('tensorflow VI', time.time() - t)

# compute expected svf
#mu_exp = compute_state_visition_freq(P_a, gamma, trajs, policy, deterministic=False)
#mu_exp = compute_state_visition_freq(P_a, gamma, trajs, policy, deterministic=True)

rewards, _, policy, mu_exp = nn_r.get_policy_svf(feat_map, P_a_t, gamma, p_start_state, 0.000001)
rewards, values, policy, mu_exp = nn_r.get_policy_svf(feat_map, P_a_t, gamma, p_start_state, 0.000001)
# compute gradients on rewards:
grad_r = mu_D - mu_exp

assert_values, assert_policy = value_iteration.value_iteration(P_a, rewards, gamma, error=0.000001, deterministic=False)
assert_values_old, assert_policy_old = value_iteration.value_iteration_old(P_a, rewards, gamma, error=0.000001, deterministic=False)
assert_values, assert_policy = value_iteration.value_iteration(P_a, rewards, gamma, error=0.000001, deterministic=True)
assert_values_old, assert_policy_old = value_iteration.value_iteration_old(P_a, rewards, gamma, error=0.000001, deterministic=True)

assert (np.abs(assert_values - assert_values_old) < 0.0001).all()
assert (np.abs(assert_policy - assert_policy) < 0.0001).all()
assert (np.abs(values - assert_values) < 0.0001).all()
assert (np.abs(values - assert_values_old) < 0.0001).all()

assert (np.abs(assert_policy - assert_policy_old) < 0.0001).all()
assert (np.abs(policy - assert_policy) < 0.001).all()
assert (np.abs(policy - assert_policy_old) < 0.001).all()

assert (np.abs(mu_exp - compute_state_visition_freq(P_a, gamma, trajs, policy, deterministic=False)) < 0.00001).all()
assert (np.abs(mu_exp - compute_state_visition_freq_old(P_a, gamma, trajs, policy, deterministic=False)) < 0.00001).all()
assert (np.abs(mu_exp - compute_state_visition_freq(P_a, gamma, trajs, policy, deterministic=True)) < 0.00001).all()
assert (np.abs(mu_exp - compute_state_visition_freq_old(P_a, gamma, trajs, policy, deterministic=True)) < 0.00001).all()

# apply gradients to the neural network
grad_theta, l2_loss, grad_norm = nn_r.apply_grads(feat_map, grad_r)
Expand Down
5 changes: 4 additions & 1 deletion mdp/value_iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def value_iteration(P_a, rewards, gamma, error=0.01, deterministic=True):
values_tmp = values.copy()

def step(start, end):
values[start:end] = (P[start:end, :, :] * (rewards + gamma * values_tmp)).sum(axis=2).max(axis=1)
tmp = rewards[start:end, np.newaxis].repeat(N_STATES, axis=1) + gamma * values_tmp
tmp = tmp[:, :, np.newaxis].repeat(N_ACTIONS, axis=2)
tmp = np.transpose(tmp, (0, 2, 1))
values[start:end] = (P[start:end, :, :] * tmp).sum(axis=2).max(axis=1)

with ThreadPoolExecutor(max_workers=num_cpus) as e:
futures = list()
Expand Down

0 comments on commit 49ee38a

Please sign in to comment.