This repo contains official JAX implementation of our paper "CHAMPAGNE: Learning Real-world Conversation from Large-Scale Web Videos"
- [Update: 23/03/29] Now the TPU checkpoints and datasets are released. Please check the Project Page!
Note that all the commands in this document should be run in the commandline of the TPU VM instance unless otherwise stated.
-
Follow the instructions to set up a Google Cloud Platform (GCP) account and enable the Cloud TPU API.
Note: While T5X works with GPU as well, we haven't heavily tested the GPU usage.
-
Create a Cloud TPU VM instance following this instruction. We recommend that you develop your workflow in a single v3-8 TPU (i.e.,
--accelerator-type=v3-8
) and scale up to pod slices once the pipeline is ready. In this README, we focus on using a single v3-8 TPU. See here to learn more about TPU architectures. -
With Cloud TPU VMs, you ssh directly into the host machine of the TPU VM. You can install packages, run your code run, etc. in the host machine. Once the TPU instance is created, ssh into it with
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE}
where
TPU_NAME
andZONE
are the name and the zone used in step 2. -
Install T5X and the dependencies.
git clone --branch=main https://github.com/google-research/t5x cd t5x python3 -m pip install -e '.[tpu]' -f \ https://storage.googleapis.com/jax-releases/libtpu_releases.html
-
Create Google Cloud Storage (GCS) bucket to store the dataset and model checkpoints. To create a GCS bucket, see these instructions.
First, you need to fill in settings.py
with your own config.
# setting up TPU instance with 8 TPU cores
python tpu_install.py --tpu-pod-name ${TPU_NAME} --tpu-size 8
# run script on TPU with 256 TPU cores
python tpu_run.py --tpu-pod-name ${TPU_NAME} --tpu-size 256 --run-sh ${SCRIPT_SH}
If the paper inspires you, please cite us:
@article{han2023champagne,
title={CHAMPAGNE: Learning Real-world Conversation from Large-Scale Web Videos},
author={Han, Seungju and Hessel, Jack and Dziri, Nouha and Choi, Yejin and Yu, Youngjae},
journal={arXiv preprint arXiv:2303.09713},
year={2023}
}