Skip to content

Beyond Uniform Query Distribution: Key-Driven Grouped Query Attention

License

Notifications You must be signed in to change notification settings

zohaib-khan5040/key-driven-gqa

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Beyond Uniform Query Distribution: Key-Driven Grouped Query Attention

Official implementation of Key-Driven GQA as presented in our paper: Beyond Uniform Query Distribution: Key-Driven Grouped Query Attention
Zohaib Khan*, Muhammad Khaquan*, Omer Tafveez, Burhanuddin Samiwala, Agha Ali Raza ( indicates equal contribution)
Lahore University of Management Sciences

@misc{khan2024uniformquerydistributionkeydriven,
      title={Beyond Uniform Query Distribution: Key-Driven Grouped Query Attention}, 
      author={Zohaib Khan and Muhammad Khaquan and Omer Tafveez and Agha Ali Raza},
      year={2024},
      eprint={2408.08454},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2408.08454}, 
}

Setup

Run pip install -r requirements.txt

  • Download tiny-imagenet-200 from here
  • Download CINIC-10 from here
  • To use ImageNet-1k, you would have to log in to HuggingFace and provide a token - the simplest way is through huggingface-cli login on your terminal.

Defining an experiment

We use yaml files to create our configurations. We expect the following structure:

dataset: 					# 'cifar10' # one of {cifar10, cifar100, food101, tiny-imagenet-200}
in_chans: 3
size: 'b'					# one of {s,b,l}
att_scheme: 'gqa'				# one of {gqa, kdgqa, dgqa_diff, dgqa_ema, pgqa}
num_classes: 10
pretrained: True
window_size: 300
num_kv_heads: 6					# dependent on the model size, and its number of heads
out_dir: "cifar10-base-gqa-pretrained"		# directory to which the outputs are saved

Running

Use train.py and provide arguments:

  • --config: path to the configuration file (must be yaml)
  • --out_dir: path to the directory where to save outputs
  • --save_model: whether to save the model checkpoints in the output directory
  • --pretrained_ckpt: path to the checkpoint, if any, to use for the training (could be for uptraining or fine-tuning)

Example usage: python train.py --config path/to/config.yaml --out_dir output_dir/ --save_model True --pretrained-ckpt path/to/ckpt.pth

About

Beyond Uniform Query Distribution: Key-Driven Grouped Query Attention

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages