Skip to content

Commit

Permalink
20240829: first version
Browse files Browse the repository at this point in the history
  • Loading branch information
Aman-4-Real committed Aug 28, 2024
1 parent a34b690 commit 04b9ec9
Show file tree
Hide file tree
Showing 291 changed files with 128,183 additions and 4 deletions.
134 changes: 130 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,141 @@
# See or Guess: Counterfactually Regularized Image Captioning
[ACM MM 2024]

This is the official repo of the paper **See or Guess: Counterfactually Regularized Image Captioning** accepted by ACM MM 2024.
[![Python](https://img.shields.io/badge/Python-3.9-blue.svg)](#Python)
[![PyTorch](https://img.shields.io/badge/PyTorch-2.2.1-green.svg)](#PyTorch)
[![Transformers](https://img.shields.io/badge/Transformers-4.34.0-orange.svg)](#Transformers)
![](https://img.shields.io/github/last-commit/Aman-4-Real/See-or-Guess?color=white)

The code will be available soon! (before the end of August)

Thanks for your attention and stars! :)
> This repository includes the original implementation of the paper **[See or Guess: Counterfactually Regularized Image Captioning]()** (ACM MM 2024) by Qian Cao et al.



# Abstract
In this work, we present a generic image captioning framework that employs causal inference to make existing models more capable of interventional tasks, and counterfactually explainable.
<details>
<summary> More (Click me) </summary>
Our approach includes two variants leveraging either total effect or natural direct effect.
Integrating them into the training process enables models to handle counterfactual scenarios, increasing their generalizability.
Extensive experiments on various datasets show that our method effectively reduces hallucinations and improves the model's faithfulness to images, demonstrating high portability across both small-scale and large-scale image-to-text models.
</details><br>

![](assets/framework.png)
<br>


If you find our work useful, please cite the paper:
```
@inproceedings{cao2024see,
title={See or Guess: Counterfactually Regularized Image Captioning},
author={Cao, Qian and Chen, Xu and Song, Ruihua and Wang, Xiting and Huang, Xinting and Ren, Yuchen},
booktitle={ACM Multimedia 2024}
}
```




# Content
0. [Before Start](#before-start)
1. [Setup](#setup)
2. [Prepare the Data](#prepare-the-data)
3. [Usage](#usage)
4. [Contact](#contact)



# Before Start
- Our method can be applied to many small- or large-scale image-to-text models. We provide an implementation of BLIP2 in this repo.
- The implementation of BLIP2 is based on the original one in [LAVIS](https://github.com/salesforce/LAVIS). You can find more details [LAVIS/BLIP2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2).



# Setup
Create a new virtual environment:
```
git clone https://github.com/Aman-4-Real/See-or-Guess
cd See-or-Guess/
conda create -n cfric python===3.9
conda activate cfric
```



# Prepare the Data
- Please organize your data as the following format:
```
{
'img_path': 'val2014/COCO_val2014_000000522418.jpg', # relative image path in the dataset dir
'caption': 'A woman wearing a net on her head cutting a cake. ', # string
'phrases': [
{
'phrase': 'cake', # the noun phrase
'boxes': [[x1, y1, x2, y2], ...], # bounding boxes of the noun phrase
'first_word_index': 10 # the index of the first word of noun phrase appearing in the caption
},
...
],
'img_id': '522418' # unique image id
}
```
- Save the data into `.pkl` file and put it in the `YOUR_DATASET_DIR`.
- Change the `url` and `storage` fields in the file `src/lavis/configs/datasets/mscoco/defaults.yaml` to `YOUR_DATASET_DIR/{train,valid,test}.pkl`, correspondingly.



# Usage

### Quick Check Out
For the key implementation, refer to [cfr_caption_datasets.py](src/lavis/datasets/datasets/cfr_caption_datasets.py), [modeling_opt.py](src/lavis/transformers_v4p34_local/models/opt/modeling_opt.py) and [CFRLoss.py](src/lavis/transformers_v4p34_local/models/opt/CFRLoss.py)




### Prepare pretrained checkpoint
Download pretrained BLIP2 checkpoint (e.g., [blip2-opt-2.7b](https://huggingface.co/Salesforce/blip2-opt-2.7b)) to the `ckpt/` folder.




### Workflow
Generally speaking, our work is based on a trained image captioning model ("initial model" in the paper). You can follow the following steps:

1. Prepare or train the initial model
If you need to train the BLIP2 on your dataset, you can follow the instructions in the [LAVIS](https://github.com/salesforce/LAVIS) repo, or you can run
```
cd src/run_scripts/
```
use the config `src/lavis/run_cfgs/caption_coco_ft.yaml`, run
```
bash train_caption.sh
```

2. Use the trained initial model to generate counterfactual captions on the training set. Use the config `src/lavis/run_cfgs/caption_eval_gen_on_train.yaml` and run
```
bash eval_caption.sh
```

3. Use the total effect loss (TE) or natural direct effect loss (NDE) to regularize the training.
Use the config `src/lavis/run_cfgs/caption_coco_ft_te0999.yaml` for TE
```
bash train_caption.sh
```
and `src/lavis/run_cfgs/caption_coco_ft_nde0999.yaml` for NDE. Remember to set both `do_NDE` and `do_TE` to `True` while doing NDE training. Also adjust a proper value of hyperparameter α in the config.



### Evaluation
For evaluation, us the config `caption_coco_ft.yaml` (for factual image captioning) and `caption_coco_eval_mask_gen.yaml` (for counterfactual image captioning) for `src/run_scripts/eval_caption.sheval_caption.sh`.


### Cases Study
![](assets/cases.png)
<br>




# Contact
For any questions, please feel free to reach me at caoqian4real[at]ruc.edu.cn.

Binary file added assets/cases.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/framework.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
30 changes: 30 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
contexttimer==0.3.3
decord==0.6.0
diffusers==0.16.0
einops==0.7.0
fairscale==0.4.4
ftfy
iopath
ipython
omegaconf
opencv-python-headless==4.5.5.64
opendatasets
packaging
pandas
plotly
pre-commit
pycocoevalcap
pycocotools
python-magic
scikit-image
sentencepiece
spacy
streamlit
tensorboard==2.16.2
timm==0.4.12
torch==2.2.1
torchvision==0.17.1
tqdm
transformers==4.34.0
webdataset
wheel
92 changes: 92 additions & 0 deletions src/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import argparse
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn

import lavis.tasks as tasks
from lavis.common.config import Config
from lavis.common.dist_utils import get_rank, init_distributed_mode
from lavis.common.logger import setup_logger
from lavis.common.optims import (
LinearWarmupCosineLRScheduler,
LinearWarmupStepLRScheduler,
)
from lavis.common.utils import now

# imports modules for registration
from lavis.datasets.builders import *
from lavis.models import *
from lavis.processors import *
from lavis.runners.runner_base import RunnerBase
from lavis.tasks import *


def parse_args():
parser = argparse.ArgumentParser(description="Training")

parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)

args = parser.parse_args()
# if 'LOCAL_RANK' not in os.environ:
# os.environ['LOCAL_RANK'] = str(args.local_rank)

return args


def setup_seeds(config):
seed = config.run_cfg.seed + get_rank()

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

cudnn.benchmark = False
cudnn.deterministic = True


def main():
# allow auto-dl completes on main process without timeout when using NCCL backend.
# os.environ["NCCL_BLOCKING_WAIT"] = "1"

# set before init_distributed_mode() to ensure the same job_id shared across all ranks.
job_id = now()

cfg = Config(parse_args())

init_distributed_mode(cfg.run_cfg)

setup_seeds(cfg)

# set after init_distributed_mode() to only log on master.
setup_logger()

cfg.pretty_print()

task = tasks.setup_task(cfg)
datasets = task.build_datasets(cfg)
model = task.build_model(cfg)

runner = RunnerBase(
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
)
runner.evaluate(skip_reload=True)


if __name__ == "__main__":
main()
31 changes: 31 additions & 0 deletions src/lavis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import os
import sys

from omegaconf import OmegaConf

from lavis.common.registry import registry

from lavis.datasets.builders import *
from lavis.models import *
from lavis.processors import *
from lavis.tasks import *


root_dir = os.path.dirname(os.path.abspath(__file__))
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))

registry.register_path("library_root", root_dir)
repo_root = os.path.join(root_dir, "..")
registry.register_path("repo_root", repo_root)
cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
registry.register_path("cache_root", cache_root)

registry.register("MAX_INT", sys.maxsize)
registry.register("SPLIT_NAMES", ["train", "val", "test"])
Loading

0 comments on commit 04b9ec9

Please sign in to comment.