diff --git a/gsp_rl/src/actors/__pycache__/learning_aids.cpython-310.pyc b/gsp_rl/src/actors/__pycache__/learning_aids.cpython-310.pyc index ab3a7ae..373ff89 100644 Binary files a/gsp_rl/src/actors/__pycache__/learning_aids.cpython-310.pyc and b/gsp_rl/src/actors/__pycache__/learning_aids.cpython-310.pyc differ diff --git a/gsp_rl/src/actors/learning_aids.py b/gsp_rl/src/actors/learning_aids.py index a64e0f4..6d9e36e 100644 --- a/gsp_rl/src/actors/learning_aids.py +++ b/gsp_rl/src/actors/learning_aids.py @@ -157,8 +157,7 @@ def TD3_choose_action(self, observation, networks, n_actions): return mu_prime.unsqueeze(0).cpu().detach().numpy() def Attention_choose_action(self, observation, networks): - state = T.tensor(observation, dtype = T.float).to(networks['attention'].device) - return networks['attention'](state).cpu().detach().numpy() + return networks['attention'](observation).cpu().detach().numpy() def learn_DQN(self, networks): diff --git a/pyproject.toml b/pyproject.toml index abdbdb3..87c72a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gsp-rl" -version = "1.1.4" +version = "1.1.5" description = "Stable with config support" authors = ["jdbloom"] readme = "README.md" diff --git a/tests/test_learning_aids/__pycache__/test_base_choose_action.cpython-310-pytest-7.4.4.pyc b/tests/test_learning_aids/__pycache__/test_base_choose_action.cpython-310-pytest-7.4.4.pyc index c9ad69c..218e23c 100644 Binary files a/tests/test_learning_aids/__pycache__/test_base_choose_action.cpython-310-pytest-7.4.4.pyc and b/tests/test_learning_aids/__pycache__/test_base_choose_action.cpython-310-pytest-7.4.4.pyc differ diff --git a/tests/test_learning_aids/test_base_choose_action.py b/tests/test_learning_aids/test_base_choose_action.py index 9db27ac..e6f87e1 100644 --- a/tests/test_learning_aids/test_base_choose_action.py +++ b/tests/test_learning_aids/test_base_choose_action.py @@ -155,5 +155,5 @@ def test_Attention_choose_action(): } networks = NA.make_Attention_Encoder(nn_args) random_input = np.random.uniform(0, 1, (1, nn_args['max_length'],nn_args['input_size'])) - + random_input = T.Tensor(random_input).to(networks['attention'].device) assert(tuple(NA.Attention_choose_action(random_input, networks).shape) == (1, nn_args['output_size'])) \ No newline at end of file