-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a34b690
commit 04b9ec9
Showing
291 changed files
with
128,183 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,141 @@ | ||
# See or Guess: Counterfactually Regularized Image Captioning | ||
[ACM MM 2024] | ||
|
||
This is the official repo of the paper **See or Guess: Counterfactually Regularized Image Captioning** accepted by ACM MM 2024. | ||
[](#Python) | ||
[](#PyTorch) | ||
[](#Transformers) | ||
 | ||
|
||
The code will be available soon! (before the end of August) | ||
|
||
Thanks for your attention and stars! :) | ||
> This repository includes the original implementation of the paper **[See or Guess: Counterfactually Regularized Image Captioning]()** (ACM MM 2024) by Qian Cao et al. | ||
|
||
|
||
|
||
# Abstract | ||
In this work, we present a generic image captioning framework that employs causal inference to make existing models more capable of interventional tasks, and counterfactually explainable. | ||
<details> | ||
<summary> More (Click me) </summary> | ||
Our approach includes two variants leveraging either total effect or natural direct effect. | ||
Integrating them into the training process enables models to handle counterfactual scenarios, increasing their generalizability. | ||
Extensive experiments on various datasets show that our method effectively reduces hallucinations and improves the model's faithfulness to images, demonstrating high portability across both small-scale and large-scale image-to-text models. | ||
</details><br> | ||
|
||
 | ||
<br> | ||
|
||
|
||
If you find our work useful, please cite the paper: | ||
``` | ||
@inproceedings{cao2024see, | ||
title={See or Guess: Counterfactually Regularized Image Captioning}, | ||
author={Cao, Qian and Chen, Xu and Song, Ruihua and Wang, Xiting and Huang, Xinting and Ren, Yuchen}, | ||
booktitle={ACM Multimedia 2024} | ||
} | ||
``` | ||
|
||
|
||
|
||
|
||
# Content | ||
0. [Before Start](#before-start) | ||
1. [Setup](#setup) | ||
2. [Prepare the Data](#prepare-the-data) | ||
3. [Usage](#usage) | ||
4. [Contact](#contact) | ||
|
||
|
||
|
||
# Before Start | ||
- Our method can be applied to many small- or large-scale image-to-text models. We provide an implementation of BLIP2 in this repo. | ||
- The implementation of BLIP2 is based on the original one in [LAVIS](https://github.com/salesforce/LAVIS). You can find more details [LAVIS/BLIP2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2). | ||
|
||
|
||
|
||
# Setup | ||
Create a new virtual environment: | ||
``` | ||
git clone https://github.com/Aman-4-Real/See-or-Guess | ||
cd See-or-Guess/ | ||
conda create -n cfric python===3.9 | ||
conda activate cfric | ||
``` | ||
|
||
|
||
|
||
# Prepare the Data | ||
- Please organize your data as the following format: | ||
``` | ||
{ | ||
'img_path': 'val2014/COCO_val2014_000000522418.jpg', # relative image path in the dataset dir | ||
'caption': 'A woman wearing a net on her head cutting a cake. ', # string | ||
'phrases': [ | ||
{ | ||
'phrase': 'cake', # the noun phrase | ||
'boxes': [[x1, y1, x2, y2], ...], # bounding boxes of the noun phrase | ||
'first_word_index': 10 # the index of the first word of noun phrase appearing in the caption | ||
}, | ||
... | ||
], | ||
'img_id': '522418' # unique image id | ||
} | ||
``` | ||
- Save the data into `.pkl` file and put it in the `YOUR_DATASET_DIR`. | ||
- Change the `url` and `storage` fields in the file `src/lavis/configs/datasets/mscoco/defaults.yaml` to `YOUR_DATASET_DIR/{train,valid,test}.pkl`, correspondingly. | ||
|
||
|
||
|
||
# Usage | ||
|
||
### Quick Check Out | ||
For the key implementation, refer to [cfr_caption_datasets.py](src/lavis/datasets/datasets/cfr_caption_datasets.py), [modeling_opt.py](src/lavis/transformers_v4p34_local/models/opt/modeling_opt.py) and [CFRLoss.py](src/lavis/transformers_v4p34_local/models/opt/CFRLoss.py) | ||
|
||
|
||
|
||
|
||
### Prepare pretrained checkpoint | ||
Download pretrained BLIP2 checkpoint (e.g., [blip2-opt-2.7b](https://huggingface.co/Salesforce/blip2-opt-2.7b)) to the `ckpt/` folder. | ||
|
||
|
||
|
||
|
||
### Workflow | ||
Generally speaking, our work is based on a trained image captioning model ("initial model" in the paper). You can follow the following steps: | ||
|
||
1. Prepare or train the initial model | ||
If you need to train the BLIP2 on your dataset, you can follow the instructions in the [LAVIS](https://github.com/salesforce/LAVIS) repo, or you can run | ||
``` | ||
cd src/run_scripts/ | ||
``` | ||
use the config `src/lavis/run_cfgs/caption_coco_ft.yaml`, run | ||
``` | ||
bash train_caption.sh | ||
``` | ||
|
||
2. Use the trained initial model to generate counterfactual captions on the training set. Use the config `src/lavis/run_cfgs/caption_eval_gen_on_train.yaml` and run | ||
``` | ||
bash eval_caption.sh | ||
``` | ||
|
||
3. Use the total effect loss (TE) or natural direct effect loss (NDE) to regularize the training. | ||
Use the config `src/lavis/run_cfgs/caption_coco_ft_te0999.yaml` for TE | ||
``` | ||
bash train_caption.sh | ||
``` | ||
and `src/lavis/run_cfgs/caption_coco_ft_nde0999.yaml` for NDE. Remember to set both `do_NDE` and `do_TE` to `True` while doing NDE training. Also adjust a proper value of hyperparameter α in the config. | ||
|
||
|
||
|
||
### Evaluation | ||
For evaluation, us the config `caption_coco_ft.yaml` (for factual image captioning) and `caption_coco_eval_mask_gen.yaml` (for counterfactual image captioning) for `src/run_scripts/eval_caption.sheval_caption.sh`. | ||
|
||
|
||
### Cases Study | ||
 | ||
<br> | ||
|
||
|
||
|
||
|
||
# Contact | ||
For any questions, please feel free to reach me at caoqian4real[at]ruc.edu.cn. | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
contexttimer==0.3.3 | ||
decord==0.6.0 | ||
diffusers==0.16.0 | ||
einops==0.7.0 | ||
fairscale==0.4.4 | ||
ftfy | ||
iopath | ||
ipython | ||
omegaconf | ||
opencv-python-headless==4.5.5.64 | ||
opendatasets | ||
packaging | ||
pandas | ||
plotly | ||
pre-commit | ||
pycocoevalcap | ||
pycocotools | ||
python-magic | ||
scikit-image | ||
sentencepiece | ||
spacy | ||
streamlit | ||
tensorboard==2.16.2 | ||
timm==0.4.12 | ||
torch==2.2.1 | ||
torchvision==0.17.1 | ||
tqdm | ||
transformers==4.34.0 | ||
webdataset | ||
wheel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
""" | ||
Copyright (c) 2022, salesforce.com, inc. | ||
All rights reserved. | ||
SPDX-License-Identifier: BSD-3-Clause | ||
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
""" | ||
|
||
import argparse | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
import torch.backends.cudnn as cudnn | ||
|
||
import lavis.tasks as tasks | ||
from lavis.common.config import Config | ||
from lavis.common.dist_utils import get_rank, init_distributed_mode | ||
from lavis.common.logger import setup_logger | ||
from lavis.common.optims import ( | ||
LinearWarmupCosineLRScheduler, | ||
LinearWarmupStepLRScheduler, | ||
) | ||
from lavis.common.utils import now | ||
|
||
# imports modules for registration | ||
from lavis.datasets.builders import * | ||
from lavis.models import * | ||
from lavis.processors import * | ||
from lavis.runners.runner_base import RunnerBase | ||
from lavis.tasks import * | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Training") | ||
|
||
parser.add_argument("--cfg-path", required=True, help="path to configuration file.") | ||
parser.add_argument( | ||
"--options", | ||
nargs="+", | ||
help="override some settings in the used config, the key-value pair " | ||
"in xxx=yyy format will be merged into config file (deprecate), " | ||
"change to --cfg-options instead.", | ||
) | ||
|
||
args = parser.parse_args() | ||
# if 'LOCAL_RANK' not in os.environ: | ||
# os.environ['LOCAL_RANK'] = str(args.local_rank) | ||
|
||
return args | ||
|
||
|
||
def setup_seeds(config): | ||
seed = config.run_cfg.seed + get_rank() | ||
|
||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
|
||
cudnn.benchmark = False | ||
cudnn.deterministic = True | ||
|
||
|
||
def main(): | ||
# allow auto-dl completes on main process without timeout when using NCCL backend. | ||
# os.environ["NCCL_BLOCKING_WAIT"] = "1" | ||
|
||
# set before init_distributed_mode() to ensure the same job_id shared across all ranks. | ||
job_id = now() | ||
|
||
cfg = Config(parse_args()) | ||
|
||
init_distributed_mode(cfg.run_cfg) | ||
|
||
setup_seeds(cfg) | ||
|
||
# set after init_distributed_mode() to only log on master. | ||
setup_logger() | ||
|
||
cfg.pretty_print() | ||
|
||
task = tasks.setup_task(cfg) | ||
datasets = task.build_datasets(cfg) | ||
model = task.build_model(cfg) | ||
|
||
runner = RunnerBase( | ||
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets | ||
) | ||
runner.evaluate(skip_reload=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
""" | ||
Copyright (c) 2022, salesforce.com, inc. | ||
All rights reserved. | ||
SPDX-License-Identifier: BSD-3-Clause | ||
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | ||
""" | ||
|
||
import os | ||
import sys | ||
|
||
from omegaconf import OmegaConf | ||
|
||
from lavis.common.registry import registry | ||
|
||
from lavis.datasets.builders import * | ||
from lavis.models import * | ||
from lavis.processors import * | ||
from lavis.tasks import * | ||
|
||
|
||
root_dir = os.path.dirname(os.path.abspath(__file__)) | ||
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) | ||
|
||
registry.register_path("library_root", root_dir) | ||
repo_root = os.path.join(root_dir, "..") | ||
registry.register_path("repo_root", repo_root) | ||
cache_root = os.path.join(repo_root, default_cfg.env.cache_root) | ||
registry.register_path("cache_root", cache_root) | ||
|
||
registry.register("MAX_INT", sys.maxsize) | ||
registry.register("SPLIT_NAMES", ["train", "val", "test"]) |
Oops, something went wrong.