You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to convert internal OLMo checkpoints to Huggingface format so I can analyze the attention matrices. The reason I need Huggingface format is that @2015aroras implemented a very useful flag output_attentions in the Huggingface API to retrieve attention matrices from the forward pass, but this is not supported natively in the OLMo repo. The conversion script seems to work for converting the checkpoint to Hugginface (see below), but when I try to load the Huggingface checkpoint, I get an error message.
Minimal Example
I'd like to be able to handle many different checkpoints, but for purposes of reproducing the bug I'll focus ons3://ai2-llm/checkpoints/OLMo-medium/mitchish7/step0. I've downloaded this locally and unsharded it using:
However, the Huggingface model is there, and after inspecting the script, I think everything should have run correctly besides the last line which deletes the temporary scratch directory. See:
Am I converting the OLMo checkpoint to Huggingface incorrectly? Or should I be using it in a way besides passing the path to from_pretrained? Or maybe there's just a version issue here and I shouldn't expect to be able to load arbitrary checkpoints in Huggingface? If the latter, it would be nice if the error messages could better indicate the incompatibility. I would also like to know potential workarounds for retrieving attention matrices natively in the OLMo repo without having to convert to Huggingface.
Regarding the failure of deleting temp files, I'm guessing you're running on beaker. Beaker is problematic about deleting the files with python (as Oyvindt discovered), so you can skip it using --no_tmp_cleanup and delete the temp folder manually afterwards. The main stuff of the script has run successfully if you are hitting the cleanup error.
Your issue is that from hf_olmo import OLMoForCausalLM is the "old-style" of OLMo HF. You're trying to load "new-style" checkpoints with the old OLMo HF checkpoints. Using from transformers import AutoModelForCausalLM and AutoModelForCausalLM.from_pretrained should make things work for you (and you can use this for old-style checkpoints too, as long as you do import hf_olmo first).
NB: if you convert a checkpoint not compatible with OLMo 1 or 1.7, then the converter may silent fail, producing a checkpoint that produces incorrect outputs. Then you'll need to use old-style HF OLMo checkpoints, but this does not have output_attentions implemented yet. See https://github.com/allenai/OLMo/blob/main/docs/Checkpoints.md for more details about types of checkpoints.
🐛 Describe the bug
I'm trying to convert internal OLMo checkpoints to Huggingface format so I can analyze the attention matrices. The reason I need Huggingface format is that @2015aroras implemented a very useful flag
output_attentions
in the Huggingface API to retrieve attention matrices from the forward pass, but this is not supported natively in the OLMo repo. The conversion script seems to work for converting the checkpoint to Hugginface (see below), but when I try to load the Huggingface checkpoint, I get an error message.Minimal Example
I'd like to be able to handle many different checkpoints, but for purposes of reproducing the bug I'll focus on
s3://ai2-llm/checkpoints/OLMo-medium/mitchish7/step0
. I've downloaded this locally and unsharded it using:I then run scripts/convert_olmo_to_hf_new.py to convert to Huggingface format:
This fails with an error message when trying to delete the temporary files after converting to the Huggingface format:
However, the Huggingface model is there, and after inspecting the script, I think everything should have run correctly besides the last line which deletes the temporary scratch directory. See:
OLMo/scripts/convert_olmo_to_hf_new.py
Line 191 in d423c11
Because I think the output of the Huggingface script is correct despite the error message, I proceeded with trying to load in Huggingface as follows:
This throws the following error:
Possible Diagnoses
Am I converting the OLMo checkpoint to Huggingface incorrectly? Or should I be using it in a way besides passing the path to
from_pretrained
? Or maybe there's just a version issue here and I shouldn't expect to be able to load arbitrary checkpoints in Huggingface? If the latter, it would be nice if the error messages could better indicate the incompatibility. I would also like to know potential workarounds for retrieving attention matrices natively in the OLMo repo without having to convert to Huggingface.Versions
Python 3.10.9
pip freeze
accelerate==0.32.1 -e git+https://github.com/allenai/OLMo@d423c11#egg=ai2_olmo -e git+https://github.com/allenai/OLMo-core@eb56a9f0c2f63cf2e79e90da878a00d1a282cec9#egg=ai2_olmo_core aiohttp==3.9.5 aiosignal==1.3.1 alabaster==0.7.16 annotated-types==0.6.0 antlr4-python3-runtime==4.9.3 anyio==3.7.0 argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 arrow==1.2.3 asttokens==2.2.1 async-timeout==4.0.3 attrs==23.1.0 Babel==2.15.0 backcall==0.2.0 backoff==2.1.2 backports.tarfile==1.2.0 beaker-gantry==1.1.0 beaker-py==1.26.14 beautifulsoup4==4.12.2 black==23.12.1 bleach==6.0.0 blessed==1.20.0 blinker==1.8.2 boltons==24.0.0 boto3==1.34.96 boto3-extensions==0.23.0 botocore==1.34.96 brotlipy==0.7.0 build==1.2.1 cached_path==1.6.2 cachetools==5.3.3 certifi @ file:///croot/certifi_1671487769961/work/certifi cffi @ file:///croot/cffi_1670423208954/work charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work click==8.1.7 click-aliases==1.0.4 click-help-colors==0.9.1 cloudpickle==3.0.0 cmake==3.29.2 colorama==0.4.6 comm==0.1.3 conda==23.1.0 conda-content-trust @ file:///tmp/abs_5952f1c8-355c-4855-ad2e-538535021ba5h26t22e5/croots/recipe/conda-content-trust_1658126371814/work conda-package-handling @ file:///croot/conda-package-handling_1672865015732/work conda_package_streaming @ file:///croot/conda-package-streaming_1670508151586/work contourpy==1.2.1 cryptography @ file:///croot/cryptography_1673298753778/work cycler==0.12.1 datasets==2.7.1 dateparser==1.2.0 debugpy==1.6.7 decorator==5.1.1 defusedxml==0.7.1 dill==0.3.6 diskcache==5.6.3 distro==1.9.0 docker==6.1.3 docker-pycreds==0.4.0 docutils==0.20.1 evaluate==0.4.2 exceptiongroup==1.1.1 execnet==2.1.1 executing==1.2.0 face==20.1.1 fastapi==0.110.3 fastjsonschema==2.17.1 filelock==3.13.4 Flask==3.0.3 fonttools==4.51.0 fqdn==1.5.1 frozenlist==1.4.1 fsspec==2024.3.1 ftfy==6.2.0 furo==2023.5.20 gitdb==4.0.10 GitPython==3.1.31 glom==23.5.0 google-api-core==2.19.0 google-auth==2.30.0 google-cloud-core==2.4.1 google-cloud-storage==2.17.0 google-crc32c==1.5.0 google-resumable-media==2.7.1 googleapis-common-protos==1.63.1 gpustat==1.1 h11==0.14.0 halo==0.0.31 httpcore==1.0.5 httptools==0.6.1 httpx==0.27.0 huggingface-hub==0.21.4 idna @ file:///croot/idna_1666125576474/work imagesize==1.4.1 importlib_metadata==7.2.0 importlib_resources==6.4.0 iniconfig==2.0.0 interegular==0.3.3 ipykernel==6.23.2 ipython==8.14.0 ipython-genutils==0.2.0 ipywidgets==8.0.6 isodate==0.6.1 isoduration==20.11.0 isort==5.12.0 itsdangerous==2.2.0 jaraco.classes==3.4.0 jaraco.context==5.3.0 jaraco.functools==4.0.1 jedi==0.18.2 jeepney==0.8.0 Jinja2==3.1.2 jmespath==1.0.1 joblib==1.4.0 jsonpointer==2.3 jsonschema==4.17.3 jupyter==1.0.0 jupyter-console==6.6.3 jupyter-events==0.6.3 jupyter_client==8.2.0 jupyter_core==5.3.1 jupyter_server==2.6.0 jupyter_server_terminals==0.4.4 jupyterlab-pygments==0.2.2 jupyterlab-widgets==3.0.7 keyring==25.2.1 kiwisolver==1.4.5 lark==1.1.9 lightning-utilities==0.11.3.post0 livereload==2.6.3 llvmlite==0.42.0 lm-format-enforcer==0.9.8 log-symbols==0.0.14 markdown-it-py==3.0.0 MarkupSafe==2.1.3 matplotlib==3.8.4 matplotlib-inline==0.1.6 maturin==1.5.1 mdit-py-plugins==0.4.1 mdurl==0.1.2 mistune==2.0.5 more-itertools==10.3.0 mpmath==1.3.0 msgpack==1.0.8 msgspec==0.18.6 multidict==6.0.5 multiprocess==0.70.14 mypy==1.3.0 mypy-extensions==1.0.0 myst-parser==2.0.0 nbclassic==1.0.0 nbclient==0.8.0 nbconvert==7.5.0 nbformat==5.9.0 necessary==0.4.3 nest-asyncio==1.5.6 networkx==3.3 nh3==0.2.17 ninja==1.11.1.1 notebook==6.5.4 notebook_shim==0.2.3 numba==0.59.1 numpy==1.26.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-ml-py==11.525.112 nvidia-nccl-cu12==2.18.1 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.1.105 omegaconf==2.3.0 oocmap==0.3 openai==1.29.0 outlines==0.0.34 overrides==7.3.1 packaging==23.1 pandas==2.2.2 pandocfilters==1.5.0 parso==0.8.3 pathspec==0.12.1 petname==2.6 pexpect==4.8.0 pickleshare==0.7.5 pillow==10.3.0 pkginfo==1.11.1 platformdirs==3.5.3 pluggy==1.5.0 prometheus-fastapi-instrumentator==7.0.0 prometheus_client==0.20.0 prompt-toolkit==3.0.38 proto-plus==1.23.0 protobuf==4.25.3 psutil==5.9.5 ptyprocess==0.7.0 pure-eval==0.2.2 py-cpuinfo==9.0.0 pyarrow==16.0.0 pyasn1==0.6.0 pyasn1_modules==0.4.0 pycosat @ file:///croot/pycosat_1666805502580/work pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work pydantic==2.7.1 pydantic_core==2.18.2 Pygments==2.15.1 pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work pyparsing==3.1.2 pyproject_hooks==1.1.0 pyrsistent==0.19.3 PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work pytest==8.2.2 pytest-sphinx==0.6.3 pytest-xdist==3.6.1 python-dateutil==2.8.2 python-dotenv==1.0.1 python-json-logger==2.0.7 pytz==2024.1 PyYAML==6.0 pyzmq==25.1.0 qtconsole==5.4.3 QtPy==2.3.1 ray==2.12.0 readme_renderer==43.0 referencing==0.35.0 regex==2024.4.28 requests @ file:///opt/conda/conda-bld/requests_1657734628632/work requests-toolbelt==1.0.0 requirements-parser==0.9.0 responses==0.18.0 rfc3339-validator==0.1.4 rfc3986==2.0.0 rfc3986-validator==0.1.1 rich==13.4.2 rpds-py==0.18.0 rsa==4.9 ruamel.yaml @ file:///croot/ruamel.yaml_1666304550667/work ruamel.yaml.clib @ file:///croot/ruamel.yaml.clib_1666302247304/work ruff==0.4.10 rusty-dawg @ file:///home/willm/rusty-dawg/bindings/python/target/wheels/rusty_dawg-0.1.0-cp310-cp310-manylinux_2_31_x86_64.whl#sha256=b780f6524d32a76e9dfff4137425f799401afc86deed08bbe1c73ac4885e3baf s3transfer==0.10.1 safetensors==0.4.3 scikit-learn==1.5.1 scipy==1.13.0 seaborn==0.13.2 SecretStorage==3.3.3 Send2Trash==1.8.2 sentencepiece==0.2.0 sentry-sdk==2.7.1 setproctitle==1.3.3 six @ file:///tmp/build/80754af9/six_1644875935023/work smart-open==7.0.4 smashed==0.21.5 smmap==5.0.0 sniffio==1.3.0 snowballstemmer==2.2.0 soupsieve==2.4.1 Sphinx==7.0.1 sphinx-autobuild==2021.3.14 sphinx-autodoc-typehints==1.23.3 sphinx-basic-ng==1.0.0b2 sphinx-copybutton==0.5.2 sphinxcontrib-applehelp==1.0.8 sphinxcontrib-devhelp==1.0.6 sphinxcontrib-htmlhelp==2.0.5 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.7 sphinxcontrib-serializinghtml==1.1.10 spinners==0.0.24 stack-data==0.6.2 starlette==0.37.2 sympy==1.12 tabulate==0.9.0 termcolor==2.4.0 terminado==0.17.1 threadpoolctl==3.5.0 tiktoken==0.6.0 tinycss2==1.2.1 tokenizers==0.19.1 tomli==2.0.1 toolz @ file:///croot/toolz_1667464077321/work torch==2.1.2 torchmetrics==1.4.0.post0 tornado==6.3.2 tqdm @ file:///opt/conda/conda-bld/tqdm_1664392687731/work traitlets==5.9.0 transformers==4.40.1 triton==2.1.0 trouting==0.3.3 twine==5.1.0 typeguard==2.13.3 types-setuptools==70.1.0.20240627 typing_extensions==4.11.0 tzdata==2024.1 tzlocal==5.2 uri-template==1.2.0 urllib3 @ file:///croot/urllib3_1673575502006/work uvicorn==0.29.0 uvloop==0.19.0 vllm==0.4.2 vllm-nccl-cu12==2.18.1.0.4.0 wandb==0.17.4 watchfiles==0.21.0 wcwidth==0.2.13 webcolors==1.13 webencodings==0.5.1 websocket-client==1.5.3 websockets==12.0 Werkzeug==3.0.3 widgetsnbextension==4.0.7 wrapt==1.16.0 xformers==0.0.26.post1 xxhash==3.4.1 yarl==1.9.4 zipp==3.19.2 zstandard @ file:///opt/conda/conda-bld/zstandard_1663827383994/workThe text was updated successfully, but these errors were encountered: