Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(examples): release an example implementation of T-V reward model. #2

Merged
merged 4 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@ ignore =
# W503: line break before binary operator
# W504: line break after binary operator
# format by black
E203,E241,E704,W503,W504,
E203,E241,E704,W503,W504,E501,W505,
# E501: line too long
# W505: doc line too long
# too long docstring due to long example blocks
E501,W505,
per-file-ignores =
# F401: module imported but unused
# intentionally unused imports
Expand Down
1 change: 0 additions & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ What types of changes does your code introduce? Put an `x` in all the boxes that
Go over all the following points, and put an `x` in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

<!-- - [ ] I have read the [CONTRIBUTION](https://safe-sora.readthedocs.io/en/latest/developer/contributing.html) guide. (**required**) -->
- [ ] My change requires a change to the documentation.
- [ ] I have updated the tests accordingly. (*required for a bug fix or a new feature*)
- [ ] I have updated the documentation accordingly.
Expand Down
16 changes: 0 additions & 16 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,6 @@ jobs:
run: |
make pre-commit

- name: ruff
run: |
make ruff

- name: flake8
run: |
make flake8

- name: pylint
run: |
make pylint

- name: isort and black
run: |
make py-format

- name: addlicense
run: |
make addlicense
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
##### Project Specification #####
dataset/
outputs/
wandb/
test/
data/
checkpoints/
cache_dir

##### Python.gitignore #####
# Byte-compiled / optimized / DLL files
Expand Down
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ repos:
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
exclude: |
(?x)(
^safe_sora/models/multimodal_encoder/|
^safe_sora/models/multimodal_projector/
)
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
Expand Down Expand Up @@ -68,6 +73,8 @@ repos:
- repo: local
hooks:
- id: pylint
args:
- --disable=R0801
name: pylint
entry: pylint
language: system
Expand All @@ -78,5 +85,8 @@ repos:
^examples/|
^tests/|
^setup.py$|
^safe_sora/models/multimodal_encoder/|
^safe_sora/models/multimodal_projector/|
^safe_sora/models/video_llava.py|
^docs/source/conf.py$
)
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ PROJECT_NAME = safe-sora
COPYRIGHT = "PKU-Alignment Team. All Rights Reserved."
PROJECT_PATH = safe_sora
SHELL = /bin/bash
SOURCE_FOLDERS = $(PROJECT_PATH) examples tests docs
SOURCE_FOLDERS = $(PROJECT_PATH) examples docs
PYTHON_FILES = $(shell find $(SOURCE_FOLDERS) -type f -name "*.py" -o -name "*.pyi")
COMMIT_HASH = $(shell git log -1 --format=%h)
PATH := $(HOME)/go/bin:$(PATH)
Expand Down Expand Up @@ -130,7 +130,7 @@ pre-commit: pre-commit-install
# Documentation

addlicense: addlicense-install
addlicense -c $(COPYRIGHT) -ignore tests/coverage.xml -l apache -y 2022-$(shell date +"%Y") -check $(SOURCE_FOLDERS)
addlicense -c $(COPYRIGHT) -ignore **/multimodal_encoder/** -ignore **/multimodal_projector/** -l apache -y 2022-$(shell date +"%Y") -check $(SOURCE_FOLDERS)

docstyle: docs-install
make -C docs clean
Expand Down
56 changes: 56 additions & 0 deletions conda-recipe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Create virtual environment with command:
#
# $ CONDA_OVERRIDE_CUDA=11.8 conda env create --file conda-recipe.yaml
#

name: safe-sora
channels:
- huggingface
- pytorch
- nvidia/label/cuda-12.1.0
- defaults
- conda-forge
dependencies:
- python = 3.11
- pip

- pytorch::pytorch >= 2.0
- pytorch::pytorch-mutex =*=*cuda*
- pytorch::torchvision
- transformers >= 4.42
- datasets
- tokenizers >= 0.19
- sentencepiece
- tensorboard
- wandb
- pip:
- accelerate
- deepspeed
- decord
- opencv-python

- nvidia/label/cuda-12.1.0::cuda-toolkit = 12.1

- matplotlib-base
- rich
- tqdm
- typing-extensions
- bitsandbytes
- av
- einops
- peft
Binary file added docs/images/win_rate.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
63 changes: 63 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
<!-- markdownlint-disable html -->

# Preference Model

In this directory, we provide an example implementation of training a preference predictor reward model on our dataset.

## Preference Modeling

To modeling human preferences, it's common to use a preference predictor adhering to the Bradley-Terry Model. The preference data is symbolized as $y_w \succ y_{l} | x$ where $y_{w}$ denotes the more preferred video than $y_l$ corresponding to the prompt $x$.
The log-likelihood loss used to train a parameterized predictor $R_\phi$ on dataset $\mathcal{D}$ is:

$$\mathcal{L} (\phi; \mathcal{D}) = -\mathbb E_{{(x,y_w,y_l)\sim \mathcal{D}}} \left[\log \sigma (R_{\phi} (y_w,x) - R_{\phi} (y_l,x))\right]$$


Leveraging a multi-modal model architecture modified on the [Video-LLaVA](https://github.com/PKU-YuanGroup/Video-LLaVA) and training with preference data from [SafeSora Dataset](https://huggingface.co/datasets/PKU-Alignment/SafeSora), we have develop a T-V reward model.
The language head of the vision-language model is replaced with a score regression head, which predicts the preference score of the video given the prompt.

This model translates abstract human values into quantifiable and optimizable scalar metrics.
Consequently, the reward model can partially replace human evaluators in assessing outputs from video generation models and act as a supervisory signal to enhance the performance of these models.

## Alignment Evaluation of Different Models

The SafeSora dataset includes annotations across multiple dimensions of human preference. We have developed several distinct models that focus on different aspects of human preference, such as helpfulness, harmlessness, and four specific sub-dimensions of helpfulness. Our models achieve an agreement ratio of 65.29% for predicting helpfulness preference and 72.41% for predicting harmlessness preference when compared with crowdworker assessments.

Furthermore, we utilize these models to evaluate four open-source models on our [Evaluation Dataset](https://huggingface.co/datasets/PKU-Alignment/SafeSora-Eval). The win-rate relationships among these models, assessed across the two alignment dimensions, are depicted in the figure below.

<div align="center">
<img src="../docs/images/win_rate.png" alt="win_rate" width="85%"/>
</div>

## Training

First, you need to [download our dataset](../README.md#data-access) to local and prepare the training environment using:

```bash
conda env create -f conda-recipe.yaml # mamba env create -f conda-recipe.yaml
conda activate safe-sora
```

Then, you need to download the Video-LLaVA model and the MM-MLP adapter from the Hugging Face model hub. For example, you can download them use the following commands:

```bash
huggingface-cli download --resume-download LanguageBind/Video-LLaVA-7B --local-dir ./LanguageBind/Video-LLaVA-7B
huggingface-cli download --resume-download LanguageBind/Video-LLaVA-Pretrain-7B --local-dir ./LanguageBind/Video-LLaVA-Pretrain-7B
```

Then, you can run the following script to train the reward model on the SafeSora dataset:

```bash
bash examples/scripts/finetune_reward_model.sh \
--model_name_or_path <your-model-name-or-checkpoint-path> \
--mm_mlp_adapter_path <your-mm_mlp_adapter_path> \
--dimension <the-target-dimension-to-train> \
--output_dir examples/outputs/reward-model
```

where `<your-model-name-or-checkpoint-path>` is the name of the Video-LLaVA model or the path to the checkpoint directory, `<your-mm_mlp_adapter_path>` is the path to the `mm_projector.bin` file, and `<the-target-dimension-to-train>` is the preference dimension that the reward model will predict.

**NOTE:** The parameter 'dimension' specifies the preference dimension that the reward model will predict. The SafeSora dataset currently supports the following dimensions: `helpfulness`, `harmlessness`, `instruction_following`, `correctness`, `informativeness`, and `aesthetics`. For the detailed information of the different dimensions, please refer to our [paper](https://arxiv.org/abs/2406.14477).

## Acknowledgements

This implementation benefits from [DeepSpeed](https://github.com/microsoft/DeepSpeed), [Transformers](https://github.com/huggingface/transformers), [LLaVA](https://github.com/haotian-liu/LLaVA), and [Video-LLaVA](https://github.com/PKU-YuanGroup/Video-LLaVA). Thanks for their wonderful works and their efforts for democratizing the LLM research.
Loading