Skip to content

Commit

Permalink
updating attention
Browse files Browse the repository at this point in the history
  • Loading branch information
jdbloom committed Apr 14, 2024
1 parent c82755c commit cf38c49
Show file tree
Hide file tree
Showing 5 changed files with 3 additions and 4 deletions.
Binary file modified gsp_rl/src/actors/__pycache__/learning_aids.cpython-310.pyc
Binary file not shown.
3 changes: 1 addition & 2 deletions gsp_rl/src/actors/learning_aids.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ def TD3_choose_action(self, observation, networks, n_actions):
return mu_prime.unsqueeze(0).cpu().detach().numpy()

def Attention_choose_action(self, observation, networks):
state = T.tensor(observation, dtype = T.float).to(networks['attention'].device)
return networks['attention'](state).cpu().detach().numpy()
return networks['attention'](observation).cpu().detach().numpy()


def learn_DQN(self, networks):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "gsp-rl"
version = "1.1.4"
version = "1.1.5"
description = "Stable with config support"
authors = ["jdbloom"]
readme = "README.md"
Expand Down
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/test_learning_aids/test_base_choose_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,5 +155,5 @@ def test_Attention_choose_action():
}
networks = NA.make_Attention_Encoder(nn_args)
random_input = np.random.uniform(0, 1, (1, nn_args['max_length'],nn_args['input_size']))

random_input = T.Tensor(random_input).to(networks['attention'].device)
assert(tuple(NA.Attention_choose_action(random_input, networks).shape) == (1, nn_args['output_size']))

0 comments on commit cf38c49

Please sign in to comment.