Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DQN gets stuck in iterator #949

Open
harryyy27 opened this issue Dec 17, 2024 · 2 comments
Open

DQN gets stuck in iterator #949

harryyy27 opened this issue Dec 17, 2024 · 2 comments

Comments

@harryyy27
Copy link

I've just about explored every possible angle I can think of on this problem except for my hardware (but i am not sure where I'd begin with this?) and I am not making any progress. I have a feeling it may be a compatibility issue but I am not sure? Anyway, I have written the code in the example linked below:

https://github.com/tensorflow/agents/blob/528cef7c4aedf54158a0564fdca446fe9942aa2a/docs/tutorials/1_dqn_tutorial.ipynb

more or less line for line. I ran it in VS code in a conda environment, having installed pip and then pip installed the appropriate packages within the environment. The code runs perfectly until it reaches the dataset in the reverb buffer and then it simply freezes and fails to progress any further without ever throwing an error or leaving any signs as to what might be happening. The exact same problem occurred when I ran my own version of this code in a project I am doing. Everything grinds to a halt at next(iterator)

My code can be seen below:

from __future__ import absolute_import, division, print_function
import os

os.environ['TF_USE_LEGACY_KERAS']='1'
# os.environ['TF_ENABLE_ONEDNN_OPTS=0']
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt 
import numpy as np 
import PIL.Image 
import pyvirtualdisplay
import reverb

import tensorflow as tf 

from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import py_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import sequential
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.trajectories import trajectory
from tf_agents.specs import tensor_spec
from tf_agents.utils import common

display = pyvirtualdisplay.Display(visible=0,size=(1400,900)).start()
print(tf.test.is_built_with_gpu_support())

###HYPERPARAMETERS

num_iterations = 20000

initial_collect_steps=100
collect_steps_per_iteration=1
replay_buffer_max_length=100000
batch_size=64
learning_rate=1e-3
log_interval=200
num_eval_episodes=10
eval_interval=1000

###ENVIRONMENT

env_name="CartPole-v0"
env = suite_gym.load(env_name)

env.reset()
image =PIL.Image.fromarray(env.render())
# image.show()


time_step=env.reset()
action=np.array(1,dtype=np.int32)
next_time_step=env.step(action)

train_py_env=suite_gym.load(env_name)
eval_py_env =suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

fc_layer_params = (100, 50)
action_tensor_spec=tensor_spec.from_spec(env.action_spec())
num_actions=action_tensor_spec.maximum - action_tensor_spec.minimum+1

def dense_layer(num_units):
    return tf.keras.layers.Dense(
        num_units,
        activation=tf.keras.activations.relu,
        kernel_initializer=tf.keras.initializers.VarianceScaling(
            scale=2.0,
            mode='fan_in',
            distribution='truncated_normal'
        )
    )
dense_layers=[dense_layer(num_units) for num_units in fc_layer_params]
q_values_layer=tf.keras.layers.Dense(
    num_actions,
    activation=None,
    kernel_initializer=tf.keras.initializers.RandomUniform(
        minval=-0.03,
        maxval=0.03
    ),
    bias_initializer=tf.keras.initializers.Constant(-0.2)
)
q_net=sequential.Sequential(dense_layers+[q_values_layer])

optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate)

train_step_counter=tf.Variable(0)

agent= dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter
)

agent.initialize()

eval_policy = agent.policy
collect_policy = agent.collect_policy

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),train_env.action_spec())

example_environment = tf_py_environment.TFPyEnvironment(
    suite_gym.load('CartPole-v0')
)
time_step=example_environment.reset()
random_policy.action(time_step)

def compute_avg_return(environment,policy,num_episodes=10):
    total_return =0.0
    for _ in range(num_episodes):
        time_step=environment.reset()
        episode_return =0.0

        while not time_step.is_last():
            action_step=policy.action(time_step)
            time_step=environment.step(action_step.action)
            episode_return += time_step.reward
        total_return+=episode_return

    avg_return=total_return/num_episodes
    return avg_return.numpy()[0]

compute_avg_return(eval_env,random_policy,num_eval_episodes)

table_name='uniform_table'
replay_buffer_signature = tensor_spec.from_spec(
    agent.collect_data_spec
)
replay_buffer_signature = tensor_spec.add_outer_dim(
    replay_buffer_signature
)

table=reverb.Table(
    table_name,
    max_size=replay_buffer_max_length,
    sampler=reverb.selectors.Uniform(),
    remover=reverb.selectors.Fifo(),
    rate_limiter=reverb.rate_limiters.MinSize(1),
    signature=replay_buffer_signature
)

reverb_server=reverb.Server([table])

replay_buffer= reverb_replay_buffer.ReverbReplayBuffer(
    agent.collect_data_spec,
    table_name=table_name,
    sequence_length=2,
    local_server=reverb_server
)
rb_observer=reverb_utils.ReverbAddTrajectoryObserver(
    replay_buffer.py_client,
    table_name,
    sequence_length=2
)

dataset=replay_buffer.as_dataset(
    num_parallel_calls=3,
    sample_batch_size=batch_size,
    num_steps=2
).prefetch(3)


iterator=iter(dataset)


# try:
#     %%time
# except:
#     pass

agent.train=common.function(agent.train)
agent.train_step_counter.assign(0)
avg_return = compute_avg_return(eval_env,agent.policy,num_eval_episodes)
returns =[avg_return]
time_step=train_py_env.reset()

collect_driver = py_driver.PyDriver(
    env,
    py_tf_eager_policy.PyTFEagerPolicy(
        agent.collect_policy,use_tf_function=True),
        [rb_observer],
        max_steps=collect_steps_per_iteration
    )

for _ in range(num_iterations):
    time_step,_=collect_driver.run(time_step)
    print('iterator next')
    experience,unused_info=next(iterator)###<---it gets stuck here!!
    train_loss=agent.train(experience).loss 
    step=agent.train_step_counter.numpy()
    if step % log_interval ==0:
        print('step = {0}: loss={1}'.format(step,train_loss))
    if step % eval_interval==0:
        avg_return = compute_avg_return(eval_env,agent.policy,num_eval_episodes)
        print('step = {0}: Average Return = {1}'.format(step,avg_return))
        returns.append(avg_return)
iterations=range(0,num_iterations+1, eval_interval)

the terminal output was as follows:

(tf_tutorial) harry@harry-Aspire-A315-58:~/Documents/Reinforcement Learning/tf$ python intro.py
2024-12-17 14:43:38.622260: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-17 14:43:38.624297: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-12-17 14:43:38.651939: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-17 14:43:38.651970: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-17 14:43:38.652896: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-17 14:43:38.657333: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2024-12-17 14:43:38.657493: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-12-17 14:43:39.165679: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[reverb/cc/platform/tfrecord_checkpointer.cc:162]  Initializing TFRecordCheckpointer in /tmp/tmpyir39eez.
[reverb/cc/platform/tfrecord_checkpointer.cc:565] Loading latest checkpoint from /tmp/tmpyir39eez
[reverb/cc/platform/default/server.cc:71] Started replay server on port 42883
iterator next
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (21034) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (21034) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (21034) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (21034) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (21034) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:165] Sampler and server are owned by the same process (21034) so Table uniform_table is accessed directly without gRPC.

the conda list output:

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
alsa-lib                  1.2.13               hb9d3cd8_0    conda-forge
asttokens                 3.0.0              pyhd8ed1ab_1    conda-forge
brotli                    1.1.0                hb9d3cd8_2    conda-forge
brotli-bin                1.1.0                hb9d3cd8_2    conda-forge
bzip2                     1.0.8                h4bc722e_7    conda-forge
ca-certificates           2024.12.14           hbcca054_0    conda-forge
cairo                     1.18.2               h3394656_1    conda-forge
certifi                   2024.8.30          pyhd8ed1ab_0    conda-forge
contourpy                 1.3.0            py39h74842e3_2    conda-forge
cycler                    0.12.1             pyhd8ed1ab_1    conda-forge
cyrus-sasl                2.1.27               h54b06d7_7    conda-forge
dbus                      1.13.6               h5008d03_3    conda-forge
decorator                 5.1.1              pyhd8ed1ab_1    conda-forge
double-conversion         3.3.0                h59595ed_0    conda-forge
exceptiongroup            1.2.2              pyhd8ed1ab_1    conda-forge
executing                 2.1.0              pyhd8ed1ab_1    conda-forge
expat                     2.6.4                h5888daf_0    conda-forge
font-ttf-dejavu-sans-mono 2.37                 hab24e00_0    conda-forge
font-ttf-inconsolata      3.000                h77eed37_0    conda-forge
font-ttf-source-code-pro  2.038                h77eed37_0    conda-forge
font-ttf-ubuntu           0.83                 h77eed37_3    conda-forge
fontconfig                2.15.0               h7e30c49_1    conda-forge
fonts-conda-ecosystem     1                             0    conda-forge
fonts-conda-forge         1                             0    conda-forge
fonttools                 4.55.3           py39h9399b63_0    conda-forge
freetype                  2.12.1               h267a509_2    conda-forge
graphite2                 1.3.13            h59595ed_1003    conda-forge
harfbuzz                  9.0.0                hda332d3_1    conda-forge
icu                       75.1                 he02047a_0    conda-forge
imageio                   2.4.0                    pypi_0    pypi
importlib-resources       6.4.5              pyhd8ed1ab_1    conda-forge
importlib_resources       6.4.5              pyhd8ed1ab_1    conda-forge
ipython                   8.18.1             pyh707e725_3    conda-forge
jedi                      0.19.2             pyhd8ed1ab_1    conda-forge
keras                     2.15.0                   pypi_0    pypi
keyutils                  1.6.1                h166bdaf_0    conda-forge
kiwisolver                1.4.7            py39h74842e3_0    conda-forge
krb5                      1.21.3               h659f571_0    conda-forge
lcms2                     2.16                 hb7c19ff_0    conda-forge
ld_impl_linux-64          2.43                 h712a8e2_2    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libblas                   3.9.0           25_linux64_openblas    conda-forge
libbrotlicommon           1.1.0                hb9d3cd8_2    conda-forge
libbrotlidec              1.1.0                hb9d3cd8_2    conda-forge
libbrotlienc              1.1.0                hb9d3cd8_2    conda-forge
libcblas                  3.9.0           25_linux64_openblas    conda-forge
libclang-cpp19.1          19.1.5          default_hb5137d0_0    conda-forge
libclang13                19.1.5          default_h9c6a7e4_0    conda-forge
libcups                   2.3.3                h4637d8d_4    conda-forge
libdeflate                1.22                 hb9d3cd8_0    conda-forge
libdrm                    2.4.124              hb9d3cd8_0    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libegl                    1.7.0                ha4b6fd6_2    conda-forge
libexpat                  2.6.4                h5888daf_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc                    14.2.0               h77fa898_1    conda-forge
libgcc-ng                 14.2.0               h69a702a_1    conda-forge
libgfortran               14.2.0               h69a702a_1    conda-forge
libgfortran5              14.2.0               hd5240d6_1    conda-forge
libgl                     1.7.0                ha4b6fd6_2    conda-forge
libglib                   2.82.2               h2ff4ddf_0    conda-forge
libglvnd                  1.7.0                ha4b6fd6_2    conda-forge
libglx                    1.7.0                ha4b6fd6_2    conda-forge
libgomp                   14.2.0               h77fa898_1    conda-forge
libiconv                  1.17                 hd590300_2    conda-forge
libjpeg-turbo             3.0.0                hd590300_1    conda-forge
liblapack                 3.9.0           25_linux64_openblas    conda-forge
libllvm19                 19.1.5               ha7bfdaf_0    conda-forge
liblzma                   5.6.3                hb9d3cd8_1    conda-forge
libnsl                    2.0.1                hd590300_0    conda-forge
libntlm                   1.4               h7f98852_1002    conda-forge
libopenblas               0.3.28          pthreads_h94d23a6_1    conda-forge
libopengl                 1.7.0                ha4b6fd6_2    conda-forge
libpciaccess              0.18                 hd590300_0    conda-forge
libpng                    1.6.44               hadc24fc_0    conda-forge
libpq                     17.2                 h3b95a9b_1    conda-forge
libsqlite                 3.47.2               hee588c1_0    conda-forge
libstdcxx                 14.2.0               hc0a3c3a_1    conda-forge
libstdcxx-ng              14.2.0               h4852527_1    conda-forge
libtiff                   4.7.0                hc4654cb_2    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libwebp-base              1.4.0                hd590300_0    conda-forge
libxcb                    1.17.0               h8a09558_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libxkbcommon              1.7.0                h2c5496b_1    conda-forge
libxml2                   2.13.5               h8d12d68_1    conda-forge
libxslt                   1.1.39               h76b75d6_0    conda-forge
libzlib                   1.3.1                hb9d3cd8_2    conda-forge
matplotlib                3.9.4            py39hf3d152e_0    conda-forge
matplotlib-base           3.9.4            py39h16632d1_0    conda-forge
matplotlib-inline         0.1.7              pyhd8ed1ab_1    conda-forge
ml-dtypes                 0.3.2                    pypi_0    pypi
munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
mysql-common              9.0.1                h266115a_3    conda-forge
mysql-libs                9.0.1                he0572af_3    conda-forge
ncurses                   6.5                  he02047a_1    conda-forge
numpy                     2.0.2            py39h9cb892a_1    conda-forge
openjpeg                  2.5.3                h5fbd93e_0    conda-forge
openldap                  2.6.9                he970967_0    conda-forge
openssl                   3.4.0                hb9d3cd8_0    conda-forge
packaging                 24.2               pyhd8ed1ab_2    conda-forge
parso                     0.8.4              pyhd8ed1ab_1    conda-forge
pcre2                     10.44                hba22ea6_2    conda-forge
pexpect                   4.9.0              pyhd8ed1ab_1    conda-forge
pickleshare               0.7.5           pyhd8ed1ab_1004    conda-forge
pillow                    11.0.0           py39h538c539_0    conda-forge
pip                       24.3.1             pyh8b19718_0    conda-forge
pixman                    0.44.2               h29eaf8c_0    conda-forge
prompt-toolkit            3.0.48             pyha770c72_1    conda-forge
pthread-stubs             0.4               hb9d3cd8_1002    conda-forge
ptyprocess                0.7.0              pyhd8ed1ab_1    conda-forge
pure_eval                 0.2.3              pyhd8ed1ab_1    conda-forge
pyglet                    2.0.20                   pypi_0    pypi
pygments                  2.18.0             pyhd8ed1ab_1    conda-forge
pyparsing                 3.2.0              pyhd8ed1ab_2    conda-forge
pyside6                   6.8.1            py39h0383914_0    conda-forge
python                    3.9.21          h9c0c6dc_1_cpython    conda-forge
python-dateutil           2.9.0.post0        pyhff2d567_1    conda-forge
python_abi                3.9                      5_cp39    conda-forge
pyvirtualdisplay          3.0                      pypi_0    pypi
qhull                     2020.2               h434a139_5    conda-forge
qt6-main                  6.8.1                h9d28a51_0    conda-forge
readline                  8.2                  h8228510_1    conda-forge
setuptools                75.6.0             pyhff2d567_1    conda-forge
six                       1.17.0             pyhd8ed1ab_0    conda-forge
stack_data                0.6.3              pyhd8ed1ab_1    conda-forge
tensorboard               2.15.2                   pypi_0    pypi
tensorflow                2.15.1                   pypi_0    pypi
tf-agents                 0.19.0                   pypi_0    pypi
tf-keras                  2.18.0                   pypi_0    pypi
tk                        8.6.13          noxft_h4845f30_101    conda-forge
tornado                   6.4.2            py39h8cd3c5a_0    conda-forge
traitlets                 5.14.3             pyhd8ed1ab_1    conda-forge
typing_extensions         4.12.2             pyha770c72_1    conda-forge
tzdata                    2024b                hc8b5060_0    conda-forge
unicodedata2              15.1.0           py39h8cd3c5a_1    conda-forge
wayland                   1.23.1               h3e06ad9_0    conda-forge
wcwidth                   0.2.13             pyhd8ed1ab_1    conda-forge
wheel                     0.45.1             pyhd8ed1ab_1    conda-forge
xcb-util                  0.4.1                hb711507_2    conda-forge
xcb-util-cursor           0.1.5                hb9d3cd8_0    conda-forge
xcb-util-image            0.4.0                hb711507_2    conda-forge
xcb-util-keysyms          0.4.1                hb711507_0    conda-forge
xcb-util-renderutil       0.3.10               hb711507_0    conda-forge
xcb-util-wm               0.4.2                hb711507_0    conda-forge
xkeyboard-config          2.43                 hb9d3cd8_0    conda-forge
xorg-libice               1.1.2                hb9d3cd8_0    conda-forge
xorg-libsm                1.2.5                he73a12e_0    conda-forge
xorg-libx11               1.8.10               h4f16b4b_1    conda-forge
xorg-libxau               1.0.12               hb9d3cd8_0    conda-forge
xorg-libxcomposite        0.4.6                hb9d3cd8_2    conda-forge
xorg-libxcursor           1.2.3                hb9d3cd8_0    conda-forge
xorg-libxdamage           1.1.6                hb9d3cd8_0    conda-forge
xorg-libxdmcp             1.1.5                hb9d3cd8_0    conda-forge
xorg-libxext              1.3.6                hb9d3cd8_0    conda-forge
xorg-libxfixes            6.0.1                hb9d3cd8_0    conda-forge
xorg-libxi                1.8.2                hb9d3cd8_0    conda-forge
xorg-libxrandr            1.5.4                hb9d3cd8_0    conda-forge
xorg-libxrender           0.9.12               hb9d3cd8_0    conda-forge
xorg-libxtst              1.2.5                hb9d3cd8_3    conda-forge
xorg-libxxf86vm           1.1.6                hb9d3cd8_0    conda-forge
zipp                      3.21.0             pyhd8ed1ab_1    conda-forge
zstd                      1.5.6                ha6fb4c9_0    conda-forge
@harryyy27 harryyy27 changed the title DQN gets stuck in DQN gets stuck in iterator Dec 17, 2024
@porta-logica
Copy link
Contributor

I guess that the issue is hardware related. Your computer's CPU seems to be missing capabilities AVX2 AVX512F AVX512_VNNI FMA, required by the pre-built version of Tensorflow 2.15.1. You can build TF for your specific CPU from source, however you will very likely have a degraded performance. Please note that the CPU features are unrelated to CUDA / GPU support, which you do not have either.

@harryyy27
Copy link
Author

Hi, thanks for your response. I'm thinking it might be either hardware related or down to the fact I'm using Ubuntu perhaps? Tried running it from source and I'm now getting strange errors. Usually relating to gcc. Fixing one error makes another pop up in its place. Not sure what to do. What would you recommend?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants