diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..410bcd8 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,22 @@ +## What does this PR do? + + + +Fixes #\ + +## Before submitting + +- [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**? +- [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together? +- [ ] Did you list all the **breaking changes** introduced by this pull request? +- [ ] Did you **test your PR locally** with `pytest` command? +- [ ] Did you **run pre-commit hooks** with `pre-commit run -a` command? + +## Did you have fun? + +Make sure you had fun coding πŸ™ƒ diff --git a/.github/codecov.yml b/.github/codecov.yml new file mode 100644 index 0000000..c66853c --- /dev/null +++ b/.github/codecov.yml @@ -0,0 +1,15 @@ +coverage: + status: + # measures overall project coverage + project: + default: + threshold: 100% # how much decrease in coverage is needed to not consider success + + # measures PR or single commit coverage + patch: + default: + threshold: 100% # how much decrease in coverage is needed to not consider success + + + # project: off + # patch: off diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..5a861fd --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,16 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "pip" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "daily" + ignore: + - dependency-name: "pytorch-lightning" + update-types: ["version-update:semver-patch"] + - dependency-name: "torchmetrics" + update-types: ["version-update:semver-patch"] diff --git a/.github/release-drafter.yml b/.github/release-drafter.yml new file mode 100644 index 0000000..59af159 --- /dev/null +++ b/.github/release-drafter.yml @@ -0,0 +1,44 @@ +name-template: "v$RESOLVED_VERSION" +tag-template: "v$RESOLVED_VERSION" + +categories: + - title: "πŸš€ Features" + labels: + - "feature" + - "enhancement" + - title: "πŸ› Bug Fixes" + labels: + - "fix" + - "bugfix" + - "bug" + - title: "🧹 Maintenance" + labels: + - "maintenance" + - "dependencies" + - "refactoring" + - "cosmetic" + - "chore" + - title: "πŸ“οΈ Documentation" + labels: + - "documentation" + - "docs" + +change-template: "- $TITLE @$AUTHOR (#$NUMBER)" +change-title-escapes: '\<*_&' # You can add # and @ to disable mentions + +version-resolver: + major: + labels: + - "major" + minor: + labels: + - "minor" + patch: + labels: + - "patch" + default: patch + +template: | + ## Changes + + $CHANGES diff --git a/.github/workflows/code-quality-main.yaml b/.github/workflows/code-quality-main.yaml new file mode 100644 index 0000000..88b7220 --- /dev/null +++ b/.github/workflows/code-quality-main.yaml @@ -0,0 +1,22 @@ +# Same as `code-quality-pr.yaml` but triggered on commit to main branch +# and runs on all files (instead of only the changed ones) + +name: Code Quality Main + +on: + push: + branches: [main] + +jobs: + code-quality: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + + - name: Run pre-commits + uses: pre-commit/action@v2.0.3 diff --git a/.github/workflows/code-quality-pr.yaml b/.github/workflows/code-quality-pr.yaml new file mode 100644 index 0000000..e58df42 --- /dev/null +++ b/.github/workflows/code-quality-pr.yaml @@ -0,0 +1,36 @@ +# This workflow finds which files were changed, prints them, +# and runs `pre-commit` on those files. + +# Inspired by the sktime library: +# https://github.com/alan-turing-institute/sktime/blob/main/.github/workflows/test.yml + +name: Code Quality PR + +on: + pull_request: + branches: [main, "release/*", "dev"] + +jobs: + code-quality: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + + - name: Find modified files + id: file_changes + uses: trilom/file-changes-action@v1.2.4 + with: + output: " " + + - name: List modified files + run: echo '${{ steps.file_changes.outputs.files}}' + + - name: Run pre-commits + uses: pre-commit/action@v2.0.3 + with: + extra_args: --files ${{ steps.file_changes.outputs.files}} diff --git a/.github/workflows/release-drafter.yml b/.github/workflows/release-drafter.yml new file mode 100644 index 0000000..6a45e15 --- /dev/null +++ b/.github/workflows/release-drafter.yml @@ -0,0 +1,27 @@ +name: Release Drafter + +on: + push: + # branches to consider in the event; optional, defaults to all + branches: + - main + +permissions: + contents: read + +jobs: + update_release_draft: + permissions: + # write permission is required to create a github release + contents: write + # write permission is required for autolabeler + # otherwise, read permission is required at least + pull-requests: write + + runs-on: ubuntu-latest + + steps: + # Drafts your next Release notes as Pull Requests are merged into "master" + - uses: release-drafter/release-drafter@v5 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..e205ee5 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,139 @@ +name: Tests + +on: + push: + branches: [main] + pull_request: + branches: [main, "release/*", "dev"] + +jobs: + run_tests_ubuntu: + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest"] + python-version: ["3.8", "3.9", "3.10"] + + timeout-minutes: 20 + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pytest + pip install sh + + - name: List dependencies + run: | + python -m pip list + + - name: Run pytest + run: | + pytest -v + + run_tests_macos: + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: ["macos-latest"] + python-version: ["3.8", "3.9", "3.10"] + + timeout-minutes: 20 + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pytest + pip install sh + + - name: List dependencies + run: | + python -m pip list + + - name: Run pytest + run: | + pytest -v + + run_tests_windows: + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: ["windows-latest"] + python-version: ["3.8", "3.9", "3.10"] + + timeout-minutes: 20 + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pytest + + - name: List dependencies + run: | + python -m pip list + + - name: Run pytest + run: | + pytest -v + + # upload code coverage report + code-coverage: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Set up Python 3.10 + uses: actions/setup-python@v2 + with: + python-version: "3.10" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install pytest + pip install pytest-cov[toml] + pip install sh + + - name: Run tests and collect coverage + run: pytest --cov src # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 diff --git a/.gitignore b/.gitignore index 20c8c90..1bab7de 100644 --- a/.gitignore +++ b/.gitignore @@ -154,9 +154,27 @@ dmypy.json # Cython debug symbols cython_debug/ -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +### VisualStudioCode +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace +**/.vscode + +# Data & Models +*.h5 +*.tar +*.tar.gz + +# Lightning-Hydra-Template +configs/local/default.yaml +/data/ +/logs/ +src/data/outputs/ +.env +**/.DS_Store + +# Aim logging +.aim \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..ee45ce1 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,147 @@ +default_language_version: + python: python3 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + # list of supported hooks: https://pre-commit.com/hooks.html + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-docstring-first + - id: check-yaml + - id: debug-statements + - id: detect-private-key + - id: check-executables-have-shebangs + - id: check-toml + - id: check-case-conflict + - id: check-added-large-files + + # python code formatting + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + args: [--line-length, "99"] + + # python import sorting + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files"] + + # python upgrading syntax to newer version + - repo: https://github.com/asottile/pyupgrade + rev: v3.3.1 + hooks: + - id: pyupgrade + args: [--py38-plus] + + # python docstring formatting + - repo: https://github.com/myint/docformatter + rev: v1.7.4 + hooks: + - id: docformatter + args: + [ + --in-place, + --wrap-summaries=99, + --wrap-descriptions=99, + --style=sphinx, + --black, + ] + + # python docstring coverage checking + - repo: https://github.com/econchick/interrogate + rev: 1.5.0 # or master if you're bold + hooks: + - id: interrogate + args: + [ + --verbose, + --fail-under=80, + --ignore-init-module, + --ignore-init-method, + --ignore-module, + --ignore-nested-functions, + -vv, + ] + + # python check (PEP8), programming errors and code complexity + - repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + args: + [ + "--extend-ignore", + "E203,E402,E501,F401,F841,RST2,RST301", + "--exclude", + "logs/*,data/*", + ] + additional_dependencies: [flake8-rst-docstrings==0.3.0] + + # python security linter + - repo: https://github.com/PyCQA/bandit + rev: "1.7.5" + hooks: + - id: bandit + args: ["-s", "B101"] + + # yaml formatting + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v3.0.0-alpha.6 + hooks: + - id: prettier + types: [yaml] + exclude: "environment.yaml" + + # shell scripts linter + - repo: https://github.com/shellcheck-py/shellcheck-py + rev: v0.9.0.2 + hooks: + - id: shellcheck + + # md formatting + - repo: https://github.com/executablebooks/mdformat + rev: 0.7.16 + hooks: + - id: mdformat + args: ["--number"] + additional_dependencies: + - mdformat-gfm + - mdformat-tables + - mdformat_frontmatter + # - mdformat-toc + # - mdformat-black + + # word spelling linter + - repo: https://github.com/codespell-project/codespell + rev: v2.2.4 + hooks: + - id: codespell + args: + - --skip=logs/**,data/**,*.ipynb + # - --ignore-words-list=abc,def + + # jupyter notebook cell output clearing + - repo: https://github.com/kynan/nbstripout + rev: 0.6.1 + hooks: + - id: nbstripout + + # jupyter notebook linting + - repo: https://github.com/nbQA-dev/nbQA + rev: 1.6.3 + hooks: + - id: nbqa-black + args: ["--line-length=99"] + - id: nbqa-isort + args: ["--profile=black"] + - id: nbqa-flake8 + args: + [ + "--extend-ignore=E203,E402,E501,F401,F841", + "--exclude=logs/*,data/*", + ] diff --git a/.project-root b/.project-root new file mode 100644 index 0000000..63eab77 --- /dev/null +++ b/.project-root @@ -0,0 +1,2 @@ +# this file is required for inferring the project root directory +# do not delete diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100644 index 0000000..56bf7f4 --- /dev/null +++ b/configs/__init__.py @@ -0,0 +1 @@ +# this file is needed here to include configs when building project as a package diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml new file mode 100644 index 0000000..c9bf2fb --- /dev/null +++ b/configs/callbacks/default.yaml @@ -0,0 +1,22 @@ +defaults: + - model_checkpoint + - early_stopping + - model_summary + - rich_progress_bar + - _self_ + +model_checkpoint: + dirpath: ${paths.output_dir}/checkpoints + filename: "epoch_{epoch:03d}" + monitor: "val/acc" + mode: "max" + save_last: True + auto_insert_metric_name: False + +early_stopping: + monitor: "val/acc" + patience: 100 + mode: "max" + +model_summary: + max_depth: -1 diff --git a/configs/callbacks/early_stopping.yaml b/configs/callbacks/early_stopping.yaml new file mode 100644 index 0000000..c826c8d --- /dev/null +++ b/configs/callbacks/early_stopping.yaml @@ -0,0 +1,15 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html + +early_stopping: + _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: ??? # quantity to be monitored, must be specified !!! + min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement + patience: 3 # number of checks with no improvement after which training will be stopped + verbose: False # verbosity mode + mode: "min" # "max" means higher metric value is better, can be also "min" + strict: True # whether to crash the training if monitor is not found in the validation metrics + check_finite: True # when set True, stops training when the monitor becomes NaN or infinite + stopping_threshold: null # stop training immediately once the monitored quantity reaches this threshold + divergence_threshold: null # stop training as soon as the monitored quantity becomes worse than this threshold + check_on_train_epoch_end: null # whether to run early stopping at the end of the training epoch + # log_rank_zero_only: False # this keyword argument isn't available in stable version diff --git a/configs/callbacks/model_checkpoint.yaml b/configs/callbacks/model_checkpoint.yaml new file mode 100644 index 0000000..bf946e8 --- /dev/null +++ b/configs/callbacks/model_checkpoint.yaml @@ -0,0 +1,17 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html + +model_checkpoint: + _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: null # directory to save the model file + filename: null # checkpoint filename + monitor: null # name of the logged metric which determines when model is improving + verbose: False # verbosity mode + save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_top_k: 1 # save k best models (determined by above metric) + mode: "min" # "max" means higher metric value is better, can be also "min" + auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name + save_weights_only: False # if True, then only the model’s weights will be saved + every_n_train_steps: null # number of training steps between checkpoints + train_time_interval: null # checkpoints are monitored at the specified time interval + every_n_epochs: null # number of epochs between checkpoints + save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation diff --git a/configs/callbacks/model_summary.yaml b/configs/callbacks/model_summary.yaml new file mode 100644 index 0000000..b75981d --- /dev/null +++ b/configs/callbacks/model_summary.yaml @@ -0,0 +1,5 @@ +# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html + +model_summary: + _target_: lightning.pytorch.callbacks.RichModelSummary + max_depth: 1 # the maximum depth of layer nesting that the summary will include diff --git a/evaluation/__init__.py b/configs/callbacks/none.yaml similarity index 100% rename from evaluation/__init__.py rename to configs/callbacks/none.yaml diff --git a/configs/callbacks/rich_progress_bar.yaml b/configs/callbacks/rich_progress_bar.yaml new file mode 100644 index 0000000..de6f1cc --- /dev/null +++ b/configs/callbacks/rich_progress_bar.yaml @@ -0,0 +1,4 @@ +# https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html + +rich_progress_bar: + _target_: lightning.pytorch.callbacks.RichProgressBar diff --git a/configs/data/default.yaml b/configs/data/default.yaml new file mode 100644 index 0000000..fa28da1 --- /dev/null +++ b/configs/data/default.yaml @@ -0,0 +1,4 @@ +# @package _global_ + +seed: 42 +train_val_test_split: [0.8, 0.1, 0.1] \ No newline at end of file diff --git a/configs/data/mnist.yaml b/configs/data/mnist.yaml new file mode 100644 index 0000000..51bfaff --- /dev/null +++ b/configs/data/mnist.yaml @@ -0,0 +1,6 @@ +_target_: src.data.mnist_datamodule.MNISTDataModule +data_dir: ${paths.data_dir} +batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) +train_val_test_split: [55_000, 5_000, 10_000] +num_workers: 0 +pin_memory: False diff --git a/configs/data/pvr.yaml b/configs/data/pvr.yaml new file mode 100644 index 0000000..e04a86c --- /dev/null +++ b/configs/data/pvr.yaml @@ -0,0 +1,11 @@ +_target_: src.data.pvr_datamodule.PVRDataModule +batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) +train_val_test_split: [55_000, 5_000, 10_000] +num_workers: 0 +pin_memory: False + +seed: 0 +pointer_size: 1 +agg_func: "sum_mod_10" +window_size: 3 +trim_window: False \ No newline at end of file diff --git a/configs/data/read_activation_from_file.yaml b/configs/data/read_activation_from_file.yaml new file mode 100644 index 0000000..46f5cfe --- /dev/null +++ b/configs/data/read_activation_from_file.yaml @@ -0,0 +1,19 @@ +# # @package _global_ +# defaults: +# - _self_ +# - default + +_target_: src.data.read_activation_from_file_datamodule.ReadActivationFromFileDataModule + +# format: 'chunks' # 'chunks' or 'one_file' or other +# max_num_chunks: 1 # Number of chunks to read. If -1, read all chunks +format: 'one_file' +path: /dlabdata1/masani/symbolic_probing/data/activation_data/gelu-2l_blocks.1.hook_mlp_out + +batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) +num_workers: 0 +pin_memory: False +seed: ${seed} +train_val_test_split: [0.8, 0.1, 0.1] + + diff --git a/configs/debug/default.yaml b/configs/debug/default.yaml new file mode 100644 index 0000000..1886902 --- /dev/null +++ b/configs/debug/default.yaml @@ -0,0 +1,35 @@ +# @package _global_ + +# default debugging setup, runs 1 full epoch +# other debugging configs can inherit from this one + +# overwrite task name so debugging logs are stored in separate folder +task_name: "debug" + +# disable callbacks and loggers during debugging +callbacks: null +logger: null + +extras: + ignore_warnings: False + enforce_tags: False + +# sets level of all command line loggers to 'DEBUG' +# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ +hydra: + job_logging: + root: + level: DEBUG + + # use this to also set hydra loggers to 'DEBUG' + # verbose: True + +trainer: + max_epochs: 1 + accelerator: cpu # debuggers don't like gpus + devices: 1 # debuggers don't like multiprocessing + detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor + +data: + num_workers: 0 # debuggers don't like multiprocessing + pin_memory: False # disable gpu memory pin diff --git a/configs/debug/fdr.yaml b/configs/debug/fdr.yaml new file mode 100644 index 0000000..7f2d34f --- /dev/null +++ b/configs/debug/fdr.yaml @@ -0,0 +1,9 @@ +# @package _global_ + +# runs 1 train, 1 validation and 1 test step + +defaults: + - default + +trainer: + fast_dev_run: true diff --git a/configs/debug/limit.yaml b/configs/debug/limit.yaml new file mode 100644 index 0000000..514d77f --- /dev/null +++ b/configs/debug/limit.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# uses only 1% of the training data and 5% of validation/test data + +defaults: + - default + +trainer: + max_epochs: 3 + limit_train_batches: 0.01 + limit_val_batches: 0.05 + limit_test_batches: 0.05 diff --git a/configs/debug/overfit.yaml b/configs/debug/overfit.yaml new file mode 100644 index 0000000..9906586 --- /dev/null +++ b/configs/debug/overfit.yaml @@ -0,0 +1,13 @@ +# @package _global_ + +# overfits to 3 batches + +defaults: + - default + +trainer: + max_epochs: 20 + overfit_batches: 3 + +# model ckpt and early stopping need to be disabled during overfitting +callbacks: null diff --git a/configs/debug/profiler.yaml b/configs/debug/profiler.yaml new file mode 100644 index 0000000..2bd7da8 --- /dev/null +++ b/configs/debug/profiler.yaml @@ -0,0 +1,12 @@ +# @package _global_ + +# runs with execution time profiling + +defaults: + - default + +trainer: + max_epochs: 1 + profiler: "simple" + # profiler: "advanced" + # profiler: "pytorch" diff --git a/configs/eval.yaml b/configs/eval.yaml new file mode 100644 index 0000000..be31299 --- /dev/null +++ b/configs/eval.yaml @@ -0,0 +1,18 @@ +# @package _global_ + +defaults: + - _self_ + - data: mnist # choose datamodule with `test_dataloader()` for evaluation + - model: mnist + - logger: null + - trainer: default + - paths: default + - extras: default + - hydra: default + +task_name: "eval" + +tags: ["dev"] + +# passing checkpoint path is necessary for evaluation +ckpt_path: ??? diff --git a/configs/experiment/example.yaml b/configs/experiment/example.yaml new file mode 100644 index 0000000..9a93b54 --- /dev/null +++ b/configs/experiment/example.yaml @@ -0,0 +1,41 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: mnist + - override /model: mnist + - override /callbacks: default + - override /trainer: default + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["mnist", "simple_dense_net"] + +seed: 12345 + +trainer: + min_epochs: 10 + max_epochs: 10 + gradient_clip_val: 0.5 + +model: + optimizer: + lr: 0.002 + net: + lin1_size: 128 + lin2_size: 256 + lin3_size: 64 + compile: false + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "mnist" + aim: + experiment: "mnist" diff --git a/configs/experiment/pvr.yaml b/configs/experiment/pvr.yaml new file mode 100644 index 0000000..6ccb541 --- /dev/null +++ b/configs/experiment/pvr.yaml @@ -0,0 +1,40 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: pvr + - override /model: transformer_dbn_classifier # transformer_dbn_classifier, gpt2 + - override /callbacks: default + - override /trainer: default + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters +name: "pvr" +# run_name: "${model.key}-${model.discrete_layer.key}" +run_name: "dbn transformer" + +tags: [] + +seed: 42 + +trainer: + min_epochs: 200 + max_epochs: 200 + # gradient_clip_val: 0.5 # TODO: add gradient clipping + check_val_every_n_epoch: 1 + accelerator: "gpu" + +model: + optimizer: + lr: 0.00001 + + +data: + batch_size: 256 + +logger: + wandb: + tags: ${tags} + # group: "mnist" \ No newline at end of file diff --git a/configs/extras/default.yaml b/configs/extras/default.yaml new file mode 100644 index 0000000..3255045 --- /dev/null +++ b/configs/extras/default.yaml @@ -0,0 +1,8 @@ +# disable python warnings if they annoy you +ignore_warnings: False + +# ask user for tags if none are provided in the config +enforce_tags: False + +# pretty print config tree at the start of the run using Rich library +print_config: True diff --git a/configs/hparams_search/mnist_optuna.yaml b/configs/hparams_search/mnist_optuna.yaml new file mode 100644 index 0000000..1391183 --- /dev/null +++ b/configs/hparams_search/mnist_optuna.yaml @@ -0,0 +1,52 @@ +# @package _global_ + +# example hyperparameter optimization of some experiment with Optuna: +# python train.py -m hparams_search=mnist_optuna experiment=example + +defaults: + - override /hydra/sweeper: optuna + +# choose metric which will be optimized by Optuna +# make sure this is the correct name of some metric logged in lightning module! +optimized_metric: "val/acc_best" + +# here we define Optuna hyperparameter search +# it optimizes for value returned from function with @hydra.main decorator +# docs: https://hydra.cc/docs/next/plugins/optuna_sweeper +hydra: + mode: "MULTIRUN" # set hydra to multirun by default if this config is attached + + sweeper: + _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper + + # storage URL to persist optimization results + # for example, you can use SQLite if you set 'sqlite:///example.db' + storage: null + + # name of the study to persist optimization results + study_name: null + + # number of parallel workers + n_jobs: 1 + + # 'minimize' or 'maximize' the objective + direction: maximize + + # total number of runs that will be executed + n_trials: 20 + + # choose Optuna hyperparameter sampler + # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others + # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html + sampler: + _target_: optuna.samplers.TPESampler + seed: 1234 + n_startup_trials: 10 # number of random sampling runs before optimization starts + + # define hyperparameter search space + params: + model.optimizer.lr: interval(0.0001, 0.1) + data.batch_size: choice(32, 64, 128, 256) + model.net.lin1_size: choice(64, 128, 256) + model.net.lin2_size: choice(64, 128, 256) + model.net.lin3_size: choice(32, 64, 128, 256) diff --git a/configs/hydra/default.yaml b/configs/hydra/default.yaml new file mode 100644 index 0000000..a61e9b3 --- /dev/null +++ b/configs/hydra/default.yaml @@ -0,0 +1,19 @@ +# https://hydra.cc/docs/configure_hydra/intro/ + +# enable color logging +defaults: + - override hydra_logging: colorlog + - override job_logging: colorlog + +# output directory, generated dynamically on each run +run: + dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} +sweep: + dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num} + +job_logging: + handlers: + file: + # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 + filename: ${hydra.runtime.output_dir}/${task_name}.log diff --git a/configs/local/.gitkeep b/configs/local/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/configs/logger/aim.yaml b/configs/logger/aim.yaml new file mode 100644 index 0000000..8f9f6ad --- /dev/null +++ b/configs/logger/aim.yaml @@ -0,0 +1,28 @@ +# https://aimstack.io/ + +# example usage in lightning module: +# https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py + +# open the Aim UI with the following command (run in the folder containing the `.aim` folder): +# `aim up` + +aim: + _target_: aim.pytorch_lightning.AimLogger + repo: ${paths.root_dir} # .aim folder will be created here + # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# + + # aim allows to group runs under experiment name + experiment: null # any string, set to "default" if not specified + + train_metric_prefix: "train/" + val_metric_prefix: "val/" + test_metric_prefix: "test/" + + # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) + system_tracking_interval: 10 # set to null to disable system metrics tracking + + # enable/disable logging of system params such as installed packages, git info, env vars, etc. + log_system_params: true + + # enable/disable tracking console logs (default value is true) + capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 diff --git a/configs/logger/comet.yaml b/configs/logger/comet.yaml new file mode 100644 index 0000000..e078927 --- /dev/null +++ b/configs/logger/comet.yaml @@ -0,0 +1,12 @@ +# https://www.comet.ml + +comet: + _target_: lightning.pytorch.loggers.comet.CometLogger + api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable + save_dir: "${paths.output_dir}" + project_name: "lightning-hydra-template" + rest_api_key: null + # experiment_name: "" + experiment_key: null # set to resume experiment + offline: False + prefix: "" diff --git a/configs/logger/csv.yaml b/configs/logger/csv.yaml new file mode 100644 index 0000000..fa028e9 --- /dev/null +++ b/configs/logger/csv.yaml @@ -0,0 +1,7 @@ +# csv logger built in lightning + +csv: + _target_: lightning.pytorch.loggers.csv_logs.CSVLogger + save_dir: "${paths.output_dir}" + name: "csv/" + prefix: "" diff --git a/configs/logger/many_loggers.yaml b/configs/logger/many_loggers.yaml new file mode 100644 index 0000000..dd58680 --- /dev/null +++ b/configs/logger/many_loggers.yaml @@ -0,0 +1,9 @@ +# train with many loggers at once + +defaults: + # - comet + - csv + # - mlflow + # - neptune + - tensorboard + - wandb diff --git a/configs/logger/mlflow.yaml b/configs/logger/mlflow.yaml new file mode 100644 index 0000000..f8fb7e6 --- /dev/null +++ b/configs/logger/mlflow.yaml @@ -0,0 +1,12 @@ +# https://mlflow.org + +mlflow: + _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger + # experiment_name: "" + # run_name: "" + tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI + tags: null + # save_dir: "./mlruns" + prefix: "" + artifact_location: null + # run_id: "" diff --git a/configs/logger/neptune.yaml b/configs/logger/neptune.yaml new file mode 100644 index 0000000..8233c14 --- /dev/null +++ b/configs/logger/neptune.yaml @@ -0,0 +1,9 @@ +# https://neptune.ai + +neptune: + _target_: lightning.pytorch.loggers.neptune.NeptuneLogger + api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable + project: username/lightning-hydra-template + # name: "" + log_model_checkpoints: True + prefix: "" diff --git a/configs/logger/tensorboard.yaml b/configs/logger/tensorboard.yaml new file mode 100644 index 0000000..2bd31f6 --- /dev/null +++ b/configs/logger/tensorboard.yaml @@ -0,0 +1,10 @@ +# https://www.tensorflow.org/tensorboard/ + +tensorboard: + _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger + save_dir: "${paths.output_dir}/tensorboard/" + name: null + log_graph: False + default_hp_metric: True + prefix: "" + # version: "" diff --git a/configs/model/default.yaml b/configs/model/default.yaml new file mode 100644 index 0000000..e69de29 diff --git a/configs/model/encdec_enc.yaml b/configs/model/encdec_enc.yaml new file mode 100644 index 0000000..6df82d1 --- /dev/null +++ b/configs/model/encdec_enc.yaml @@ -0,0 +1,83 @@ +_target_: src.models.encdec_enc.EncDecEncModel +compile: True + +encdec_config: + _target_: transformers.BartConfig + vocab_size: 50265 + max_position_embeddings: 256 + encoder_layers: 4 + encoder_ffn_dim: 4096 + encoder_attention_heads: 4 + decoder_layers: 4 + decoder_ffn_dim: 4096 + decoder_attention_heads: 4 + d_model: 1024 + use_cache: True + +autoreg_wrapper_config: + use_past_key_values: False + use_last_step_states: True + max_lengths: + input: 30 + output: 30 + soft_average: + p_eos_backward: True + p_eos_forward: False + word_embeds_with_scores_forward: True + + +enc_config: + _target_: transformers.BertConfig + vocab_size: 20 + hidden_size: 1024 + num_hidden_layers: 4 + num_attention_heads: 4 + intermediate_size: 3072 + max_position_embeddings: 256 + type_vocab_size: 1 + use_cache: True + +probe_discretizer: + _target_: blocks.modules.discrete_bottleneck.softmax.SoftmaxDiscreteBottleneck +probe_discretizer_config: + dimensions: + decoder_embedding_dim: 1024 + vocab_size: ${...enc_config.vocab_size} + encoder_embedding_dim: 1024 + unembedding_dim: ${...enc_config.vocab_size} + quantize_vector: True + temperature: 1.0 + encoder_embedding_trainable: True + decoder_embedding_trainable: True + linear_head_trainable: True + +input_discretizer: + _target_: blocks.modules.discrete_bottleneck.abstract_discrete_layer.AbstractDiscreteLayer +input_discretizer_config: + dimensions: + decoder_embedding_dim: 1024 + vocab_size: ${...encdec_config.vocab_size} + encoder_embedding_dim: 1024 + unembedding_dim: ${...encdec_config.vocab_size} + quantize_vector: True + temperature: 1.0 + encoder_embedding_trainable: True + decoder_embedding_trainable: True + linear_head_trainable: True + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 0.0 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: 'min' + factor: 0.95 + patience: 10 + cooldown: 0 + +monitor: "val/loss" +input_tokenizer_name: "facebook/bart-base" diff --git a/configs/model/gpt2.yaml b/configs/model/gpt2.yaml new file mode 100644 index 0000000..4d213c9 --- /dev/null +++ b/configs/model/gpt2.yaml @@ -0,0 +1,63 @@ +_target_: src.models.transformer_dbn_classifier.TransformerDBNClassifier + +key: "gpt2_classifier" + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 0.0 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: 'min' + factor: 0.95 + patience: 10 + cooldown: 0 + +monitor: "val/loss" + +#################################################### + +# compile model for faster training with pytorch 2.0 +compile: False + +nn: + _target_: src.models.components.gpt2_classifier.GPT2Classifier + embedding_dim: 256 + output_dim: ${model.nn.num_embedding} + dbn_after_each_layer: False + dbn_last_layer: False + shared_embedding_dbn: True + num_embedding: 10 + seq_len: 11 # TODO: set this automatically based on the data config file or take it form some higher level folder. + emb_dropout: 0.1 + depth: 6 + pool: 'mean' # 'mean' or 'cls' + supervision: False # TODO: move it below? + + + gpt2_config: + vocab_size: 10 + n_positions: 11 + n_embd: 256 + n_layer: 6 + n_head: 4 + output_hidden_states: True + + + discrete_layer: + _target_: src.models.components.discrete_layers.vqvae.VQVAEDiscreteLayer + key: 'vqvae' + temperature: 1.0 + label_smoothing_scale: 0.0 + dist_ord: 2 + vocab_size: ${model.nn.num_embedding} + dictionary_dim: ${model.nn.embedding_dim} + hard: True + projection_method: "layer norm" # "unit-sphere" "scale" "layer norm" or "None" + beta: 0.25 + + + diff --git a/configs/model/mnist.yaml b/configs/model/mnist.yaml new file mode 100644 index 0000000..6f9c2fa --- /dev/null +++ b/configs/model/mnist.yaml @@ -0,0 +1,25 @@ +_target_: src.models.mnist_module.MNISTLitModule + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 0.0 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +net: + _target_: src.models.components.simple_dense_net.SimpleDenseNet + input_size: 784 + lin1_size: 64 + lin2_size: 128 + lin3_size: 64 + output_size: 10 + +# compile model for faster training with pytorch 2.0 +compile: false diff --git a/configs/model/transformer_dbn_classifier.yaml b/configs/model/transformer_dbn_classifier.yaml new file mode 100644 index 0000000..f47e9ee --- /dev/null +++ b/configs/model/transformer_dbn_classifier.yaml @@ -0,0 +1,61 @@ +_target_: src.models.transformer_dbn_classifier.TransformerDBNClassifier + +key: "transformer_dbn_classifier" + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 0.0 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: 'min' + factor: 0.95 + patience: 10 + cooldown: 0 + +monitor: "val/loss" + +#################################################### + +# compile model for faster training with pytorch 2.0 +compile: False + +nn: + _target_: src.models.components.transformer.TransformerDBN + embedding_dim: 256 + output_dim: ${model.nn.num_embedding} + dbn_after_each_layer: False + dbn_last_layer: True + shared_embedding_dbn: False + num_embedding: 10 + seq_len: 11 # TODO: set this automatically based on the data config file or take it form some higher level folder. + emb_dropout: 0.1 + depth: 6 + pool: 'mean' # 'mean' or 'cls' + supervision: False # TODO: move it below? + loss_coeffs: + disc_loss: 0.001 + + discrete_layer: + _target_: src.models.components.discrete_layers.vqvae.VQVAEDiscreteLayer + key: 'vqvae' + temperature: 1.0 + label_smoothing_scale: 0.0 + dist_ord: 2 + vocab_size: ${model.nn.num_embedding} + dictionary_dim: ${model.nn.embedding_dim} + hard: True + projection_method: "layer norm" # "unit-sphere" "scale" "layer norm" or "None" + beta: 0.25 + + transformer_layer: + _target_: src.models.components.transformer.TransformerLayer + num_heads: 8 + dropout: ${model.nn.emb_dropout} + dim: ${model.nn.embedding_dim} + mlp_dim: ${model.nn.embedding_dim} + + diff --git a/configs/paths/default.yaml b/configs/paths/default.yaml new file mode 100644 index 0000000..ec81db2 --- /dev/null +++ b/configs/paths/default.yaml @@ -0,0 +1,18 @@ +# path to root directory +# this requires PROJECT_ROOT environment variable to exist +# you can replace it with "." if you want the root to be the current working directory +root_dir: ${oc.env:PROJECT_ROOT} + +# path to data directory +data_dir: ${paths.root_dir}/data/ + +# path to logging directory +log_dir: ${paths.root_dir}/logs/ + +# path to output directory, created dynamically by hydra +# path generation pattern is specified in `configs/hydra/default.yaml` +# use it to store all files generated during the run, like ckpts and metrics +output_dir: ${hydra:runtime.output_dir} + +# path to working directory +work_dir: ${hydra:runtime.cwd} diff --git a/configs/train.yaml b/configs/train.yaml new file mode 100644 index 0000000..a6caba5 --- /dev/null +++ b/configs/train.yaml @@ -0,0 +1,54 @@ +# @package _global_ + +# specify here default configuration +# order of defaults determines the order in which configs override each other +defaults: + - _self_ + - data: read_activation_from_file + - model: encdec_enc + - callbacks: default + - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`) + - trainer: default + - paths: default + - extras: default + - hydra: default + + # experiment configs allow for version control of specific hyperparameters + # e.g. best hyperparameters for given model and datamodule + - experiment: null + + # config for hyperparameter optimization + - hparams_search: null + + # optional local config for machine/user specific settings + # it's optional since it doesn't need to exist and is excluded from version control + - optional local: default + + # debugging config (enable through command line, e.g. `python train.py debug=default) + - debug: null + +# determines the log directory's identifier +name: "symbolic_probing" # ??? +run_name: "debugging" # ??? + + +# task name, determines output directory path +task_name: "train" + +# tags to help you identify your experiments +# you can overwrite this in experiment configs +# overwrite from command line with `python train.py tags="[first_tag, second_tag]"` +tags: ["dev"] + +# set False to skip model training +train: True + +# evaluate on test set, using best model weights achieved during training +# lightning chooses best weights based on the metric specified in checkpoint callback +test: True + +# simply provide checkpoint path to resume training +ckpt_path: null + +# seed for random number generators in pytorch, numpy and python.random +seed: 42 diff --git a/configs/trainer/cpu.yaml b/configs/trainer/cpu.yaml new file mode 100644 index 0000000..b7d6767 --- /dev/null +++ b/configs/trainer/cpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: cpu +devices: 1 diff --git a/configs/trainer/ddp.yaml b/configs/trainer/ddp.yaml new file mode 100644 index 0000000..ab8f890 --- /dev/null +++ b/configs/trainer/ddp.yaml @@ -0,0 +1,9 @@ +defaults: + - default + +strategy: ddp + +accelerator: gpu +devices: 4 +num_nodes: 1 +sync_batchnorm: True diff --git a/configs/trainer/ddp_sim.yaml b/configs/trainer/ddp_sim.yaml new file mode 100644 index 0000000..8404419 --- /dev/null +++ b/configs/trainer/ddp_sim.yaml @@ -0,0 +1,7 @@ +defaults: + - default + +# simulate DDP on CPU, useful for debugging +accelerator: cpu +devices: 2 +strategy: ddp_spawn diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml new file mode 100644 index 0000000..50905e7 --- /dev/null +++ b/configs/trainer/default.yaml @@ -0,0 +1,19 @@ +_target_: lightning.pytorch.trainer.Trainer + +default_root_dir: ${paths.output_dir} + +min_epochs: 1 # prevents early stopping +max_epochs: 10 + +accelerator: cpu +devices: 1 + +# mixed precision for extra speed-up +# precision: 16 + +# perform a validation loop every N training epochs +check_val_every_n_epoch: 1 + +# set True to to ensure deterministic results +# makes training slower but gives more reproducibility than just setting seeds +deterministic: False diff --git a/configs/trainer/gpu.yaml b/configs/trainer/gpu.yaml new file mode 100644 index 0000000..b238951 --- /dev/null +++ b/configs/trainer/gpu.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: gpu +devices: 1 diff --git a/configs/trainer/mps.yaml b/configs/trainer/mps.yaml new file mode 100644 index 0000000..1ecf6d5 --- /dev/null +++ b/configs/trainer/mps.yaml @@ -0,0 +1,5 @@ +defaults: + - default + +accelerator: mps +devices: 1 diff --git a/result/mistral_v1_outputs_test.json b/outputs/result/mistral_v1_outputs_test.json similarity index 100% rename from result/mistral_v1_outputs_test.json rename to outputs/result/mistral_v1_outputs_test.json diff --git a/result/mistral_v1_outputs_test_pause_token.json b/outputs/result/mistral_v1_outputs_test_pause_token.json similarity index 100% rename from result/mistral_v1_outputs_test_pause_token.json rename to outputs/result/mistral_v1_outputs_test_pause_token.json diff --git a/result/mistral_v1_outputs_test_pause_token_random.json b/outputs/result/mistral_v1_outputs_test_pause_token_random.json similarity index 100% rename from result/mistral_v1_outputs_test_pause_token_random.json rename to outputs/result/mistral_v1_outputs_test_pause_token_random.json diff --git a/result/mistral_v1_outputs_train.json b/outputs/result/mistral_v1_outputs_train.json similarity index 100% rename from result/mistral_v1_outputs_train.json rename to outputs/result/mistral_v1_outputs_train.json diff --git a/result/output_mistral_instruct_v2_pause_token_sft_gsm8k.json b/outputs/result/output_mistral_instruct_v2_pause_token_sft_gsm8k.json similarity index 100% rename from result/output_mistral_instruct_v2_pause_token_sft_gsm8k.json rename to outputs/result/output_mistral_instruct_v2_pause_token_sft_gsm8k.json diff --git a/result/output_mistral_instruct_v2_sft_gsm8k.json b/outputs/result/output_mistral_instruct_v2_sft_gsm8k.json similarity index 100% rename from result/output_mistral_instruct_v2_sft_gsm8k.json rename to outputs/result/output_mistral_instruct_v2_sft_gsm8k.json diff --git a/result/output_mistral_instruct_v2_sft_gsm8k_updated_embedding_layer.json b/outputs/result/output_mistral_instruct_v2_sft_gsm8k_updated_embedding_layer.json similarity index 100% rename from result/output_mistral_instruct_v2_sft_gsm8k_updated_embedding_layer.json rename to outputs/result/output_mistral_instruct_v2_sft_gsm8k_updated_embedding_layer.json diff --git a/result/rc_mistral_correctness_reward_outer_loop_1.json b/outputs/result/rc_mistral_correctness_reward_outer_loop_1.json similarity index 100% rename from result/rc_mistral_correctness_reward_outer_loop_1.json rename to outputs/result/rc_mistral_correctness_reward_outer_loop_1.json diff --git a/result/rc_mistral_correctness_reward_outer_loop_2.json b/outputs/result/rc_mistral_correctness_reward_outer_loop_2.json similarity index 100% rename from result/rc_mistral_correctness_reward_outer_loop_2.json rename to outputs/result/rc_mistral_correctness_reward_outer_loop_2.json diff --git a/result/rc_mistral_delta_reward_outer_loop_1.json b/outputs/result/rc_mistral_delta_reward_outer_loop_1.json similarity index 100% rename from result/rc_mistral_delta_reward_outer_loop_1.json rename to outputs/result/rc_mistral_delta_reward_outer_loop_1.json diff --git a/result/rc_mistral_delta_reward_outer_loop_2.json b/outputs/result/rc_mistral_delta_reward_outer_loop_2.json similarity index 100% rename from result/rc_mistral_delta_reward_outer_loop_2.json rename to outputs/result/rc_mistral_delta_reward_outer_loop_2.json diff --git a/result/rc_mistral_outer_loop_0.json b/outputs/result/rc_mistral_outer_loop_0.json similarity index 100% rename from result/rc_mistral_outer_loop_0.json rename to outputs/result/rc_mistral_outer_loop_0.json diff --git a/pip_requirements.txt b/pip_requirements.txt index be3b9b7..42fd9df 100644 --- a/pip_requirements.txt +++ b/pip_requirements.txt @@ -1,10 +1,26 @@ +# --------- transformers --------- # accelerate==0.28.0 datasets==2.18.0 transformers==4.39.3 -hydra-core==1.3.2 trl==0.8.1 peft==0.10.0 sentencepiece==0.2.0 scipy==1.13.0 protobuf==5.26.1 matplotlib==3.8.4 + +# --------- hydra --------- # +hydra-core==1.3.2 +hydra-colorlog==1.2.0 +hydra-optuna-sweeper==1.2.0 + +# --------- loggers --------- # +wandb + +# --------- others --------- # +rootutils # standardizing the project root setup +pre-commit # hooks for applying linters on commit +rich # beautiful text formatting in terminal +pytest # tests +einops +sh # for running bash commands in some tests (linux/macos only) diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index f0d91b3..0000000 --- a/requirements.txt +++ /dev/null @@ -1,137 +0,0 @@ -accelerate==0.22.0 -aiohttp==3.8.5 -aiosignal==1.3.1 -annotated-types==0.5.0 -antlr4-python3-runtime==4.9.3 -appdirs==1.4.4 -asttokens==2.2.1 -async-timeout==4.0.3 -attrs==23.1.0 -backcall==0.2.0 -beautifulsoup4==4.12.2 -blis==0.7.10 -boto3==1.28.39 -botocore==1.31.39 -bs4==0.0.1 -catalogue==2.0.9 -certifi==2023.7.22 -charset-normalizer==3.2.0 -click==8.1.7 -cmake==3.27.2 -confection==0.1.1 -contourpy==1.1.0 -cycler==0.11.0 -cymem==2.0.7 -datasets==2.15.0 -decorator==5.1.1 -dill==0.3.7 -docker-pycreds==0.4.0 -en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.6.0/en_core_web_sm-3.6.0-py3-none-any.whl#sha256=83276fc78a70045627144786b52e1f2728ad5e29e5e43916ec37ea9c26a11212 -executing==1.2.0 -filelock==3.12.3 -fonttools==4.42.1 -frozenlist==1.4.0 -fsspec==2023.6.0 -gitdb==4.0.10 -GitPython==3.1.33 -huggingface-hub==0.19.4 -Hydra==2.5 -hydra-core==1.3.2 -idna==3.4 -importlib-resources==6.0.1 -ipdb==0.13.13 -ipython==8.12.2 -jedi==0.19.0 -Jinja2==3.1.2 -jmespath==1.0.1 -joblib==1.3.2 -kiwisolver==1.4.5 -langcodes==3.3.0 -lit==16.0.6 -MarkupSafe==2.1.3 -matplotlib==3.7.2 -matplotlib-inline==0.1.6 -mpmath==1.3.0 -multidict==6.0.4 -multiprocess==0.70.15 -murmurhash==1.0.9 -networkx==3.1 -numpy==1.24.4 -nvidia-cublas-cu11==11.10.3.66 -nvidia-cuda-cupti-cu11==11.7.101 -nvidia-cuda-nvrtc-cu11==11.7.99 -nvidia-cuda-runtime-cu11==11.7.99 -nvidia-cudnn-cu11==8.5.0.96 -nvidia-cufft-cu11==10.9.0.58 -nvidia-curand-cu11==10.2.10.91 -nvidia-cusolver-cu11==11.4.0.1 -nvidia-cusparse-cu11==11.7.4.91 -nvidia-nccl-cu11==2.14.3 -nvidia-nvtx-cu11==11.7.91 -omegaconf==2.3.0 -openai==0.28.0 -packaging==23.1 -pandas==2.0.3 -parso==0.8.3 -pathtools==0.1.2 -pathy==0.10.2 -peft==0.5.0 -pexpect==4.8.0 -pickleshare==0.7.5 -Pillow==10.0.0 -preshed==3.0.8 -prompt-toolkit==3.0.39 -protobuf==4.24.2 -psutil==5.9.5 -ptyprocess==0.7.0 -pure-eval==0.2.2 -pyarrow==14.0.1 -pyarrow-hotfix==0.5 -pydantic==2.3.0 -pydantic_core==2.6.3 -Pygments==2.16.1 -pyparsing==3.0.9 -python-dateutil==2.8.2 -pytz==2023.3 -PyYAML==6.0.1 -regex==2023.8.8 -requests==2.31.0 -s3transfer==0.6.2 -safetensors==0.3.3 -scikit-learn==1.3.0 -scipy==1.10.1 -seaborn==0.12.2 -sentencepiece==0.1.99 -sentry-sdk==1.30.0 -setproctitle==1.3.2 -six==1.16.0 -smart-open==6.3.0 -smmap==5.0.0 -soupsieve==2.5 -spacy==3.6.1 -spacy-legacy==3.0.12 -spacy-loggers==1.0.4 -srsly==2.4.7 -stack-data==0.6.2 -sympy==1.12 -tensor-parallel==2.0.0 -thinc==8.1.12 -threadpoolctl==3.2.0 -tiktoken==0.6.0 -tokenizers==0.13.3 -tomli==2.0.1 -torch==2.0.1 -tqdm==4.66.1 -traitlets==5.9.0 -transformers==4.32.1 -triton==2.0.0 -typer==0.9.0 -typing_extensions==4.7.1 -tzdata==2023.3 -urllib3==1.26.16 -wandb==0.15.9 -wasabi==1.1.2 -wcwidth==0.2.6 -xxhash==3.4.1 -yarl==1.9.2 -zipp==3.16.2 diff --git a/scripts/schedule.sh b/scripts/schedule.sh new file mode 100644 index 0000000..44b3da1 --- /dev/null +++ b/scripts/schedule.sh @@ -0,0 +1,7 @@ +#!/bin/bash +# Schedule execution of many runs +# Run from root folder with: bash scripts/schedule.sh + +python src/train.py trainer.max_epochs=5 logger=csv + +python src/train.py trainer.max_epochs=10 logger=csv diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data/components/__init__.py b/src/data/components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data_generation/gsm8k_pause_injector.py b/src/data/data_generation/gsm8k_pause_injector.py similarity index 100% rename from data_generation/gsm8k_pause_injector.py rename to src/data/data_generation/gsm8k_pause_injector.py diff --git a/gsm8k/dataset_dict.json b/src/data/gsm8k/dataset_dict.json similarity index 100% rename from gsm8k/dataset_dict.json rename to src/data/gsm8k/dataset_dict.json diff --git a/gsm8k/test/cache-196eeac2572b05e9.arrow b/src/data/gsm8k/test/cache-196eeac2572b05e9.arrow similarity index 100% rename from gsm8k/test/cache-196eeac2572b05e9.arrow rename to src/data/gsm8k/test/cache-196eeac2572b05e9.arrow diff --git a/gsm8k/test/cache-1a9eda85e7e31587.arrow b/src/data/gsm8k/test/cache-1a9eda85e7e31587.arrow similarity index 100% rename from gsm8k/test/cache-1a9eda85e7e31587.arrow rename to src/data/gsm8k/test/cache-1a9eda85e7e31587.arrow diff --git a/gsm8k/test/cache-1af4c0d8a167eeea.arrow b/src/data/gsm8k/test/cache-1af4c0d8a167eeea.arrow similarity index 100% rename from gsm8k/test/cache-1af4c0d8a167eeea.arrow rename to src/data/gsm8k/test/cache-1af4c0d8a167eeea.arrow diff --git a/gsm8k/test/cache-2da3ebc9746ca436.arrow b/src/data/gsm8k/test/cache-2da3ebc9746ca436.arrow similarity index 100% rename from gsm8k/test/cache-2da3ebc9746ca436.arrow rename to src/data/gsm8k/test/cache-2da3ebc9746ca436.arrow diff --git a/gsm8k/test/cache-366cf927ac943456.arrow b/src/data/gsm8k/test/cache-366cf927ac943456.arrow similarity index 100% rename from gsm8k/test/cache-366cf927ac943456.arrow rename to src/data/gsm8k/test/cache-366cf927ac943456.arrow diff --git a/gsm8k/test/cache-39c3af65a4471c45.arrow b/src/data/gsm8k/test/cache-39c3af65a4471c45.arrow similarity index 100% rename from gsm8k/test/cache-39c3af65a4471c45.arrow rename to src/data/gsm8k/test/cache-39c3af65a4471c45.arrow diff --git a/gsm8k/test/cache-45df51f38b4c4c3a.arrow b/src/data/gsm8k/test/cache-45df51f38b4c4c3a.arrow similarity index 100% rename from gsm8k/test/cache-45df51f38b4c4c3a.arrow rename to src/data/gsm8k/test/cache-45df51f38b4c4c3a.arrow diff --git a/gsm8k/test/cache-47dfc58cb94fa277.arrow b/src/data/gsm8k/test/cache-47dfc58cb94fa277.arrow similarity index 100% rename from gsm8k/test/cache-47dfc58cb94fa277.arrow rename to src/data/gsm8k/test/cache-47dfc58cb94fa277.arrow diff --git a/gsm8k/test/cache-7cf8cfef47ebbe5c.arrow b/src/data/gsm8k/test/cache-7cf8cfef47ebbe5c.arrow similarity index 100% rename from gsm8k/test/cache-7cf8cfef47ebbe5c.arrow rename to src/data/gsm8k/test/cache-7cf8cfef47ebbe5c.arrow diff --git a/gsm8k/test/cache-9d13276854e38e73.arrow b/src/data/gsm8k/test/cache-9d13276854e38e73.arrow similarity index 100% rename from gsm8k/test/cache-9d13276854e38e73.arrow rename to src/data/gsm8k/test/cache-9d13276854e38e73.arrow diff --git a/gsm8k/test/cache-b1882bbcd73dee91.arrow b/src/data/gsm8k/test/cache-b1882bbcd73dee91.arrow similarity index 100% rename from gsm8k/test/cache-b1882bbcd73dee91.arrow rename to src/data/gsm8k/test/cache-b1882bbcd73dee91.arrow diff --git a/gsm8k/test/cache-eea01a3da67fbad0.arrow b/src/data/gsm8k/test/cache-eea01a3da67fbad0.arrow similarity index 100% rename from gsm8k/test/cache-eea01a3da67fbad0.arrow rename to src/data/gsm8k/test/cache-eea01a3da67fbad0.arrow diff --git a/gsm8k/test/cache-eecd7d4611a6a655.arrow b/src/data/gsm8k/test/cache-eecd7d4611a6a655.arrow similarity index 100% rename from gsm8k/test/cache-eecd7d4611a6a655.arrow rename to src/data/gsm8k/test/cache-eecd7d4611a6a655.arrow diff --git a/gsm8k/test/data-00000-of-00001.arrow b/src/data/gsm8k/test/data-00000-of-00001.arrow similarity index 100% rename from gsm8k/test/data-00000-of-00001.arrow rename to src/data/gsm8k/test/data-00000-of-00001.arrow diff --git a/gsm8k/test/dataset_info.json b/src/data/gsm8k/test/dataset_info.json similarity index 100% rename from gsm8k/test/dataset_info.json rename to src/data/gsm8k/test/dataset_info.json diff --git a/gsm8k/test/state.json b/src/data/gsm8k/test/state.json similarity index 100% rename from gsm8k/test/state.json rename to src/data/gsm8k/test/state.json diff --git a/gsm8k/train/cache-081b7fc154fc7034.arrow b/src/data/gsm8k/train/cache-081b7fc154fc7034.arrow similarity index 100% rename from gsm8k/train/cache-081b7fc154fc7034.arrow rename to src/data/gsm8k/train/cache-081b7fc154fc7034.arrow diff --git a/gsm8k/train/cache-2ac2a5d0e8f825ee.arrow b/src/data/gsm8k/train/cache-2ac2a5d0e8f825ee.arrow similarity index 100% rename from gsm8k/train/cache-2ac2a5d0e8f825ee.arrow rename to src/data/gsm8k/train/cache-2ac2a5d0e8f825ee.arrow diff --git a/gsm8k/train/cache-39945ff6b579263f.arrow b/src/data/gsm8k/train/cache-39945ff6b579263f.arrow similarity index 100% rename from gsm8k/train/cache-39945ff6b579263f.arrow rename to src/data/gsm8k/train/cache-39945ff6b579263f.arrow diff --git a/gsm8k/train/cache-3e71712acde2770d.arrow b/src/data/gsm8k/train/cache-3e71712acde2770d.arrow similarity index 100% rename from gsm8k/train/cache-3e71712acde2770d.arrow rename to src/data/gsm8k/train/cache-3e71712acde2770d.arrow diff --git a/gsm8k/train/cache-56f5e92911f3bf68.arrow b/src/data/gsm8k/train/cache-56f5e92911f3bf68.arrow similarity index 100% rename from gsm8k/train/cache-56f5e92911f3bf68.arrow rename to src/data/gsm8k/train/cache-56f5e92911f3bf68.arrow diff --git a/gsm8k/train/cache-9c02a9e3a84de81c.arrow b/src/data/gsm8k/train/cache-9c02a9e3a84de81c.arrow similarity index 100% rename from gsm8k/train/cache-9c02a9e3a84de81c.arrow rename to src/data/gsm8k/train/cache-9c02a9e3a84de81c.arrow diff --git a/gsm8k/train/cache-c16a136ee2762bbb.arrow b/src/data/gsm8k/train/cache-c16a136ee2762bbb.arrow similarity index 100% rename from gsm8k/train/cache-c16a136ee2762bbb.arrow rename to src/data/gsm8k/train/cache-c16a136ee2762bbb.arrow diff --git a/gsm8k/train/cache-c26f1f258fd66ee1.arrow b/src/data/gsm8k/train/cache-c26f1f258fd66ee1.arrow similarity index 100% rename from gsm8k/train/cache-c26f1f258fd66ee1.arrow rename to src/data/gsm8k/train/cache-c26f1f258fd66ee1.arrow diff --git a/gsm8k/train/cache-e06c1b40f0909757.arrow b/src/data/gsm8k/train/cache-e06c1b40f0909757.arrow similarity index 100% rename from gsm8k/train/cache-e06c1b40f0909757.arrow rename to src/data/gsm8k/train/cache-e06c1b40f0909757.arrow diff --git a/gsm8k/train/cache-e1ac74e31ad7226f.arrow b/src/data/gsm8k/train/cache-e1ac74e31ad7226f.arrow similarity index 100% rename from gsm8k/train/cache-e1ac74e31ad7226f.arrow rename to src/data/gsm8k/train/cache-e1ac74e31ad7226f.arrow diff --git a/gsm8k/train/cache-e20838f408cfedf3.arrow b/src/data/gsm8k/train/cache-e20838f408cfedf3.arrow similarity index 100% rename from gsm8k/train/cache-e20838f408cfedf3.arrow rename to src/data/gsm8k/train/cache-e20838f408cfedf3.arrow diff --git a/gsm8k/train/cache-f4bd95e6f6e73bbe.arrow b/src/data/gsm8k/train/cache-f4bd95e6f6e73bbe.arrow similarity index 100% rename from gsm8k/train/cache-f4bd95e6f6e73bbe.arrow rename to src/data/gsm8k/train/cache-f4bd95e6f6e73bbe.arrow diff --git a/gsm8k/train/cache-fd7490825757a073.arrow b/src/data/gsm8k/train/cache-fd7490825757a073.arrow similarity index 100% rename from gsm8k/train/cache-fd7490825757a073.arrow rename to src/data/gsm8k/train/cache-fd7490825757a073.arrow diff --git a/gsm8k/train/data-00000-of-00001.arrow b/src/data/gsm8k/train/data-00000-of-00001.arrow similarity index 100% rename from gsm8k/train/data-00000-of-00001.arrow rename to src/data/gsm8k/train/data-00000-of-00001.arrow diff --git a/gsm8k/train/dataset_info.json b/src/data/gsm8k/train/dataset_info.json similarity index 100% rename from gsm8k/train/dataset_info.json rename to src/data/gsm8k/train/dataset_info.json diff --git a/gsm8k/train/state.json b/src/data/gsm8k/train/state.json similarity index 100% rename from gsm8k/train/state.json rename to src/data/gsm8k/train/state.json diff --git a/data/gsm8k_json/gsm8k/test.json b/src/data/gsm8k_json/gsm8k/test.json similarity index 100% rename from data/gsm8k_json/gsm8k/test.json rename to src/data/gsm8k_json/gsm8k/test.json diff --git a/data/gsm8k_json/gsm8k/train.json b/src/data/gsm8k_json/gsm8k/train.json similarity index 100% rename from data/gsm8k_json/gsm8k/train.json rename to src/data/gsm8k_json/gsm8k/train.json diff --git a/data/gsm8k_json/gsm8k_pause_token/test.json b/src/data/gsm8k_json/gsm8k_pause_token/test.json similarity index 100% rename from data/gsm8k_json/gsm8k_pause_token/test.json rename to src/data/gsm8k_json/gsm8k_pause_token/test.json diff --git a/data/gsm8k_json/gsm8k_pause_token/train.json b/src/data/gsm8k_json/gsm8k_pause_token/train.json similarity index 100% rename from data/gsm8k_json/gsm8k_pause_token/train.json rename to src/data/gsm8k_json/gsm8k_pause_token/train.json diff --git a/data/gsm8k_json/gsm8k_pause_token_random/test.json b/src/data/gsm8k_json/gsm8k_pause_token_random/test.json similarity index 100% rename from data/gsm8k_json/gsm8k_pause_token_random/test.json rename to src/data/gsm8k_json/gsm8k_pause_token_random/test.json diff --git a/data/gsm8k_json/gsm8k_pause_token_random/train.json b/src/data/gsm8k_json/gsm8k_pause_token_random/train.json similarity index 100% rename from data/gsm8k_json/gsm8k_pause_token_random/train.json rename to src/data/gsm8k_json/gsm8k_pause_token_random/train.json diff --git a/data/gsm8k_jsonl/gsm8k/test.json b/src/data/gsm8k_jsonl/gsm8k/test.json similarity index 100% rename from data/gsm8k_jsonl/gsm8k/test.json rename to src/data/gsm8k_jsonl/gsm8k/test.json diff --git a/data/gsm8k_jsonl/gsm8k/train.json b/src/data/gsm8k_jsonl/gsm8k/train.json similarity index 100% rename from data/gsm8k_jsonl/gsm8k/train.json rename to src/data/gsm8k_jsonl/gsm8k/train.json diff --git a/data/gsm8k_jsonl/gsm8k_10_random_pause_injected_mistral/test.json b/src/data/gsm8k_jsonl/gsm8k_10_random_pause_injected_mistral/test.json similarity index 100% rename from data/gsm8k_jsonl/gsm8k_10_random_pause_injected_mistral/test.json rename to src/data/gsm8k_jsonl/gsm8k_10_random_pause_injected_mistral/test.json diff --git a/data/gsm8k_jsonl/gsm8k_10_random_pause_injected_mistral/train.json b/src/data/gsm8k_jsonl/gsm8k_10_random_pause_injected_mistral/train.json similarity index 100% rename from data/gsm8k_jsonl/gsm8k_10_random_pause_injected_mistral/train.json rename to src/data/gsm8k_jsonl/gsm8k_10_random_pause_injected_mistral/train.json diff --git a/data/gsm8k_jsonl/gsm8k_pause_injected/test.json b/src/data/gsm8k_jsonl/gsm8k_pause_injected/test.json similarity index 100% rename from data/gsm8k_jsonl/gsm8k_pause_injected/test.json rename to src/data/gsm8k_jsonl/gsm8k_pause_injected/test.json diff --git a/data/gsm8k_jsonl/gsm8k_pause_injected/train.json b/src/data/gsm8k_jsonl/gsm8k_pause_injected/train.json similarity index 100% rename from data/gsm8k_jsonl/gsm8k_pause_injected/train.json rename to src/data/gsm8k_jsonl/gsm8k_pause_injected/train.json diff --git a/data/gsm8k_random_pauses_5_samples_per_dp/train.json b/src/data/gsm8k_random_pauses_5_samples_per_dp/train.json similarity index 100% rename from data/gsm8k_random_pauses_5_samples_per_dp/train.json rename to src/data/gsm8k_random_pauses_5_samples_per_dp/train.json diff --git a/src/data/read_activation_from_file_datamodule.py b/src/data/read_activation_from_file_datamodule.py new file mode 100644 index 0000000..0999786 --- /dev/null +++ b/src/data/read_activation_from_file_datamodule.py @@ -0,0 +1,224 @@ +from typing import Any, Dict, Optional, Tuple + +import hydra +import torch +from lightning import LightningDataModule +from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset +from datasets import Dataset, DatasetDict, concatenate_datasets +from tqdm import tqdm +import os +import code +from omegaconf import DictConfig + + +class ReadActivationFromFileDataModule(LightningDataModule): + """`LightningDataModule` for the PVR dataset. + """ + + def __init__(self, **kwargs) -> None: + self, + # data_dir: str = "data/", + # train_val_test_split: Tuple[int, int, int] = (55_000, 5_000, 10_000), + # batch_size: int = 64, + # num_workers: int = 0, + # pin_memory: bool = False, + + super().__init__() + self.save_hyperparameters(logger=False) + + self.data_train: Optional[Dataset] = None + self.data_val: Optional[Dataset] = None + self.data_test: Optional[Dataset] = None + + self.batch_size_per_device = kwargs['batch_size'] + + def prepare_data(self) -> None: + """Download data if needed. Lightning ensures that `self.prepare_data()` is called only + within a single process on CPU, so you can safely add your downloading logic within. In + case of multi-node training, the execution of this hook depends upon + `self.prepare_data_per_node()`. + + Do not use it to assign state (self.x = y). + """ + pass + + def setup(self, stage: Optional[str] = None) -> None: + """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + + This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and + `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after + `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to + `self.setup()` once the data is prepared and available for use. + + :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. + """ + # Divide batch size by the number of devices. This is necessary for multi-GPU training. + if self.trainer is not None: + if self.hparams.batch_size % self.trainer.world_size != 0: + raise RuntimeError( + f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." + ) + self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size + + # load and split datasets only if not loaded already + if not self.data_train and not self.data_val and not self.data_test: + generator = torch.Generator().manual_seed(self.hparams.seed) + dataset = self.__load_dataset(self.hparams.path) + if len(dataset.keys()) < 3: + dataset = dataset[list(dataset.keys())[0]] + self.data_train, self.data_val, self.data_test = random_split( + dataset=dataset, + lengths=self.hparams.train_val_test_split, + generator=generator, + ) + else: + self.data_train = dataset['train'] + self.data_val = dataset['validation'] + self.data_test = dataset['test'] + + + def __load_dataset(self, data_dir: str) -> Dataset: + """Load the dataset from the data directory. + + :param data_dir: The directory containing the dataset. + :return: The dataset. + """ + # Load the dataset from the data directory + if self.hparams.format=='chunks': + dataset = self.__load_dataset_from_chunks(data_dir) + elif self.hparams.format=='one_file': + dataset = self.__load_dataset_from_one_file(data_dir) + else: + raise ValueError(f"Invalid format: {self.hparams.format}") + return dataset + + + def __load_dataset_from_chunks(self, data_dir: str) -> Dataset: + """Load the dataset parts from the data directory. + merge them into a file. + """ + # Load and concatenate all chunks for each split + # making a final dataset with all the chunks and splits + # removing the final directory if it exists + final_path = os.path.join(data_dir, "final") + if os.path.exists(final_path): + os.system(f"rm -r {final_path}") + final_dataset = DatasetDict() + num_chunks = len(os.listdir(data_dir)) + for split in os.listdir(data_dir): + data_split_path = os.path.join(data_dir, split) + num_chunks = len(os.listdir(data_split_path)) + # for chunk_num in range(num_chunks): + # chunk_path = os.path.join(data_split_path, f"chunk_{chunk_num}") + dir_list = os.listdir(data_split_path) + dir_list.sort() + chunk_num = 0 + for chunk_name in tqdm(dir_list, desc=f"Loading {split}"): + chunk_path = os.path.join(data_split_path, chunk_name) + if chunk_num == 0: + loaded_dset = Dataset.load_from_disk(chunk_path) + chunk_num += 1 + else: + if os.path.exists(chunk_path+"dataset_info.json"): + loaded_dset = concatenate_datasets([loaded_dset, Dataset.load_from_disk(chunk_path)]) + chunk_num += 1 + else: + print(f"Skipping {chunk_path}, file broken") + if chunk_num == self.hparams.max_num_chunks: + break + final_dataset[split] = loaded_dset + + # Optionally, save the concatenated final datasets to disk + final_dataset.save_to_disk(final_path) + + return final_dataset + + def __load_dataset_from_one_file(self, data_dir: str) -> Dataset: + """Load the dataset from a single file. + """ + try: + dataset = Dataset.load_from_disk(data_dir + "/final") + except: + dataset = DatasetDict.load_from_disk(data_dir + "/final") + return dataset + + def train_dataloader(self) -> DataLoader[Any]: + """Create and return the train dataloader. + + :return: The train dataloader. + """ + return DataLoader( + dataset=self.data_train, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Create and return the validation dataloader. + + :return: The validation dataloader. + """ + return DataLoader( + dataset=self.data_val, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Create and return the test dataloader. + + :return: The test dataloader. + """ + return DataLoader( + dataset=self.data_test, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + ) + + def teardown(self, stage: Optional[str] = None) -> None: + """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, + `trainer.test()`, and `trainer.predict()`. + + :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + Defaults to ``None``. + """ + pass + + def state_dict(self) -> Dict[Any, Any]: + # TODO: Implement + """Called when saving a checkpoint. Implement to generate and save the datamodule state. + + :return: A dictionary containing the datamodule state that you want to save. + """ + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + # TODO: Implement + """Called when loading a checkpoint. Implement to reload datamodule state given datamodule + `state_dict()`. + + :param state_dict: The datamodule state returned by `self.state_dict()`. + """ + pass + + + + + +@hydra.main(version_base="1.3", config_path="../../configs/data", config_name="read_activation_from_file.yaml") +def main(cfg: DictConfig) -> Optional[float]: + data_module = ReadActivationFromFileDataModule(**cfg) + data_module.setup() + dl = data_module.train_dataloader() + for x in dl: + print(x) + break + +if __name__ == "__main__": + main() diff --git a/src/data/transformer_activation_datamodule.py b/src/data/transformer_activation_datamodule.py new file mode 100644 index 0000000..6666a53 --- /dev/null +++ b/src/data/transformer_activation_datamodule.py @@ -0,0 +1,771 @@ +# from typing import Any, Dict, Optional, Tuple + +# import multiprocessing +# import torch +# from lightning import LightningDataModule +# from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split +# from torchvision.transforms import transforms +# import os +# from typing import Any, Iterator, cast + +# import torch +# from datasets import load_dataset +# from transformer_lens import HookedTransformer + +# class TransformerActivationDataModule(LightningDataModule): +# """ +# LightningDataModule` for the activation data of a transformer model. +# """ + +# def __init__( +# self, +# data_dir: str = "data/", +# batch_size: int = 64, +# num_workers: int = 0, +# pin_memory: bool = False, +# ) -> None: +# super().__init__() + +# # this line allows to access init params with 'self.hparams' attribute +# # also ensures init params will be stored in ckpt +# self.save_hyperparameters(logger=False) + +# self.data_train: Optional[Dataset] = None +# self.data_val: Optional[Dataset] = None +# self.data_test: Optional[Dataset] = None + +# self.batch_size_per_device = batch_size + +# self.probed_model = HookedTransformer.from_pretrained(self.hparams["model_name"]).to(DTYPES[self.hparams["enc_dtype"]]).to(self.hparams["device"]) +# self.probed_model.eval() +# self.probed_model.requires_grad_(False) +# self.probed_model_conf = { +# 'n_layers': self.probed_model.cfg.n_layers, +# 'd_model': self.probed_model.cfg.d_model, +# 'n_heads': self.probed_model.cfg.n_heads, +# 'd_head': self.probed_model.cfg.d_head, +# 'd_mlp': self.probed_model.cfg.d_mlp, +# 'd_vocab': self.probed_model.cfg.d_vocab +# } + +# def prepare_data(self) -> None: +# """Download data if needed. Lightning ensures that `self.prepare_data()` is called only +# within a single process on CPU, so you can safely add your downloading logic within. In +# case of multi-node training, the execution of this hook depends upon +# `self.prepare_data_per_node()`. + +# Do not use it to assign state (self.x = y). +# """ + + +# def setup(self, stage: Optional[str] = None) -> None: +# """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + +# This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and +# `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after +# `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to +# `self.setup()` once the data is prepared and available for use. + +# :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. +# """ +# # Divide batch size by the number of devices. +# if self.trainer is not None: +# if self.hparams.batch_size % self.trainer.world_size != 0: +# raise RuntimeError( +# f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." +# ) +# self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size + +# # load and split datasets only if not loaded already +# if not self.data_train and not self.data_val and not self.data_test: +# trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms) +# testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms) +# dataset = ConcatDataset(datasets=[trainset, testset]) +# self.data_train, self.data_val, self.data_test = random_split( +# dataset=dataset, +# lengths=self.hparams.train_val_test_split, +# generator=torch.Generator().manual_seed(42), +# ) + +# def train_dataloader(self) -> DataLoader[Any]: +# """Create and return the train dataloader. + +# :return: The train dataloader. +# """ +# return DataLoader( +# dataset=self.data_train, +# batch_size=self.batch_size_per_device, +# num_workers=self.hparams.num_workers, +# pin_memory=self.hparams.pin_memory, +# shuffle=True, +# ) + +# def val_dataloader(self) -> DataLoader[Any]: +# """Create and return the validation dataloader. + +# :return: The validation dataloader. +# """ +# return DataLoader( +# dataset=self.data_val, +# batch_size=self.batch_size_per_device, +# num_workers=self.hparams.num_workers, +# pin_memory=self.hparams.pin_memory, +# shuffle=False, +# ) + +# def test_dataloader(self) -> DataLoader[Any]: +# """Create and return the test dataloader. + +# :return: The test dataloader. +# """ +# return DataLoader( +# dataset=self.data_test, +# batch_size=self.batch_size_per_device, +# num_workers=self.hparams.num_workers, +# pin_memory=self.hparams.pin_memory, +# shuffle=False, +# ) + +# def load_text_dataset(self, dataset_path: str, split: str = "train", streaming: bool = True): +# """ +# Load a text dataset from Hugging Face's datasets library. +# """ +# data = load_dataset(dataset_path, split=split, streaming=streaming) +# return + + + +# def read_from_pile(address: str, max_lines: int = 100_000, start_line: int = 0): +# """Reads a file from the Pile dataset. Returns a generator.""" +# with open(address, "r") as f: +# for i, line in enumerate(f): +# if i < start_line: +# continue +# if i >= max_lines + start_line: +# break +# yield json.loads(line) + + +# def make_sentence_dataset(dataset_name: str, max_lines: int = 20_000, start_line: int = 0): +# """Returns a dataset from the Huggingface Datasets library.""" +# if dataset_name == "EleutherAI/pile": +# if not os.path.exists("pile0"): +# print("Downloading shard 0 of the Pile dataset (requires 50GB of disk space).") +# if not os.path.exists("pile0.zst"): +# os.system("curl https://the-eye.eu/public/AI/pile/train/00.jsonl.zst > pile0.zst") +# os.system("unzstd pile0.zst") +# dataset = Dataset.from_list(list(read_from_pile("pile0", max_lines=max_lines, start_line=start_line))) +# else: +# dataset = load_dataset(dataset_name, split="train")#, split=f"train[{start_line}:{start_line + max_lines}]") +# return dataset + + +# # Nora's Code from https://github.com/AlignmentResearch/tuned-lens/blob/main/tuned_lens/data.py +# def chunk_and_tokenize( +# data: T, +# tokenizer: PreTrainedTokenizerBase, +# *, +# format: str = "torch", +# num_proc: int = min(mp.cpu_count() // 2, 8), +# text_key: str = "text", +# max_length: int = 2048, +# return_final_batch: bool = False, +# load_from_cache_file: bool = True, +# ) -> Tuple[T, float]: +# """Perform GPT-style chunking and tokenization on a dataset. + +# The resulting dataset will consist entirely of chunks exactly `max_length` tokens +# long. Long sequences will be split into multiple chunks, and short sequences will +# be merged with their neighbors, using `eos_token` as a separator. The fist token +# will also always be an `eos_token`. + +# Args: +# data: The dataset to chunk and tokenize. +# tokenizer: The tokenizer to use. +# format: The format to return the dataset in, passed to `Dataset.with_format`. +# num_proc: The number of processes to use for tokenization. +# text_key: The key in the dataset to use as the text to tokenize. +# max_length: The maximum length of a batch of input ids. +# return_final_batch: Whether to return the final batch, which may be smaller +# than the others. +# load_from_cache_file: Whether to load from the cache file. + +# Returns: +# * The chunked and tokenized dataset. +# * The ratio of nats to bits per byte see https://arxiv.org/pdf/2101.00027.pdf, +# section 3.1. +# """ + +# def _tokenize_fn(x: Dict[str, list]): +# chunk_size = min(tokenizer.model_max_length, max_length) # tokenizer max length is 1024 for gpt2 +# sep = tokenizer.eos_token or "<|endoftext|>" +# joined_text = sep.join([""] + x[text_key]) +# output = tokenizer( +# # Concatenate all the samples together, separated by the EOS token. +# joined_text, # start with an eos token +# max_length=chunk_size, +# return_attention_mask=False, +# return_overflowing_tokens=True, +# truncation=True, +# ) + +# if overflow := output.pop("overflowing_tokens", None): +# # Slow Tokenizers return unnested lists of ints +# assert isinstance(output["input_ids"][0], int) + +# # Chunk the overflow into batches of size `chunk_size` +# chunks = [output["input_ids"]] + [ +# overflow[i * chunk_size : (i + 1) * chunk_size] for i in range(math.ceil(len(overflow) / chunk_size)) +# ] +# output = {"input_ids": chunks} + +# total_tokens = sum(len(ids) for ids in output["input_ids"]) +# total_bytes = len(joined_text.encode("utf-8")) + +# if not return_final_batch: +# # We know that the last sample will almost always be less than the max +# # number of tokens, and we don't want to pad, so we just drop it. +# output = {k: v[:-1] for k, v in output.items()} + +# output_batch_size = len(output["input_ids"]) + +# if output_batch_size == 0: +# raise ValueError( +# "Not enough data to create a single batch complete batch." +# " Either allow the final batch to be returned," +# " or supply more data." +# ) + +# # We need to output this in order to compute the number of bits per byte +# div, rem = divmod(total_tokens, output_batch_size) +# output["length"] = [div] * output_batch_size +# output["length"][-1] += rem + +# div, rem = divmod(total_bytes, output_batch_size) +# output["bytes"] = [div] * output_batch_size +# output["bytes"][-1] += rem + +# return output + +# data = data.map( +# _tokenize_fn, +# # Batching is important for ensuring that we don't waste tokens +# # since we always throw away the last element of the batch we +# # want to keep the batch size as large as possible +# batched=True, +# batch_size=2048, +# num_proc=num_proc, +# remove_columns=get_columns_all_equal(data), +# load_from_cache_file=load_from_cache_file, +# ) +# total_bytes: float = sum(data["bytes"]) +# total_tokens: float = sum(data["length"]) +# return data.with_format(format, columns=["input_ids"]), (total_tokens / total_bytes) / math.log(2) + + +# def get_columns_all_equal(dataset: Union[Dataset, DatasetDict]) -> List[str]: +# """Get a single list of columns in a `Dataset` or `DatasetDict`. + +# We assert the columms are the same across splits if it's a `DatasetDict`. + +# Args: +# dataset: The dataset to get the columns from. + +# Returns: +# A list of columns. +# """ +# if isinstance(dataset, DatasetDict): +# cols_by_split = dataset.column_names.values() +# columns = next(iter(cols_by_split)) +# if not all(cols == columns for cols in cols_by_split): +# raise ValueError("All splits must have the same columns") + +# return columns + +# return dataset.column_names + + +# # End Nora's Code from https://github.com/AlignmentResearch/tuned-lens/blob/main/tuned_lens/data.py + +# def make_activation_dataset( +# sentence_dataset: DataLoader, +# model: HookedTransformer, +# tensor_name: str, +# activation_width: int, +# dataset_folder: str, +# baukit: bool = False, +# chunk_size_gb: float = 2, +# device: torch.device = torch.device("cuda:0"), +# layer: int = 2, +# n_chunks: int = 1, +# max_length: int = 256, +# model_batch_size: int = 4, +# center_dataset: bool = False +# ) -> pd.DataFrame: +# print(f"Running model and saving activations to {dataset_folder}") +# with torch.no_grad(): +# chunk_size = chunk_size_gb * (2**30) # 2GB +# activation_size = ( +# activation_width * 2 * model_batch_size * max_length +# ) # 3072 mlp activations, 2 bytes per half, 1024 context window +# actives_per_chunk = chunk_size // activation_size +# dataset = [] +# n_saved_chunks = 0 +# for batch_idx, batch in tqdm(enumerate(sentence_dataset)): +# batch = batch["input_ids"].to(device) +# if baukit: +# # Don't have nanoGPT models integrated with transformer_lens so using baukit for activations +# with Trace(model, tensor_name) as ret: +# _ = model(batch) +# mlp_activation_data = ret.output +# mlp_activation_data = rearrange(mlp_activation_data, "b s n -> (b s) n").to(torch.float16).to(device) +# mlp_activation_data = nn.functional.gelu(mlp_activation_data) +# else: +# _, cache = model.run_with_cache(batch, stop_at_layer=layer + 1) +# mlp_activation_data = ( +# cache[tensor_name].to(device).to(torch.float16) +# ) # NOTE: could do all layers at once, but currently just doing 1 layer +# mlp_activation_data = rearrange(mlp_activation_data, "b s n -> (b s) n") + +# dataset.append(mlp_activation_data) +# if len(dataset) >= actives_per_chunk: +# if center_dataset: +# if n_saved_chunks == 0: +# chunk_mean = torch.mean(torch.cat(dataset), dim=0) +# dataset = [x - chunk_mean for x in dataset] + +# # Need to save, restart the list +# save_activation_chunk(dataset, n_saved_chunks, dataset_folder) +# n_saved_chunks += 1 +# print(f"Saved chunk {n_saved_chunks} of activations, total size: {batch_idx * activation_size} ") +# dataset = [] +# if n_saved_chunks == n_chunks: +# break + +# if n_saved_chunks < n_chunks: +# save_activation_chunk(dataset, n_saved_chunks, dataset_folder) +# print(f"Saved undersized chunk {n_saved_chunks} of activations, total size: {batch_idx * activation_size} ") + + + + +# # import os +# # from typing import Any, Iterator, cast + +# # import torch +# # from datasets import load_dataset +# # from torch.utils.data import DataLoader +# # from transformer_lens import HookedTransformer + + +# # class ActivationData: +# # """ +# # Class for streaming tokens and generating and storing activations +# # while training SAEs. +# # cfg: config object with the following attributes: +# # - dataset_path: path to the dataset +# # - use_cached_activations: whether to use cached activations +# # - cached_activations_path: path to the directory containing cached activations +# # - total_training_tokens: total number of tokens to train on +# # - n_batches_in_buffer: number of batches to store in the buffer +# # - store_batch_size: number of tokens to store in the buffer at a time +# # - train_batch_size: number of tokens to train on at a time +# # - context_size: number of tokens in the context +# # - d_in: input dimensionality +# # - hook_point_layer: layer to hook into +# # - hook_point_head_index: head index to hook into +# # - hook_point: name of the hook +# # - device: device to store the activations on +# # - dtype: data type to store the activations in +# # model: the model to generate activations from +# # create_dataloader: whether to create a dataloader +# # """ + +# # def __init__( +# # self, +# # cfg: Any, +# # model: HookedTransformer, +# # create_dataloader: bool = True, +# # ): +# # self.cfg = cfg +# # self.model = model +# # self.dataset = load_dataset(cfg.dataset_path, split="train", streaming=True) +# # self.iterable_dataset = iter(self.dataset) + +# # # Check if dataset is tokenized +# # dataset_sample = next(self.iterable_dataset) +# # self.cfg.is_dataset_tokenized = "tokens" in dataset_sample.keys() +# # print( +# # f"Dataset is {'tokenized' if self.cfg.is_dataset_tokenized else 'not tokenized'}! Updating config." +# # ) +# # self.iterable_dataset = iter(self.dataset) # Reset iterator after checking + +# # if self.cfg.use_cached_activations: # EDIT: load from multi-layer acts +# # assert self.cfg.cached_activations_path is not None # keep pyright happy +# # # Sanity check: does the cache directory exist? +# # assert os.path.exists( +# # self.cfg.cached_activations_path +# # ), f"Cache directory {self.cfg.cached_activations_path} does not exist. Consider double-checking your dataset, model, and hook names." + +# # self.next_cache_idx = 0 # which file to open next +# # self.next_idx_within_buffer = 0 # where to start reading from in that file + +# # # Check that we have enough data on disk +# # first_buffer = torch.load(f"{self.cfg.cached_activations_path}/0.pt") +# # buffer_size_on_disk = first_buffer.shape[0] +# # n_buffers_on_disk = len(os.listdir(self.cfg.cached_activations_path)) +# # # Note: we're assuming all files have the same number of tokens +# # # (which seems reasonable imo since that's what our script does) +# # n_activations_on_disk = buffer_size_on_disk * n_buffers_on_disk +# # assert ( +# # n_activations_on_disk > self.cfg.total_training_tokens +# # ), f"Only {n_activations_on_disk/1e6:.1f}M activations on disk, but cfg.total_training_tokens is {self.cfg.total_training_tokens/1e6:.1f}M." + +# # # TODO add support for "mixed loading" (ie use cache until you run out, then switch over to streaming from HF) + +# # if create_dataloader: +# # # fill buffer half a buffer, so we can mix it with a new buffer +# # self.storage_buffer = self.get_buffer(self.cfg.n_batches_in_buffer // 2) +# # self.dataloader = self.get_data_loader() + +# # def get_batch_tokens(self): +# # """ +# # Streams a batch of tokens from a dataset. +# # """ + +# # batch_size = self.cfg.store_batch_size +# # context_size = self.cfg.context_size +# # device = self.cfg.device + +# # batch_tokens = torch.zeros( +# # size=(0, context_size), device=device, dtype=torch.long, requires_grad=False +# # ) + +# # current_batch = [] +# # current_length = 0 + +# # # pbar = tqdm(total=batch_size, desc="Filling batches") +# # while batch_tokens.shape[0] < batch_size: +# # if not self.cfg.is_dataset_tokenized: +# # s = next(self.iterable_dataset)["text"] +# # tokens = self.model.to_tokens( +# # s, +# # truncate=True, +# # move_to_device=True, +# # ).squeeze(0) +# # assert ( +# # len(tokens.shape) == 1 +# # ), f"tokens.shape should be 1D but was {tokens.shape}" +# # else: +# # tokens = torch.tensor( +# # next(self.iterable_dataset)["tokens"], +# # dtype=torch.long, +# # device=device, +# # requires_grad=False, +# # ) +# # token_len = tokens.shape[0] + +# # # TODO: Fix this so that we are limiting how many tokens we get from the same context. +# # assert self.model.tokenizer is not None # keep pyright happy +# # bos_token_id_tensor = torch.tensor( +# # [self.model.tokenizer.bos_token_id], +# # device=tokens.device, +# # dtype=torch.long, +# # ) +# # while token_len > 0 and batch_tokens.shape[0] < batch_size: +# # # Space left in the current batch +# # space_left = context_size - current_length + +# # # If the current tokens fit entirely into the remaining space +# # if token_len <= space_left: +# # current_batch.append(tokens[:token_len]) +# # current_length += token_len +# # break + +# # else: +# # # Take as much as will fit +# # current_batch.append(tokens[:space_left]) + +# # # Remove used part, add BOS +# # tokens = tokens[space_left:] +# # tokens = torch.cat( +# # ( +# # bos_token_id_tensor, +# # tokens, +# # ), +# # dim=0, +# # ) + +# # token_len -= space_left +# # token_len += 1 +# # current_length = context_size + +# # # If a batch is full, concatenate and move to next batch +# # if current_length == context_size: +# # full_batch = torch.cat(current_batch, dim=0) +# # batch_tokens = torch.cat( +# # (batch_tokens, full_batch.unsqueeze(0)), dim=0 +# # ) +# # current_batch = [] +# # current_length = 0 + +# # # pbar.n = batch_tokens.shape[0] +# # # pbar.refresh() +# # return batch_tokens[:batch_size] + +# # def get_activations(self, batch_tokens: torch.Tensor, get_loss: bool = False): +# # """ +# # Returns activations of shape (batches, context, num_layers, d_in) +# # """ +# # layers = ( +# # self.cfg.hook_point_layer +# # if isinstance(self.cfg.hook_point_layer, list) +# # else [self.cfg.hook_point_layer] +# # ) +# # act_names = [self.cfg.hook_point.format(layer=layer) for layer in layers] +# # hook_point_max_layer = max(layers) +# # if self.cfg.hook_point_head_index is not None: +# # layerwise_activations = self.model.run_with_cache( +# # batch_tokens, +# # names_filter=act_names, +# # stop_at_layer=hook_point_max_layer + 1, +# # )[1] +# # activations_list = [ +# # layerwise_activations[act_name][:, :, self.cfg.hook_point_head_index] +# # for act_name in act_names +# # ] +# # else: +# # layerwise_activations = self.model.run_with_cache( +# # batch_tokens, +# # names_filter=act_names, +# # stop_at_layer=hook_point_max_layer + 1, +# # )[1] +# # activations_list = [ +# # layerwise_activations[act_name] for act_name in act_names +# # ] + +# # # Stack along a new dimension to keep separate layers distinct +# # stacked_activations = torch.stack(activations_list, dim=2) + +# # return stacked_activations + +# # def get_buffer(self, n_batches_in_buffer: int): +# # context_size = self.cfg.context_size +# # batch_size = self.cfg.store_batch_size +# # d_in = self.cfg.d_in +# # total_size = batch_size * n_batches_in_buffer +# # num_layers = ( +# # len(self.cfg.hook_point_layer) +# # if isinstance(self.cfg.hook_point_layer, list) +# # else 1 +# # ) # Number of hook points or layers + +# # if self.cfg.use_cached_activations: +# # # Load the activations from disk +# # buffer_size = total_size * context_size +# # # Initialize an empty tensor with an additional dimension for layers +# # new_buffer = torch.zeros( +# # (buffer_size, num_layers, d_in), +# # dtype=self.cfg.dtype, +# # device=self.cfg.device, +# # ) +# # n_tokens_filled = 0 + +# # # Assume activations for different layers are stored separately and need to be combined +# # while n_tokens_filled < buffer_size: +# # if not os.path.exists( +# # f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt" +# # ): +# # print( +# # "\n\nWarning: Ran out of cached activation files earlier than expected." +# # ) +# # print( +# # f"Expected to have {buffer_size} activations, but only found {n_tokens_filled}." +# # ) +# # if buffer_size % self.cfg.total_training_tokens != 0: +# # print( +# # "This might just be a rounding error β€” your batch_size * n_batches_in_buffer * context_size is not divisible by your total_training_tokens" +# # ) +# # print(f"Returning a buffer of size {n_tokens_filled} instead.") +# # print("\n\n") +# # new_buffer = new_buffer[:n_tokens_filled, ...] +# # return new_buffer + +# # activations = torch.load( +# # f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt" +# # ) +# # taking_subset_of_file = False +# # if n_tokens_filled + activations.shape[0] > buffer_size: +# # activations = activations[: buffer_size - n_tokens_filled, ...] +# # taking_subset_of_file = True + +# # new_buffer[ +# # n_tokens_filled : n_tokens_filled + activations.shape[0], ... +# # ] = activations + +# # if taking_subset_of_file: +# # self.next_idx_within_buffer = activations.shape[0] +# # else: +# # self.next_cache_idx += 1 +# # self.next_idx_within_buffer = 0 + +# # n_tokens_filled += activations.shape[0] + +# # return new_buffer + +# # refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size) +# # # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers +# # new_buffer = torch.zeros( +# # (total_size, context_size, num_layers, d_in), +# # dtype=self.cfg.dtype, +# # device=self.cfg.device, +# # ) + +# # for refill_batch_idx_start in refill_iterator: +# # refill_batch_tokens = self.get_batch_tokens() +# # refill_activations = self.get_activations(refill_batch_tokens) +# # new_buffer[ +# # refill_batch_idx_start : refill_batch_idx_start + batch_size, ... +# # ] = refill_activations + +# # # pbar.update(1) + +# # new_buffer = new_buffer.reshape(-1, num_layers, d_in) +# # new_buffer = new_buffer[torch.randperm(new_buffer.shape[0])] + +# # return new_buffer + +# # def get_data_loader( +# # self, +# # ) -> Iterator[Any]: +# # """ +# # Return a torch.utils.dataloader which you can get batches from. + +# # Should automatically refill the buffer when it gets to n % full. +# # (better mixing if you refill and shuffle regularly). +# # """ + +# # batch_size = self.cfg.train_batch_size + +# # # 1. # create new buffer by mixing stored and new buffer +# # mixing_buffer = torch.cat( +# # [self.get_buffer(self.cfg.n_batches_in_buffer // 2), self.storage_buffer], +# # dim=0, +# # ) + +# # mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])] + +# # # 2. put 50 % in storage +# # self.storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2] + +# # # 3. put other 50 % in a dataloader +# # dataloader = iter( +# # DataLoader( +# # # TODO: seems like a typing bug? +# # cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]), +# # batch_size=batch_size, +# # shuffle=True, +# # ) +# # ) + +# # return dataloader + +# # def next_batch(self): +# # """ +# # Get the next batch from the current DataLoader. +# # If the DataLoader is exhausted, refill the buffer and create a new DataLoader. +# # """ +# # try: +# # # Try to get the next batch +# # return next(self.dataloader) +# # except StopIteration: +# # # If the DataLoader is exhausted, create a new one +# # self.dataloader = self.get_data_loader() +# # return next(self.dataloader) + + + + + + +# ############################################################################################################ +# ####### neel nanda's buffer idea for storing activations and using them for training the autoencoder ###### +# ############################################################################################################ + +# # def shuffle_data(all_tokens): +# # print("Shuffled data") +# # return all_tokens[torch.randperm(all_tokens.shape[0])] + +# # loading_data_first_time = False +# # if loading_data_first_time: +# # data = load_dataset("NeelNanda/c4-code-tokenized-2b", split="train", cache_dir="/workspace/cache/") +# # data.save_to_disk("/workspace/data/c4_code_tokenized_2b.hf") +# # data.set_format(type="torch", columns=["tokens"]) +# # all_tokens = data["tokens"] +# # all_tokens.shape + +# # all_tokens_reshaped = einops.rearrange(all_tokens, "batch (x seq_len) -> (batch x) seq_len", x=8, seq_len=128) +# # all_tokens_reshaped[:, 0] = model.tokenizer.bos_token_id +# # all_tokens_reshaped = all_tokens_reshaped[torch.randperm(all_tokens_reshaped.shape[0])] +# # torch.save(all_tokens_reshaped, "/workspace/data/c4_code_2b_tokens_reshaped.pt") +# # else: +# # # data = datasets.load_from_disk("/workspace/data/c4_code_tokenized_2b.hf") +# # all_tokens = torch.load("/workspace/data/c4_code_2b_tokens_reshaped.pt") +# # all_tokens = shuffle_data(all_tokens) + + + +# # class Buffer(): +# # """ +# # This defines a data buffer, to store a bunch of MLP acts that can be used to train the autoencoder. +# # It'll automatically run the model to generate more when it gets halfway empty. +# # requires a cfg dictionary with the following +# # buffer_size: int, the size of the buffer +# # act_size: int, the size of the activations +# # device: torch device, where to store the buffer +# # buffer_batches: int, how many batches to run to fill the buffer +# # model_batch_size: int, how many tokens to run at once +# # layer: int, which layer to stop at +# # act_name: str, the name of the activation to store +# # batch_size: int, how many activations to return at once +# # """ + +# # def __init__(self, cfg): +# # self.buffer = torch.zeros((cfg["buffer_size"], cfg["act_size"]), dtype=torch.bfloat16, requires_grad=False).to(cfg["device"]) +# # self.cfg = cfg +# # self.token_pointer = 0 +# # self.first = True +# # self.refresh() + +# # @torch.no_grad() +# # def refresh(self): +# # self.pointer = 0 +# # with torch.autocast("cuda", torch.bfloat16): +# # if self.first: +# # num_batches = self.cfg["buffer_batches"] +# # else: +# # num_batches = self.cfg["buffer_batches"]//2 +# # self.first = False +# # for _ in range(0, num_batches, self.cfg["model_batch_size"]): +# # tokens = all_tokens[self.token_pointer:self.token_pointer+self.cfg["model_batch_size"]] +# # _, cache = model.run_with_cache(tokens, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"]) +# # acts = cache[cfg["act_name"]].reshape(-1, self.cfg["act_size"]) + +# # # print(tokens.shape, acts.shape, self.pointer, self.token_pointer) +# # self.buffer[self.pointer: self.pointer+acts.shape[0]] = acts +# # self.pointer += acts.shape[0] +# # self.token_pointer += self.cfg["model_batch_size"] +# # # if self.token_pointer > all_tokens.shape[0] - self.cfg["model_batch_size"]: +# # # self.token_pointer = 0 + +# # self.pointer = 0 +# # self.buffer = self.buffer[torch.randperm(self.buffer.shape[0]).to(cfg["device"])] + +# # @torch.no_grad() +# # def next(self): +# # out = self.buffer[self.pointer:self.pointer+self.cfg["batch_size"]] +# # self.pointer += self.cfg["batch_size"] +# # if self.pointer > self.buffer.shape[0]//2 - self.cfg["batch_size"]: +# # # print("Refreshing the buffer!") +# # self.refresh() +# # return out \ No newline at end of file diff --git a/src/eval.py b/src/eval.py new file mode 100644 index 0000000..b70faae --- /dev/null +++ b/src/eval.py @@ -0,0 +1,99 @@ +from typing import Any, Dict, List, Tuple + +import hydra +import rootutils +from lightning import LightningDataModule, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +# ------------------------------------------------------------------------------------ # +# the setup_root above is equivalent to: +# - adding project root dir to PYTHONPATH +# (so you don't need to force user to install project as a package) +# (necessary before importing any local modules e.g. `from src import utils`) +# - setting up PROJECT_ROOT environment variable +# (which is used as a base for paths in "configs/paths/default.yaml") +# (this way all filepaths are the same no matter where you run the code) +# - loading environment variables from ".env" in root dir +# +# you can remove it if you: +# 1. either install project as a package or move entry files to project root dir +# 2. set `root_dir` to "." in "configs/paths/default.yaml" +# +# more info: https://github.com/ashleve/rootutils +# ------------------------------------------------------------------------------------ # + +from src.utils import ( + RankedLogger, + extras, + instantiate_loggers, + log_hyperparameters, + task_wrapper, +) + +log = RankedLogger(__name__, rank_zero_only=True) + + +@task_wrapper +def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Evaluates given checkpoint on a datamodule testset. + + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Tuple[dict, dict] with metrics and dict with all instantiated objects. + """ + assert cfg.ckpt_path + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model) + + log.info("Instantiating loggers...") + logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + log_hyperparameters(object_dict) + + log.info("Starting testing!") + trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path) + + # for predictions use trainer.predict(...) + # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path) + + metric_dict = trainer.callback_metrics + + return metric_dict, object_dict + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml") +def main(cfg: DictConfig) -> None: + """Main entry point for evaluation. + + :param cfg: DictConfig configuration composed by Hydra. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + evaluate(cfg) + + +if __name__ == "__main__": + main() diff --git a/src/evaluation/__init__.py b/src/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/evaluation/gsm8k/__init__.py b/src/evaluation/gsm8k/__init__.py similarity index 100% rename from evaluation/gsm8k/__init__.py rename to src/evaluation/gsm8k/__init__.py diff --git a/evaluation/gsm8k/compare_with_reference.py b/src/evaluation/gsm8k/compare_with_reference.py similarity index 100% rename from evaluation/gsm8k/compare_with_reference.py rename to src/evaluation/gsm8k/compare_with_reference.py diff --git a/evaluation/gsm8k/test_compare_with_reference.py b/src/evaluation/gsm8k/test_compare_with_reference.py similarity index 100% rename from evaluation/gsm8k/test_compare_with_reference.py rename to src/evaluation/gsm8k/test_compare_with_reference.py diff --git a/thirdparty/openai/LICENSE b/src/openai/LICENSE similarity index 100% rename from thirdparty/openai/LICENSE rename to src/openai/LICENSE diff --git a/thirdparty/openai/README.md b/src/openai/README.md similarity index 100% rename from thirdparty/openai/README.md rename to src/openai/README.md diff --git a/thirdparty/openai/grade_school_math/calculator.py b/src/openai/grade_school_math/calculator.py similarity index 100% rename from thirdparty/openai/grade_school_math/calculator.py rename to src/openai/grade_school_math/calculator.py diff --git a/thirdparty/openai/grade_school_math/data/example_model_solutions.jsonl b/src/openai/grade_school_math/data/example_model_solutions.jsonl similarity index 100% rename from thirdparty/openai/grade_school_math/data/example_model_solutions.jsonl rename to src/openai/grade_school_math/data/example_model_solutions.jsonl diff --git a/thirdparty/openai/grade_school_math/data/test.jsonl b/src/openai/grade_school_math/data/test.jsonl similarity index 100% rename from thirdparty/openai/grade_school_math/data/test.jsonl rename to src/openai/grade_school_math/data/test.jsonl diff --git a/thirdparty/openai/grade_school_math/data/test_socratic.jsonl b/src/openai/grade_school_math/data/test_socratic.jsonl similarity index 100% rename from thirdparty/openai/grade_school_math/data/test_socratic.jsonl rename to src/openai/grade_school_math/data/test_socratic.jsonl diff --git a/thirdparty/openai/grade_school_math/data/train.jsonl b/src/openai/grade_school_math/data/train.jsonl similarity index 100% rename from thirdparty/openai/grade_school_math/data/train.jsonl rename to src/openai/grade_school_math/data/train.jsonl diff --git a/thirdparty/openai/grade_school_math/data/train_socratic.jsonl b/src/openai/grade_school_math/data/train_socratic.jsonl similarity index 100% rename from thirdparty/openai/grade_school_math/data/train_socratic.jsonl rename to src/openai/grade_school_math/data/train_socratic.jsonl diff --git a/thirdparty/openai/grade_school_math/dataset.py b/src/openai/grade_school_math/dataset.py similarity index 100% rename from thirdparty/openai/grade_school_math/dataset.py rename to src/openai/grade_school_math/dataset.py diff --git a/thirdparty/openai/grade_school_math/img/example_problems.png b/src/openai/grade_school_math/img/example_problems.png similarity index 100% rename from thirdparty/openai/grade_school_math/img/example_problems.png rename to src/openai/grade_school_math/img/example_problems.png diff --git a/thirdparty/openai/grade_school_math/sample.py b/src/openai/grade_school_math/sample.py similarity index 100% rename from thirdparty/openai/grade_school_math/sample.py rename to src/openai/grade_school_math/sample.py diff --git a/thirdparty/openai/grade_school_math/train.py b/src/openai/grade_school_math/train.py similarity index 100% rename from thirdparty/openai/grade_school_math/train.py rename to src/openai/grade_school_math/train.py diff --git a/thirdparty/openai/grade_school_math/view_model_solutions.py b/src/openai/grade_school_math/view_model_solutions.py similarity index 100% rename from thirdparty/openai/grade_school_math/view_model_solutions.py rename to src/openai/grade_school_math/view_model_solutions.py diff --git a/thirdparty/openai/setup.py b/src/openai/setup.py similarity index 100% rename from thirdparty/openai/setup.py rename to src/openai/setup.py diff --git a/src/Readme.md b/src/tariners/Readme.md similarity index 100% rename from src/Readme.md rename to src/tariners/Readme.md diff --git a/src/behavior_cloning.py b/src/tariners/behavior_cloning.py similarity index 100% rename from src/behavior_cloning.py rename to src/tariners/behavior_cloning.py diff --git a/src/callbacks.py b/src/tariners/callbacks.py similarity index 100% rename from src/callbacks.py rename to src/tariners/callbacks.py diff --git a/src/collators.py b/src/tariners/collators.py similarity index 100% rename from src/collators.py rename to src/tariners/collators.py diff --git a/src/constraints.py b/src/tariners/constraints.py similarity index 100% rename from src/constraints.py rename to src/tariners/constraints.py diff --git a/src/eval_script.py b/src/tariners/eval_script.py similarity index 100% rename from src/eval_script.py rename to src/tariners/eval_script.py diff --git a/src/pause_classifier_wrapper.py b/src/tariners/pause_classifier_wrapper.py similarity index 100% rename from src/pause_classifier_wrapper.py rename to src/tariners/pause_classifier_wrapper.py diff --git a/src/pretrain.py b/src/tariners/pretrain.py similarity index 100% rename from src/pretrain.py rename to src/tariners/pretrain.py diff --git a/src/reward_conditioned.py b/src/tariners/reward_conditioned.py similarity index 100% rename from src/reward_conditioned.py rename to src/tariners/reward_conditioned.py diff --git a/src/rewards.py b/src/tariners/rewards.py similarity index 100% rename from src/rewards.py rename to src/tariners/rewards.py diff --git a/src/run_inference.py b/src/tariners/run_inference.py similarity index 100% rename from src/run_inference.py rename to src/tariners/run_inference.py diff --git a/src/samplers.py b/src/tariners/samplers.py similarity index 100% rename from src/samplers.py rename to src/tariners/samplers.py diff --git a/src/sft.py b/src/tariners/sft.py similarity index 100% rename from src/sft.py rename to src/tariners/sft.py diff --git a/src/sft_pause.py b/src/tariners/sft_pause.py similarity index 100% rename from src/sft_pause.py rename to src/tariners/sft_pause.py diff --git a/src/testing_inv_modeling.py b/src/tariners/testing_inv_modeling.py similarity index 100% rename from src/testing_inv_modeling.py rename to src/tariners/testing_inv_modeling.py diff --git a/src/trainers.py b/src/tariners/trainers.py similarity index 100% rename from src/trainers.py rename to src/tariners/trainers.py diff --git a/src/utils.py b/src/tariners/utils.py similarity index 100% rename from src/utils.py rename to src/tariners/utils.py diff --git a/src/wsft.py b/src/tariners/wsft.py similarity index 100% rename from src/wsft.py rename to src/tariners/wsft.py diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..0f90842 --- /dev/null +++ b/src/train.py @@ -0,0 +1,132 @@ +from typing import Any, Dict, List, Optional, Tuple + +import hydra +import lightning as L +import rootutils +import torch +from lightning import Callback, LightningDataModule, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +# ------------------------------------------------------------------------------------ # +# the setup_root above is equivalent to: +# - adding project root dir to PYTHONPATH +# (so you don't need to force user to install project as a package) +# (necessary before importing any local modules e.g. `from src import utils`) +# - setting up PROJECT_ROOT environment variable +# (which is used as a base for paths in "configs/paths/default.yaml") +# (this way all filepaths are the same no matter where you run the code) +# - loading environment variables from ".env" in root dir +# +# you can remove it if you: +# 1. either install project as a package or move entry files to project root dir +# 2. set `root_dir` to "." in "configs/paths/default.yaml" +# +# more info: https://github.com/ashleve/rootutils +# ------------------------------------------------------------------------------------ # + +from src.utils import ( + RankedLogger, + extras, + get_metric_value, + instantiate_callbacks, + instantiate_loggers, + log_hyperparameters, + task_wrapper, +) + +log = RankedLogger(__name__, rank_zero_only=True) + + +@task_wrapper +def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Trains the model. Can additionally evaluate on a testset, using best weights obtained during + training. + + This method is wrapped in optional @task_wrapper decorator, that controls the behavior during + failure. Useful for multiruns, saving info about the crash, etc. + + :param cfg: A DictConfig configuration composed by Hydra. + :return: A tuple with metrics and dict with all instantiated objects. + """ + # set seed for random number generators in pytorch, numpy and python.random + if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + + log.info(f"Instantiating datamodule <{cfg.data._target_}>") + datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data, _recursive_=False) + + log.info(f"Instantiating model <{cfg.model._target_}>") + model: LightningModule = hydra.utils.instantiate(cfg.model, _recursive_=False) + + log.info("Instantiating callbacks...") + callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) + + log.info("Instantiating loggers...") + logger: List[Logger] = instantiate_loggers(cfg.get("logger")) + + log.info(f"Instantiating trainer <{cfg.trainer._target_}>") + trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) + + object_dict = { + "cfg": cfg, + "datamodule": datamodule, + "model": model, + "callbacks": callbacks, + "logger": logger, + "trainer": trainer, + } + + if logger: + log.info("Logging hyperparameters!") + log_hyperparameters(object_dict) + + if cfg.get("train"): + log.info("Starting training!") + trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + + train_metrics = trainer.callback_metrics + + if cfg.get("test"): + log.info("Starting testing!") + ckpt_path = trainer.checkpoint_callback.best_model_path + if ckpt_path == "": + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = None + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + log.info(f"Best ckpt path: {ckpt_path}") + + test_metrics = trainer.callback_metrics + + # merge train and test metrics + metric_dict = {**train_metrics, **test_metrics} + + return metric_dict, object_dict + + +@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") +def main(cfg: DictConfig) -> Optional[float]: + """Main entry point for training. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Optional[float] with optimized metric value. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + # train the model + metric_dict, _ = train(cfg) + + # safely retrieve metric value for hydra-based hyperparameter optimization + metric_value = get_metric_value( + metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") + ) + + # return optimized metric + return metric_value + + +if __name__ == "__main__": + main() diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..5b0707c --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,5 @@ +from src.utils.instantiators import instantiate_callbacks, instantiate_loggers +from src.utils.logging_utils import log_hyperparameters +from src.utils.pylogger import RankedLogger +from src.utils.rich_utils import enforce_tags, print_config_tree +from src.utils.utils import extras, get_metric_value, task_wrapper diff --git a/src/utils/generate_activation_data.py b/src/utils/generate_activation_data.py new file mode 100644 index 0000000..9579301 --- /dev/null +++ b/src/utils/generate_activation_data.py @@ -0,0 +1,456 @@ +from typing import Any, Dict, Optional, Tuple + +import multiprocessing as mp +import torch +from lightning import LightningDataModule +from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split +# from torchvision.transforms import transforms +import os +from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, Literal + +import torch +from datasets import Dataset, DatasetDict, load_dataset, concatenate_datasets +import pandas as pd +from transformer_lens import HookedTransformer +from transformers import PreTrainedTokenizerBase +from tqdm import tqdm + + +T = TypeVar("T", bound=Union[Dataset, DatasetDict]) + +MODEL_BATCH_SIZE = 4 +CHUNK_SIZE_GB = 2.0 +MAX_SENTENCE_LEN = 256 +DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} + + + +class GenerateActivationData(LightningDataModule): + """ + LightningDataModule` for the activation data of a transformer model. + """ + + def __init__( + self, + data_dir: str = "data/", + batch_size: int = 64, + num_workers: int = 0, + pin_memory: bool = False, + **kwargs: Any, + ) -> None: + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + print(self.hparams) + self.data_train: Optional[Dataset] = None + self.data_val: Optional[Dataset] = None + self.data_test: Optional[Dataset] = None + + self.batch_size_per_device = batch_size + + self.probed_model = HookedTransformer.from_pretrained(self.hparams["model_name"]).to(self.hparams["model_device"]) + self.probed_model.name = self.hparams["model_name"] + self.probed_model.eval() + self.probed_model.requires_grad_(False) + self.probed_model_conf = { + 'n_layers': self.probed_model.cfg.n_layers, + 'd_model': self.probed_model.cfg.d_model, + 'n_heads': self.probed_model.cfg.n_heads, + 'd_head': self.probed_model.cfg.d_head, + 'd_mlp': self.probed_model.cfg.d_mlp, + 'd_vocab': self.probed_model.cfg.d_vocab + } + + # def load_text_dataset(self, dataset_path: str, split: str = "train", streaming: bool = True): + # """ + # Load a text dataset from Hugging Face's datasets library. + # """ + # data = load_dataset(dataset_path, split=split, streaming=streaming) + # return + + def make_sentence_dataset(self, dataset_name: str, data_files: List, data_path: str = None, max_lines: int = 20_000, start_line: int = 0): + """Returns a dataset from the Huggingface Datasets library.""" + # if dataset_name == "EleutherAI/pile": + # if not os.path.exists(os.path.join(data_path, "pile0.zst")): + # print("Downloading shard 0 of the Pile dataset (requires 50GB of disk space).") + # if not os.path.exists(os.path.join(data_path, "pile0.zst")): + # os.system(f"curl -o {data_path}/pile0.zst https://the-eye.eu/public/AI/pile/train/00.jsonl.zst ") + # os.system(f"unzstd {data_path}/pile0.zst") + # dataset = Dataset.from_list(list(self.read_from_pile("{data_path}/pile0", max_lines=max_lines, start_line=start_line))) + # else: + # os.environ["HF_DATASETS_CACHE"] = data_path + # data_path = "/dlabdata1/masani/symbolic_probing/data" + # data_files = {"train": ["train/00.jsonl.zst", "train/01.jsonl.zst"], "validation": "val.jsonl.zst", "test": "test.jsonl.zst"} + self.dataset = load_dataset("monology/pile-uncopyrighted", data_files=data_files, cache_dir=data_path) + return self.dataset + + def read_from_pile(self, address: str, max_lines: int = 100_000, start_line: int = 0): + """Reads a file from the Pile dataset. Returns a generator.""" + with open(address, "r") as f: + for i, line in enumerate(f): + if i < start_line: + continue + if i >= max_lines + start_line: + break + yield json.loads(line) + + def save_activation_chunk(self, dataset, n_saved_chunks, dataset_folder): + dataset_t = torch.cat(dataset, dim=0).to("cpu") + os.makedirs(dataset_folder, exist_ok=True) + with open(dataset_folder + "/" + str(n_saved_chunks) + ".pt", "wb") as f: + torch.save(dataset_t, f) + + def make_activation_dataset( + self, + sentence_dataset: DataLoader = None, + model: HookedTransformer = None, + tensor_name: str = 'blocks.1.hook_mlp_out', + activation_width: int = 512, + dataset_folder: str = "activations", + baukit: bool = False, + chunk_size_gb: float = 2, + layer: int = 2, + n_chunks: int = 1, + max_length: int = 1024, + model_batch_size: int = 4, + center_dataset: bool = False + ) -> pd.DataFrame: + if sentence_dataset is None: + sentence_dataset = self.dataset + if model is None: + model = self.probed_model + tokenizer = model.tokenizer + device = next(model.parameters()).device + dataset_folder = os.path.join(dataset_folder, model.name+"_"+tensor_name) + # max_length = min(max_length, tokenizer.model_max_length, model.cfg.n_ctx) # model.pos_embed.W_pos.shape[0] perhaps? + activities_per_input = 256 + n_saved_chunks = 0 + if os.path.exists(dataset_folder): + # removing the folder and its contents and remaking it + os.system(f"rm -r {dataset_folder}") + os.makedirs(os.path.join(dataset_folder), exist_ok=True) + + generator = torch.Generator(device=device).manual_seed(42) + + + data = {'activations': []} # Store activations here, add logits if you want to save them as well + data_size = 0 # Keep track of the data size for periodic saving + + with torch.no_grad(): + for split, split_dataset in sentence_dataset.items(): + split_dataloader = DataLoader(split_dataset, batch_size=model_batch_size, shuffle=False) + for batch_idx, batch in tqdm(enumerate(split_dataloader)): + tokenized_batch = model.to_tokens(batch['text']) + # tokenized_batch = tokenizer(batch['text'], padding=True, max_length=max_length, truncation=True, return_tensors="pt") + (logits, loss), cache = model.run_with_cache(tokenized_batch, return_type='both') + # model.tokenizer.decode(logits.argmax(dim=-1)[0]) # to decode the logits and see how the predictions look like + + labels = tokenized_batch[:, 1:] + # logits = logits[:, :-1, :] + activations = cache[tensor_name][:, :-1, :] + # subsample 256 activations from each of the 1024 context window without replacement + # compute NLL loss for each of the 256 activations + labels_not_pad_mask = (labels != tokenizer.pad_token_id) + labels = labels[labels_not_pad_mask] + # logits = logits[labels_not_pad_mask] + activations = activations[labels_not_pad_mask] + perm = torch.randperm(labels.shape[0], generator=generator, device=device) + + activations_subsampled = activations[perm[:activities_per_input * model_batch_size]].detach().to(torch.float16).reshape(-1, activations.shape[-1]).cpu() + data_size += activations_subsampled.nelement() * activations_subsampled.element_size() + # data_size += activations_subsampled.nbytes #for numpy + # logits_subsampled = logits[perm[:activities_per_input * model_batch_size]].detach().to(torch.float16).reshape(-1, logits.shape[-1]).cpu().numpy() + + # activation_data.append(activations_subsampled) + # logit_data.append(logits_subsampled) + # data_size += activations_subsampled.nelement() * activations_subsampled.element_size() + logits_subsampled.nelement() * logits_subsampled.element_size() + # data = {'activations': activations_subsampled, 'logits': logits_subsampled} # use this if you want to save the logits as well + # data = {'activations': activations_subsampled} # use this if you want to save only the activations + data['activations'].extend(activations_subsampled) + + # for key in data: + # data[key].nelement() * data[key].element_size() + # for key in data: + # data_size += data[key].nbytes() + # data_size += data[key].nelement() * data[key].element_size() + + + if data_size >= 2 * (1024 ** 3): # Assuming 2GB before saving and resetting 2 * (1024 ** 3) + self.save_dataset_chunk(data, dataset_folder, split) + for key in data: + data[key] = [] + data_size = 0 + + # # for debugging + # if batch_idx > 10: + # break + + # Save any remaining data that didn't reach the threshold + if data_size > 0: + self.save_dataset_chunk(data, dataset_folder, split) + + # Concatenate all chunks and make one final dataset + final_datasets = self.concatenate_datasets(sentence_dataset, dataset_folder, tokenizer, max_length) + return final_datasets + + # def tensor_to_list(self, tensor): + # """Converts a tensor to a list of values, ensuring compatibility with the datasets library.""" + # return tensor.detach().cpu().tolist() + + def save_dataset_chunk(self, data, dataset_folder, split_name): + """Saves a chunk of the dataset to disk using the datasets library, tailored for PyTorch tensors.""" + # Prepare the data by converting tensors to lists of values + # activations_lists = [self.tensor_to_list(a) for a in activation_data] + # logits_lists = [self.tensor_to_list(l) for l in logit_data] + + # for key in data: + # data[key] = self.tensor_to_list(data[key]) + + # Prepare a list of dictionaries (each dictionary corresponds to a single example in the dataset) + # data = [{'activations': act, 'logits': log} for act, log in zip(activations_lists, logits_lists)] + + # Create a Hugging Face dataset from the list of dictionaries + dataset_chunk = Dataset.from_dict(data) + + if not os.path.exists(os.path.join(dataset_folder, f"{split_name}")): + os.makedirs(os.path.join(dataset_folder, f"{split_name}")) + last_chunk_id = len(os.listdir(os.path.join(dataset_folder, f"{split_name}"))) + save_path = os.path.join(dataset_folder, f"{split_name}", f"chunk_{last_chunk_id}") + + dataset_chunk.save_to_disk(save_path) + print(f"chunk {last_chunk_id} saved to {save_path}") + + + # if os.path.exists(dataset_path): + # dataset_dict = DatasetDict.load_from_disk(dataset_path) + # if split_name in dataset_dict: + # # Concatenate the new chunk to the existing dataset split + # # dataset_dict[split_name] = DatasetDict({split_name: dataset_dict[split_name].concatenate(dataset_chunk)}) + # dataset_dict[split_name] = concatenate_datasets([dataset_dict[split_name], dataset_chunk]) + # else: + # # Add new split with the chunk + # dataset_dict[split_name] = dataset_chunk + # else: + # # Initialize a new dataset with the chunk + # dataset_dict = DatasetDict({split_name: dataset_chunk}) + + # Save the dataset + # save_path = os.path.join(dataset_path, f"{split}_chunk_{num_chunks}") + # dataset_dict.save_to_disk(dataset_path) + # print(f"chunk {num_chunks} saved to {save_path}") + # print(f"Updated dataset for {split_name} split at {dataset_path}") + + def concatenate_datasets(self, sentence_dataset, dataset_folder, tokenizer, max_length): + # Load and concatenate all chunks for each split + # making a final dataset with all the chunks and splits + final_dataset = DatasetDict() + for split in sentence_dataset.keys(): + data_split_path = os.path.join(dataset_folder, split) + num_chunks = len(os.listdir(data_split_path)) + for chunk_num in range(num_chunks): + chunk_path = os.path.join(data_split_path, f"chunk_{chunk_num}") + if chunk_num == 0: + loaded_dset = Dataset.load_from_disk(chunk_path) + else: + loaded_dset = concatenate_datasets([loaded_dset, Dataset.load_from_disk(chunk_path)]) + final_dataset[split] = loaded_dset + + # Optionally, save the concatenated final datasets to disk + final_path = os.path.join(dataset_folder, "final") + final_dataset.save_to_disk(final_path) + + return final_datasets + + + # tokenized_dataset = self.chunk_and_tokenize(split_dataset, tokenizer, max_length=max_length) + # output = model.run_with_cache(tokenized_dataset["input_ids"].to(device), stop_at_layer=layer + 1) + # for sentence_idx, sentence in tqdm(enumerate(sentence_dataset)): + # tokenized_sentence = tokenizer(sentence, padding=True, truncation=True, return_tensors="pt") + # batch = tokenized_sentence["input_ids"].to(device) + + # _, cache = model.run_with_cache(batch, stop_at_layer=layer + 1) + # mlp_activation_data = ( + # cache[tensor_name].to(device).to(torch.float16) + # ) # NOTE: could do all layers at once, but currently just doing 1 layer + # mlp_activation_data = rearrange(mlp_activation_data, "b s n -> (b s) n") + + # dataset.append(mlp_activation_data) + # if len(dataset) >= actives_per_chunk: + # if center_dataset: + # if n_saved_chunks == 0: + # chunk_mean = torch.mean(torch.cat(dataset), dim=0) + # dataset = [x - chunk_mean for x in dataset] + + # # Need to save, restart the list + # save_activation_chunk(dataset, n_saved_chunks, dataset_folder) + # n_saved_chunks += 1 + # print(f"Saved chunk {n_saved_chunks} of activations, total size: {batch_idx * activation_size} ") + # dataset = [] + # if n_saved_chunks == n_chunks: + # break + + # if n_saved_chunks < n_chunks: + # save_activation_chunk(dataset, n_saved_chunks, dataset_folder) + # print(f"Saved undersized chunk {n_saved_chunks} of activations, total size: {batch_idx * activation_size} ") + + # # Nora's Code from https://github.com/AlignmentResearch/tuned-lens/blob/main/tuned_lens/data.py + # def chunk_and_tokenize(self, + # data: T, + # tokenizer: PreTrainedTokenizerBase, + # *, + # format: str = "torch", + # num_proc: int = min(mp.cpu_count() // 2, 8), + # text_key: str = "text", + # max_length: int = 2048, + # return_final_batch: bool = False, + # load_from_cache_file: bool = True, + # )-> Tuple[T, float]: + # """Perform GPT-style chunking and tokenization on a dataset. + + # The resulting dataset will consist entirely of chunks exactly `max_length` tokens + # long. Long sequences will be split into multiple chunks, and short sequences will + # be merged with their neighbors, using `eos_token` as a separator. The fist token + # will also always be an `eos_token`. + + # Args: + # data: The dataset to chunk and tokenize. + # tokenizer: The tokenizer to use. + # format: The format to return the dataset in, passed to `Dataset.with_format`. + # num_proc: The number of processes to use for tokenization. + # text_key: The key in the dataset to use as the text to tokenize. + # max_length: The maximum length of a batch of input ids. + # return_final_batch: Whether to return the final batch, which may be smaller + # than the others. + # load_from_cache_file: Whether to load from the cache file. + + # Returns: + # * The chunked and tokenized dataset. + # * The ratio of nats to bits per byte see https://arxiv.org/pdf/2101.00027.pdf, + # section 3.1. + # """ + + # def _tokenize_fn(x: Dict[str, list]): + # chunk_size = min(tokenizer.model_max_length, max_length) # tokenizer max length is 1024 for gpt2 + # sep = tokenizer.eos_token or "<|endoftext|>" + # joined_text = sep.join([""] + x[text_key]) + # output = tokenizer( + # # Concatenate all the samples together, separated by the EOS token. + # joined_text, # start with an eos token + # max_length=chunk_size, + # return_attention_mask=False, + # return_overflowing_tokens=True, + # truncation=True, + # ) + + # if overflow := output.pop("overflowing_tokens", None): + # # Slow Tokenizers return unnested lists of ints + # assert isinstance(output["input_ids"][0], int) + + # # Chunk the overflow into batches of size `chunk_size` + # chunks = [output["input_ids"]] + [ + # overflow[i * chunk_size : (i + 1) * chunk_size] for i in range(math.ceil(len(overflow) / chunk_size)) + # ] + # output = {"input_ids": chunks} + + # total_tokens = sum(len(ids) for ids in output["input_ids"]) + # total_bytes = len(joined_text.encode("utf-8")) + + # if not return_final_batch: + # # We know that the last sample will almost always be less than the max + # # number of tokens, and we don't want to pad, so we just drop it. + # output = {k: v[:-1] for k, v in output.items()} + + # output_batch_size = len(output["input_ids"]) + + # if output_batch_size == 0: + # raise ValueError( + # "Not enough data to create a single batch complete batch." + # " Either allow the final batch to be returned," + # " or supply more data." + # ) + + # # We need to output this in order to compute the number of bits per byte + # div, rem = divmod(total_tokens, output_batch_size) + # output["length"] = [div] * output_batch_size + # output["length"][-1] += rem + + # div, rem = divmod(total_bytes, output_batch_size) + # output["bytes"] = [div] * output_batch_size + # output["bytes"][-1] += rem + + # return output + + # def get_columns_all_equal(dataset: Union[Dataset, DatasetDict]) -> List[str]: + # """Get a single list of columns in a `Dataset` or `DatasetDict`. + + # We assert the columms are the same across splits if it's a `DatasetDict`. + + # Args: + # dataset: The dataset to get the columns from. + + # Returns: + # A list of columns. + # """ + # if isinstance(dataset, DatasetDict): + # cols_by_split = dataset.column_names.values() + # columns = next(iter(cols_by_split)) + # if not all(cols == columns for cols in cols_by_split): + # raise ValueError("All splits must have the same columns") + + # return columns + + # return dataset.column_names + + # # take 2048 texts from Pile, concat them together by appending eos token, tokenize them, + # # chunck them into chunk_size=256 tokens, and return the tokenized texts as a much larger than 2048 batch of + # # equal length tokenized texts + + # data = data.map( + # _tokenize_fn, + # # Batching is important for ensuring that we don't waste tokens + # # since we always throw away the last element of the batch we + # # want to keep the batch size as large as possible + # batched=True, + # batch_size=2048, + # num_proc=num_proc, + # remove_columns=get_columns_all_equal(data), + # load_from_cache_file=load_from_cache_file, + # ) + # total_bytes: float = sum(data["bytes"]) + # total_tokens: float = sum(data["length"]) + # return data.with_format(format, columns=["input_ids"]), (total_tokens / total_bytes) / math.log(2) + + + + + # # End Nora's Code from https://github.com/AlignmentResearch/tuned-lens/blob/main/tuned_lens/data.py + + + + +if __name__ == "__main__": + + generate_activation_data_hparams= { + 'model_name': "gelu-2l", + 'model_device': 'cuda:1', + } + + generate_activation_data_class = GenerateActivationData(**generate_activation_data_hparams) + + # loading text dataset + data_path = os.path.join(os.getcwd(), "symbolic_probing", "data") + data_files = {"train": ["train/00.jsonl.zst", "train/01.jsonl.zst"], "validation": "val.jsonl.zst", "test": "test.jsonl.zst"} + dataset = generate_activation_data_class.make_sentence_dataset("monology/pile-uncopyrighted", data_files=data_files, data_path=data_path, max_lines=20_000, start_line=0) + print(dataset) + + # make activation dataset + dataset_folder = os.path.join(os.getcwd(), "symbolic_probing", "data", "activation_data") + generate_activation_data_class.make_activation_dataset(dataset, dataset_folder=dataset_folder) + + + + + + diff --git a/src/utils/instantiators.py b/src/utils/instantiators.py new file mode 100644 index 0000000..82b9278 --- /dev/null +++ b/src/utils/instantiators.py @@ -0,0 +1,56 @@ +from typing import List + +import hydra +from lightning import Callback +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig + +from src.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: + """Instantiates callbacks from config. + + :param callbacks_cfg: A DictConfig object containing callback configurations. + :return: A list of instantiated callbacks. + """ + callbacks: List[Callback] = [] + + if not callbacks_cfg: + log.warning("No callback configs found! Skipping..") + return callbacks + + if not isinstance(callbacks_cfg, DictConfig): + raise TypeError("Callbacks config must be a DictConfig!") + + for _, cb_conf in callbacks_cfg.items(): + if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: + log.info(f"Instantiating callback <{cb_conf._target_}>") + callbacks.append(hydra.utils.instantiate(cb_conf)) + + return callbacks + + +def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: + """Instantiates loggers from config. + + :param logger_cfg: A DictConfig object containing logger configurations. + :return: A list of instantiated loggers. + """ + logger: List[Logger] = [] + + if not logger_cfg: + log.warning("No logger configs found! Skipping...") + return logger + + if not isinstance(logger_cfg, DictConfig): + raise TypeError("Logger config must be a DictConfig!") + + for _, lg_conf in logger_cfg.items(): + if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: + log.info(f"Instantiating logger <{lg_conf._target_}>") + logger.append(hydra.utils.instantiate(lg_conf)) + + return logger diff --git a/src/utils/logging_utils.py b/src/utils/logging_utils.py new file mode 100644 index 0000000..360abcd --- /dev/null +++ b/src/utils/logging_utils.py @@ -0,0 +1,57 @@ +from typing import Any, Dict + +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import OmegaConf + +from src.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def log_hyperparameters(object_dict: Dict[str, Any]) -> None: + """Controls which config parts are saved by Lightning loggers. + + Additionally saves: + - Number of model parameters + + :param object_dict: A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. + """ + hparams = {} + + cfg = OmegaConf.to_container(object_dict["cfg"]) + model = object_dict["model"] + trainer = object_dict["trainer"] + + if not trainer.logger: + log.warning("Logger not found! Skipping hyperparameter logging...") + return + + hparams["model"] = cfg["model"] + + # save number of model parameters + hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) + hparams["model/params/trainable"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + hparams["model/params/non_trainable"] = sum( + p.numel() for p in model.parameters() if not p.requires_grad + ) + + hparams["data"] = cfg["data"] + hparams["trainer"] = cfg["trainer"] + + hparams["callbacks"] = cfg.get("callbacks") + hparams["extras"] = cfg.get("extras") + + hparams["task_name"] = cfg.get("task_name") + hparams["tags"] = cfg.get("tags") + hparams["ckpt_path"] = cfg.get("ckpt_path") + hparams["seed"] = cfg.get("seed") + + # send hparams to all loggers + for logger in trainer.loggers: + logger.log_hyperparams(hparams) diff --git a/src/utils/pylogger.py b/src/utils/pylogger.py new file mode 100644 index 0000000..c4ee867 --- /dev/null +++ b/src/utils/pylogger.py @@ -0,0 +1,51 @@ +import logging +from typing import Mapping, Optional + +from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only + + +class RankedLogger(logging.LoggerAdapter): + """A multi-GPU-friendly python command line logger.""" + + def __init__( + self, + name: str = __name__, + rank_zero_only: bool = False, + extra: Optional[Mapping[str, object]] = None, + ) -> None: + """Initializes a multi-GPU-friendly python command line logger that logs on all processes + with their rank prefixed in the log message. + + :param name: The name of the logger. Default is ``__name__``. + :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. + :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + """ + logger = logging.getLogger(name) + super().__init__(logger=logger, extra=extra) + self.rank_zero_only = rank_zero_only + + def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None: + """Delegate a log call to the underlying logger, after prefixing its message with the rank + of the process it's being logged from. If `'rank'` is provided, then the log will only + occur on that rank/process. + + :param level: The level to log at. Look at `logging.__init__.py` for more information. + :param msg: The message to log. + :param rank: The rank to log at. + :param args: Additional args to pass to the underlying logging function. + :param kwargs: Any additional keyword args to pass to the underlying logging function. + """ + if self.isEnabledFor(level): + msg, kwargs = self.process(msg, kwargs) + current_rank = getattr(rank_zero_only, "rank", None) + if current_rank is None: + raise RuntimeError("The `rank_zero_only.rank` needs to be set before use") + msg = rank_prefixed_message(msg, current_rank) + if self.rank_zero_only: + if current_rank == 0: + self.logger.log(level, msg, *args, **kwargs) + else: + if rank is None: + self.logger.log(level, msg, *args, **kwargs) + elif current_rank == rank: + self.logger.log(level, msg, *args, **kwargs) diff --git a/src/utils/rich_utils.py b/src/utils/rich_utils.py new file mode 100644 index 0000000..aeec680 --- /dev/null +++ b/src/utils/rich_utils.py @@ -0,0 +1,99 @@ +from pathlib import Path +from typing import Sequence + +import rich +import rich.syntax +import rich.tree +from hydra.core.hydra_config import HydraConfig +from lightning_utilities.core.rank_zero import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict +from rich.prompt import Prompt + +from src.utils import pylogger + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +@rank_zero_only +def print_config_tree( + cfg: DictConfig, + print_order: Sequence[str] = ( + "data", + "model", + "callbacks", + "logger", + "trainer", + "paths", + "extras", + ), + resolve: bool = False, + save_to_file: bool = False, +) -> None: + """Prints the contents of a DictConfig as a tree structure using the Rich library. + + :param cfg: A DictConfig composed by Hydra. + :param print_order: Determines in what order config components are printed. Default is ``("data", "model", + "callbacks", "logger", "trainer", "paths", "extras")``. + :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. + :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. + """ + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + queue = [] + + # add fields from `print_order` to queue + for field in print_order: + queue.append(field) if field in cfg else log.warning( + f"Field '{field}' not found in config. Skipping '{field}' config printing..." + ) + + # add all the other fields to queue (not specified in `print_order`) + for field in cfg: + if field not in queue: + queue.append(field) + + # generate config tree from queue + for field in queue: + branch = tree.add(field, style=style, guide_style=style) + + config_group = cfg[field] + if isinstance(config_group, DictConfig): + branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) + else: + branch_content = str(config_group) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + # print config tree + rich.print(tree) + + # save config tree to file + if save_to_file: + with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: + rich.print(tree, file=file) + + +@rank_zero_only +def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: + """Prompts user to input tags from command line if no tags are provided in config. + + :param cfg: A DictConfig composed by Hydra. + :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. + """ + if not cfg.get("tags"): + if "id" in HydraConfig().cfg.hydra.job: + raise ValueError("Specify tags before launching a multirun!") + + log.warning("No tags provided in config. Prompting user to input tags...") + tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + tags = [t.strip() for t in tags.split(",") if t != ""] + + with open_dict(cfg): + cfg.tags = tags + + log.info(f"Tags: {cfg.tags}") + + if save_to_file: + with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: + rich.print(cfg.tags, file=file) diff --git a/src/utils/utils.py b/src/utils/utils.py new file mode 100644 index 0000000..02b5576 --- /dev/null +++ b/src/utils/utils.py @@ -0,0 +1,119 @@ +import warnings +from importlib.util import find_spec +from typing import Any, Callable, Dict, Optional, Tuple + +from omegaconf import DictConfig + +from src.utils import pylogger, rich_utils + +log = pylogger.RankedLogger(__name__, rank_zero_only=True) + + +def extras(cfg: DictConfig) -> None: + """Applies optional utilities before the task is started. + + Utilities: + - Ignoring python warnings + - Setting tags from command line + - Rich config printing + + :param cfg: A DictConfig object containing the config tree. + """ + # return if no `extras` config + if not cfg.get("extras"): + log.warning("Extras config not found! ") + return + + # disable python warnings + if cfg.extras.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # prompt user to input tags from command line if none are provided in the config + if cfg.extras.get("enforce_tags"): + log.info("Enforcing tags! ") + rich_utils.enforce_tags(cfg, save_to_file=True) + + # pretty print config tree using Rich library + if cfg.extras.get("print_config"): + log.info("Printing config tree with Rich! ") + rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) + + +def task_wrapper(task_func: Callable) -> Callable: + """Optional decorator that controls the failure behavior when executing the task function. + + This wrapper can be used to: + - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) + - save the exception to a `.log` file + - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) + - etc. (adjust depending on your needs) + + Example: + ``` + @utils.task_wrapper + def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + ... + return metric_dict, object_dict + ``` + + :param task_func: The task function to be wrapped. + + :return: The wrapped task function. + """ + + def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: + # execute the task + try: + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs + except Exception as ex: + # save exception to `.log` file + log.exception("") + + # some hyperparameter combinations might be invalid or cause out-of-memory errors + # so when using hparam search plugins like Optuna, you might want to disable + # raising the below exception to avoid multirun failure + raise ex + + # things to always do after either success or exception + finally: + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # always close wandb run (even if exception occurs so multirun won't fail) + if find_spec("wandb"): # check if wandb is installed + import wandb + + if wandb.run: + log.info("Closing wandb!") + wandb.finish() + + return metric_dict, object_dict + + return wrap + + +def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]: + """Safely retrieves value of the metric logged in LightningModule. + + :param metric_dict: A dict containing metric values. + :param metric_name: If provided, the name of the metric to retrieve. + :return: If a metric name was provided, the value of the metric. + """ + if not metric_name: + log.info("Metric name is None! Skipping metric value retrieval...") + return None + + if metric_name not in metric_dict: + raise Exception( + f"Metric value not found! \n" + "Make sure metric name logged in LightningModule is correct!\n" + "Make sure `optimized_metric` name in `hparams_search` config is correct!" + ) + + metric_value = metric_dict[metric_name].item() + log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") + + return metric_value diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..b5dea33 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,107 @@ +"""This file prepares config fixtures for other tests.""" + +from pathlib import Path + +import pytest +import rootutils +from hydra import compose, initialize +from hydra.core.global_hydra import GlobalHydra +from omegaconf import DictConfig, open_dict + + +@pytest.fixture(scope="package") +def cfg_train_global() -> DictConfig: + """A pytest fixture for setting up a default Hydra DictConfig for training. + + :return: A DictConfig object containing a default Hydra configuration for training. + """ + with initialize(version_base="1.3", config_path="../configs"): + cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[]) + + # set defaults for all tests + with open_dict(cfg): + cfg.paths.root_dir = str(rootutils.find_root(indicator=".project-root")) + cfg.trainer.max_epochs = 1 + cfg.trainer.limit_train_batches = 0.01 + cfg.trainer.limit_val_batches = 0.1 + cfg.trainer.limit_test_batches = 0.1 + cfg.trainer.accelerator = "cpu" + cfg.trainer.devices = 1 + cfg.data.num_workers = 0 + cfg.data.pin_memory = False + cfg.extras.print_config = False + cfg.extras.enforce_tags = False + cfg.logger = None + + return cfg + + +@pytest.fixture(scope="package") +def cfg_eval_global() -> DictConfig: + """A pytest fixture for setting up a default Hydra DictConfig for evaluation. + + :return: A DictConfig containing a default Hydra configuration for evaluation. + """ + with initialize(version_base="1.3", config_path="../configs"): + cfg = compose(config_name="eval.yaml", return_hydra_config=True, overrides=["ckpt_path=."]) + + # set defaults for all tests + with open_dict(cfg): + cfg.paths.root_dir = str(rootutils.find_root(indicator=".project-root")) + cfg.trainer.max_epochs = 1 + cfg.trainer.limit_test_batches = 0.1 + cfg.trainer.accelerator = "cpu" + cfg.trainer.devices = 1 + cfg.data.num_workers = 0 + cfg.data.pin_memory = False + cfg.extras.print_config = False + cfg.extras.enforce_tags = False + cfg.logger = None + + return cfg + + +@pytest.fixture(scope="function") +def cfg_train(cfg_train_global: DictConfig, tmp_path: Path) -> DictConfig: + """A pytest fixture built on top of the `cfg_train_global()` fixture, which accepts a temporary + logging path `tmp_path` for generating a temporary logging path. + + This is called by each test which uses the `cfg_train` arg. Each test generates its own temporary logging path. + + :param cfg_train_global: The input DictConfig object to be modified. + :param tmp_path: The temporary logging path. + + :return: A DictConfig with updated output and log directories corresponding to `tmp_path`. + """ + cfg = cfg_train_global.copy() + + with open_dict(cfg): + cfg.paths.output_dir = str(tmp_path) + cfg.paths.log_dir = str(tmp_path) + + yield cfg + + GlobalHydra.instance().clear() + + +@pytest.fixture(scope="function") +def cfg_eval(cfg_eval_global: DictConfig, tmp_path: Path) -> DictConfig: + """A pytest fixture built on top of the `cfg_eval_global()` fixture, which accepts a temporary + logging path `tmp_path` for generating a temporary logging path. + + This is called by each test which uses the `cfg_eval` arg. Each test generates its own temporary logging path. + + :param cfg_train_global: The input DictConfig object to be modified. + :param tmp_path: The temporary logging path. + + :return: A DictConfig with updated output and log directories corresponding to `tmp_path`. + """ + cfg = cfg_eval_global.copy() + + with open_dict(cfg): + cfg.paths.output_dir = str(tmp_path) + cfg.paths.log_dir = str(tmp_path) + + yield cfg + + GlobalHydra.instance().clear() diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/helpers/package_available.py b/tests/helpers/package_available.py new file mode 100644 index 0000000..0afdba8 --- /dev/null +++ b/tests/helpers/package_available.py @@ -0,0 +1,32 @@ +import platform + +import pkg_resources +from lightning.fabric.accelerators import TPUAccelerator + + +def _package_available(package_name: str) -> bool: + """Check if a package is available in your environment. + + :param package_name: The name of the package to be checked. + + :return: `True` if the package is available. `False` otherwise. + """ + try: + return pkg_resources.require(package_name) is not None + except pkg_resources.DistributionNotFound: + return False + + +_TPU_AVAILABLE = TPUAccelerator.is_available() + +_IS_WINDOWS = platform.system() == "Windows" + +_SH_AVAILABLE = not _IS_WINDOWS and _package_available("sh") + +_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _package_available("deepspeed") +_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _package_available("fairscale") + +_WANDB_AVAILABLE = _package_available("wandb") +_NEPTUNE_AVAILABLE = _package_available("neptune") +_COMET_AVAILABLE = _package_available("comet_ml") +_MLFLOW_AVAILABLE = _package_available("mlflow") diff --git a/tests/helpers/run_if.py b/tests/helpers/run_if.py new file mode 100644 index 0000000..9703af4 --- /dev/null +++ b/tests/helpers/run_if.py @@ -0,0 +1,142 @@ +"""Adapted from: + +https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py +""" + +import sys +from typing import Any, Dict, Optional + +import pytest +import torch +from packaging.version import Version +from pkg_resources import get_distribution +from pytest import MarkDecorator + +from tests.helpers.package_available import ( + _COMET_AVAILABLE, + _DEEPSPEED_AVAILABLE, + _FAIRSCALE_AVAILABLE, + _IS_WINDOWS, + _MLFLOW_AVAILABLE, + _NEPTUNE_AVAILABLE, + _SH_AVAILABLE, + _TPU_AVAILABLE, + _WANDB_AVAILABLE, +) + + +class RunIf: + """RunIf wrapper for conditional skipping of tests. + + Fully compatible with `@pytest.mark`. + + Example: + + ```python + @RunIf(min_torch="1.8") + @pytest.mark.parametrize("arg1", [1.0, 2.0]) + def test_wrapper(arg1): + assert arg1 > 0 + ``` + """ + + def __new__( + cls, + min_gpus: int = 0, + min_torch: Optional[str] = None, + max_torch: Optional[str] = None, + min_python: Optional[str] = None, + skip_windows: bool = False, + sh: bool = False, + tpu: bool = False, + fairscale: bool = False, + deepspeed: bool = False, + wandb: bool = False, + neptune: bool = False, + comet: bool = False, + mlflow: bool = False, + **kwargs: Dict[Any, Any], + ) -> MarkDecorator: + """Creates a new `@RunIf` `MarkDecorator` decorator. + + :param min_gpus: Min number of GPUs required to run test. + :param min_torch: Minimum pytorch version to run test. + :param max_torch: Maximum pytorch version to run test. + :param min_python: Minimum python version required to run test. + :param skip_windows: Skip test for Windows platform. + :param tpu: If TPU is available. + :param sh: If `sh` module is required to run the test. + :param fairscale: If `fairscale` module is required to run the test. + :param deepspeed: If `deepspeed` module is required to run the test. + :param wandb: If `wandb` module is required to run the test. + :param neptune: If `neptune` module is required to run the test. + :param comet: If `comet` module is required to run the test. + :param mlflow: If `mlflow` module is required to run the test. + :param kwargs: Native `pytest.mark.skipif` keyword arguments. + """ + conditions = [] + reasons = [] + + if min_gpus: + conditions.append(torch.cuda.device_count() < min_gpus) + reasons.append(f"GPUs>={min_gpus}") + + if min_torch: + torch_version = get_distribution("torch").version + conditions.append(Version(torch_version) < Version(min_torch)) + reasons.append(f"torch>={min_torch}") + + if max_torch: + torch_version = get_distribution("torch").version + conditions.append(Version(torch_version) >= Version(max_torch)) + reasons.append(f"torch<{max_torch}") + + if min_python: + py_version = ( + f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" + ) + conditions.append(Version(py_version) < Version(min_python)) + reasons.append(f"python>={min_python}") + + if skip_windows: + conditions.append(_IS_WINDOWS) + reasons.append("does not run on Windows") + + if tpu: + conditions.append(not _TPU_AVAILABLE) + reasons.append("TPU") + + if sh: + conditions.append(not _SH_AVAILABLE) + reasons.append("sh") + + if fairscale: + conditions.append(not _FAIRSCALE_AVAILABLE) + reasons.append("fairscale") + + if deepspeed: + conditions.append(not _DEEPSPEED_AVAILABLE) + reasons.append("deepspeed") + + if wandb: + conditions.append(not _WANDB_AVAILABLE) + reasons.append("wandb") + + if neptune: + conditions.append(not _NEPTUNE_AVAILABLE) + reasons.append("neptune") + + if comet: + conditions.append(not _COMET_AVAILABLE) + reasons.append("comet") + + if mlflow: + conditions.append(not _MLFLOW_AVAILABLE) + reasons.append("mlflow") + + reasons = [rs for cond, rs in zip(conditions, reasons) if cond] + return pytest.mark.skipif( + condition=any(conditions), + reason=f"Requires: [{' + '.join(reasons)}]", + **kwargs, + ) diff --git a/tests/helpers/run_sh_command.py b/tests/helpers/run_sh_command.py new file mode 100644 index 0000000..fdd2ed6 --- /dev/null +++ b/tests/helpers/run_sh_command.py @@ -0,0 +1,22 @@ +from typing import List + +import pytest + +from tests.helpers.package_available import _SH_AVAILABLE + +if _SH_AVAILABLE: + import sh + + +def run_sh_command(command: List[str]) -> None: + """Default method for executing shell commands with `pytest` and `sh` package. + + :param command: A list of shell commands as strings. + """ + msg = None + try: + sh.python(command) + except sh.ErrorReturnCode as e: + msg = e.stderr.decode() + if msg: + pytest.fail(msg=msg) diff --git a/tests/test_configs.py b/tests/test_configs.py new file mode 100644 index 0000000..d7041dc --- /dev/null +++ b/tests/test_configs.py @@ -0,0 +1,37 @@ +import hydra +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig + + +def test_train_config(cfg_train: DictConfig) -> None: + """Tests the training configuration provided by the `cfg_train` pytest fixture. + + :param cfg_train: A DictConfig containing a valid training configuration. + """ + assert cfg_train + assert cfg_train.data + assert cfg_train.model + assert cfg_train.trainer + + HydraConfig().set_config(cfg_train) + + hydra.utils.instantiate(cfg_train.data) + hydra.utils.instantiate(cfg_train.model) + hydra.utils.instantiate(cfg_train.trainer) + + +def test_eval_config(cfg_eval: DictConfig) -> None: + """Tests the evaluation configuration provided by the `cfg_eval` pytest fixture. + + :param cfg_train: A DictConfig containing a valid evaluation configuration. + """ + assert cfg_eval + assert cfg_eval.data + assert cfg_eval.model + assert cfg_eval.trainer + + HydraConfig().set_config(cfg_eval) + + hydra.utils.instantiate(cfg_eval.data) + hydra.utils.instantiate(cfg_eval.model) + hydra.utils.instantiate(cfg_eval.trainer) diff --git a/tests/test_datamodules.py b/tests/test_datamodules.py new file mode 100644 index 0000000..901f3d6 --- /dev/null +++ b/tests/test_datamodules.py @@ -0,0 +1,38 @@ +from pathlib import Path + +import pytest +import torch + +from src.data.mnist_datamodule import MNISTDataModule + + +@pytest.mark.parametrize("batch_size", [32, 128]) +def test_mnist_datamodule(batch_size: int) -> None: + """Tests `MNISTDataModule` to verify that it can be downloaded correctly, that the necessary + attributes were created (e.g., the dataloader objects), and that dtypes and batch sizes + correctly match. + + :param batch_size: Batch size of the data to be loaded by the dataloader. + """ + data_dir = "data/" + + dm = MNISTDataModule(data_dir=data_dir, batch_size=batch_size) + dm.prepare_data() + + assert not dm.data_train and not dm.data_val and not dm.data_test + assert Path(data_dir, "MNIST").exists() + assert Path(data_dir, "MNIST", "raw").exists() + + dm.setup() + assert dm.data_train and dm.data_val and dm.data_test + assert dm.train_dataloader() and dm.val_dataloader() and dm.test_dataloader() + + num_datapoints = len(dm.data_train) + len(dm.data_val) + len(dm.data_test) + assert num_datapoints == 70_000 + + batch = next(iter(dm.train_dataloader())) + x, y = batch + assert len(x) == batch_size + assert len(y) == batch_size + assert x.dtype == torch.float32 + assert y.dtype == torch.int64 diff --git a/tests/test_eval.py b/tests/test_eval.py new file mode 100644 index 0000000..423c9d2 --- /dev/null +++ b/tests/test_eval.py @@ -0,0 +1,39 @@ +import os +from pathlib import Path + +import pytest +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, open_dict + +from src.eval import evaluate +from src.train import train + + +@pytest.mark.slow +def test_train_eval(tmp_path: Path, cfg_train: DictConfig, cfg_eval: DictConfig) -> None: + """Tests training and evaluation by training for 1 epoch with `train.py` then evaluating with + `eval.py`. + + :param tmp_path: The temporary logging path. + :param cfg_train: A DictConfig containing a valid training configuration. + :param cfg_eval: A DictConfig containing a valid evaluation configuration. + """ + assert str(tmp_path) == cfg_train.paths.output_dir == cfg_eval.paths.output_dir + + with open_dict(cfg_train): + cfg_train.trainer.max_epochs = 1 + cfg_train.test = True + + HydraConfig().set_config(cfg_train) + train_metric_dict, _ = train(cfg_train) + + assert "last.ckpt" in os.listdir(tmp_path / "checkpoints") + + with open_dict(cfg_eval): + cfg_eval.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") + + HydraConfig().set_config(cfg_eval) + test_metric_dict, _ = evaluate(cfg_eval) + + assert test_metric_dict["test/acc"] > 0.0 + assert abs(train_metric_dict["test/acc"].item() - test_metric_dict["test/acc"].item()) < 0.001 diff --git a/tests/test_sweeps.py b/tests/test_sweeps.py new file mode 100644 index 0000000..7856b15 --- /dev/null +++ b/tests/test_sweeps.py @@ -0,0 +1,107 @@ +from pathlib import Path + +import pytest + +from tests.helpers.run_if import RunIf +from tests.helpers.run_sh_command import run_sh_command + +startfile = "src/train.py" +overrides = ["logger=[]"] + + +@RunIf(sh=True) +@pytest.mark.slow +def test_experiments(tmp_path: Path) -> None: + """Test running all available experiment configs with `fast_dev_run=True.` + + :param tmp_path: The temporary logging path. + """ + command = [ + startfile, + "-m", + "experiment=glob(*)", + "hydra.sweep.dir=" + str(tmp_path), + "++trainer.fast_dev_run=true", + ] + overrides + run_sh_command(command) + + +@RunIf(sh=True) +@pytest.mark.slow +def test_hydra_sweep(tmp_path: Path) -> None: + """Test default hydra sweep. + + :param tmp_path: The temporary logging path. + """ + command = [ + startfile, + "-m", + "hydra.sweep.dir=" + str(tmp_path), + "model.optimizer.lr=0.005,0.01", + "++trainer.fast_dev_run=true", + ] + overrides + + run_sh_command(command) + + +@RunIf(sh=True) +@pytest.mark.slow +def test_hydra_sweep_ddp_sim(tmp_path: Path) -> None: + """Test default hydra sweep with ddp sim. + + :param tmp_path: The temporary logging path. + """ + command = [ + startfile, + "-m", + "hydra.sweep.dir=" + str(tmp_path), + "trainer=ddp_sim", + "trainer.max_epochs=3", + "+trainer.limit_train_batches=0.01", + "+trainer.limit_val_batches=0.1", + "+trainer.limit_test_batches=0.1", + "model.optimizer.lr=0.005,0.01,0.02", + ] + overrides + run_sh_command(command) + + +@RunIf(sh=True) +@pytest.mark.slow +def test_optuna_sweep(tmp_path: Path) -> None: + """Test Optuna hyperparam sweeping. + + :param tmp_path: The temporary logging path. + """ + command = [ + startfile, + "-m", + "hparams_search=mnist_optuna", + "hydra.sweep.dir=" + str(tmp_path), + "hydra.sweeper.n_trials=10", + "hydra.sweeper.sampler.n_startup_trials=5", + "++trainer.fast_dev_run=true", + ] + overrides + run_sh_command(command) + + +@RunIf(wandb=True, sh=True) +@pytest.mark.slow +def test_optuna_sweep_ddp_sim_wandb(tmp_path: Path) -> None: + """Test Optuna sweep with wandb logging and ddp sim. + + :param tmp_path: The temporary logging path. + """ + command = [ + startfile, + "-m", + "hparams_search=mnist_optuna", + "hydra.sweep.dir=" + str(tmp_path), + "hydra.sweeper.n_trials=5", + "trainer=ddp_sim", + "trainer.max_epochs=3", + "+trainer.limit_train_batches=0.01", + "+trainer.limit_val_batches=0.1", + "+trainer.limit_test_batches=0.1", + "logger=wandb", + ] + run_sh_command(command) diff --git a/tests/test_train.py b/tests/test_train.py new file mode 100644 index 0000000..c13ae02 --- /dev/null +++ b/tests/test_train.py @@ -0,0 +1,108 @@ +import os +from pathlib import Path + +import pytest +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, open_dict + +from src.train import train +from tests.helpers.run_if import RunIf + + +def test_train_fast_dev_run(cfg_train: DictConfig) -> None: + """Run for 1 train, val and test step. + + :param cfg_train: A DictConfig containing a valid training configuration. + """ + HydraConfig().set_config(cfg_train) + with open_dict(cfg_train): + cfg_train.trainer.fast_dev_run = True + cfg_train.trainer.accelerator = "cpu" + train(cfg_train) + + +@RunIf(min_gpus=1) +def test_train_fast_dev_run_gpu(cfg_train: DictConfig) -> None: + """Run for 1 train, val and test step on GPU. + + :param cfg_train: A DictConfig containing a valid training configuration. + """ + HydraConfig().set_config(cfg_train) + with open_dict(cfg_train): + cfg_train.trainer.fast_dev_run = True + cfg_train.trainer.accelerator = "gpu" + train(cfg_train) + + +@RunIf(min_gpus=1) +@pytest.mark.slow +def test_train_epoch_gpu_amp(cfg_train: DictConfig) -> None: + """Train 1 epoch on GPU with mixed-precision. + + :param cfg_train: A DictConfig containing a valid training configuration. + """ + HydraConfig().set_config(cfg_train) + with open_dict(cfg_train): + cfg_train.trainer.max_epochs = 1 + cfg_train.trainer.accelerator = "gpu" + cfg_train.trainer.precision = 16 + train(cfg_train) + + +@pytest.mark.slow +def test_train_epoch_double_val_loop(cfg_train: DictConfig) -> None: + """Train 1 epoch with validation loop twice per epoch. + + :param cfg_train: A DictConfig containing a valid training configuration. + """ + HydraConfig().set_config(cfg_train) + with open_dict(cfg_train): + cfg_train.trainer.max_epochs = 1 + cfg_train.trainer.val_check_interval = 0.5 + train(cfg_train) + + +@pytest.mark.slow +def test_train_ddp_sim(cfg_train: DictConfig) -> None: + """Simulate DDP (Distributed Data Parallel) on 2 CPU processes. + + :param cfg_train: A DictConfig containing a valid training configuration. + """ + HydraConfig().set_config(cfg_train) + with open_dict(cfg_train): + cfg_train.trainer.max_epochs = 2 + cfg_train.trainer.accelerator = "cpu" + cfg_train.trainer.devices = 2 + cfg_train.trainer.strategy = "ddp_spawn" + train(cfg_train) + + +@pytest.mark.slow +def test_train_resume(tmp_path: Path, cfg_train: DictConfig) -> None: + """Run 1 epoch, finish, and resume for another epoch. + + :param tmp_path: The temporary logging path. + :param cfg_train: A DictConfig containing a valid training configuration. + """ + with open_dict(cfg_train): + cfg_train.trainer.max_epochs = 1 + + HydraConfig().set_config(cfg_train) + metric_dict_1, _ = train(cfg_train) + + files = os.listdir(tmp_path / "checkpoints") + assert "last.ckpt" in files + assert "epoch_000.ckpt" in files + + with open_dict(cfg_train): + cfg_train.ckpt_path = str(tmp_path / "checkpoints" / "last.ckpt") + cfg_train.trainer.max_epochs = 2 + + metric_dict_2, _ = train(cfg_train) + + files = os.listdir(tmp_path / "checkpoints") + assert "epoch_001.ckpt" in files + assert "epoch_002.ckpt" not in files + + assert metric_dict_1["train/acc"] < metric_dict_2["train/acc"] + assert metric_dict_1["val/acc"] < metric_dict_2["val/acc"]