Skip to content

Commit

Permalink
integrate into transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
zwhong714 committed Jun 2, 2024
1 parent bb5f394 commit 64a9a1e
Show file tree
Hide file tree
Showing 4,221 changed files with 1,600,993 additions and 159 deletions.
The diff you're trying to view is too large. We only load the first 3000 changed files.
102 changes: 70 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,70 +1,108 @@
# Adaptive decoding
# Adaptive decoding [[paper]](https://arxiv.org/abs/2402.18223)



**TL;DR:** Our new decoding algorithm, Adaptive Decoding, balances the diversity and coherence of open-ended text generation.

## Update
- [2024/5/31] we integrate our method into transformers.
- [2024/5/1] our paper is accepted by ICML2024.
- [2024/2/15] we first release our code and our paper.

**TL;DR:** Our new decoding algorithm, Adaptive Decoding, enhances the diversity and coherence of open-ended text generation.
## Background
<center>
<img src="./img/Background.png" alt="generation2 (1)" style="zoom:50%;" />
</center>

During the generation process, the distribution predicted by the language model (LM) generally falls into two categories. The first is a flattened distribution, indicating that the LM has multiple potential choices for the next token. The second is a sharp distribution, suggesting that the model's choices are more limited. Ensuring that the model dynamically understands the current state is crucial for generating sentences with high diversity and high coherence.

## Introduction
We propose a novel decoding algorithm termed Adaptive Decoding, which leverages entropy principles. Each distribution predicted by the language model can be conceptualized as a state comprising two sets: the candidate set **A** and the ordered set **B**, wherein tokens are arranged by their probabilities.
## Abstract
Current language models decode text token by token according to probabilistic distribution, and determining the appropriate candidates for the next token is crucial to ensure generation quality. This study introduces adaptive decoding, a mechanism that dynamically empowers language models to ascertain a sensible candidate set during generation. Specifically, we introduce an entropy-based metric called confidence and conceptualize determining the optimal candidate set as a confidence-increasing process. The rationality of including a token in the candidate set is assessed by leveraging the increment of confidence.

<center>
<img src="./img/equation.png" alt="generation2 (1)" style="zoom:50%;" />
</center>

By iteratively selecting the token with the highest probability from **B** and adding it to **A**, we can gauge the increment in confidence, which reflects the rationality of incorporating this token into the candidate set.


**Method**: Each distribution predicted by the language model can be conceptualized as a state comprising two sets: the candidate set **A** and the ordered set **B**, wherein tokens are arranged by their probabilities. By iteratively selecting the token with the highest probability from **B** and adding it to **A**, we can gauge the increment in confidence, which reflects the rationality of incorporating this token into the candidate set.

<center>
<img src="./img/equation.png" alt="generation2 (1)" style="zoom:100%;" />
<img src="./img/overview.png" alt="generation2 (1)" style="zoom:25%;" />
</center>

Detailed information can be found in our paper.
**Results**: Experimental results reveal that our method balances diversity and coherence well. The human evaluation shows that our method can generate human-preferred text. Additionally, our method can potentially improve the reasoning ability of language models.

Detailed information can be found in our [paper]((https://arxiv.org/abs/2402.18223)).

## Installation
```
pip install -e transformers-main
```


## Usage

**Hyperparameter**: There is only one hyperparameter we need to tune for optimal generation, and the recommended values are 0.001 or 0.0005.


<center>
<img src="./img/hyperparameter.png" alt="generation2 (1)" style="zoom:50%;" />
</center>




```python
import os
import torch
import numpy as np
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from adaptive import adaptive_decoding
from tqdm import tqdm
from pprint import pprint

os.environ['CUDA_VISIBLE_DEVICES'] = "0"
if torch.cuda.is_available():
print ('Cuda is available.')
cuda_available = torch.cuda.is_available()
device = torch.device('cuda')
device = 'auto'

# Load your model
model_name = 'gpt2-xl'
model_name = 'meta-llama/Llama-2-7b-chat-hf'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.float32)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.float16)
model.eval()

# Using adaptive decoding
prefix = "The city's growth has reflected the push and pull of many social and economic factors."
max_len = 256
epsilon = 0.001
results = adaptive_decoding(model, tokenizer, prefix, max_len, epsilon)
```
**Generation:**
```
The city's growth has reflected the push and pull of many social and economic factors. Some places are flourishing, while others are struggling.
When the Great Recession began in 2008, downtown Charlotte started losing people, and soon downtown lost nearly 3,000 residents since 2010, Census estimates show.

As new development has come to the heart of Charlotte, the city's once-thriving downtown area has lost nearly one-quarter of its population from 2010 to 2015, more than any other area in the city, U.S. Census data shows. (Downtown's population has since been climbing.)
bos_token_id = tokenizer.bos_token_id
eos_token_id = tokenizer.eos_token_id
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id else eos_token_id

While some people may now say the city is thriving, there are signs that downtown is becoming increasingly unaffordable. For instance, rent rates in the city as a whole increased by 28 percent from 2002 to 2015, according to a report from real estate analytics firm Axiometrics.
device = model.device

At the same time, many downtown jobs have been displaced, as well.
sentence = "Paige had 11 songs on her mp3 player. If she deleted 9 old songs from it and then added 8 new songs, how many songs does she have on her mp3 player? "

The jobs displaced in the Charlotte area include retail management, retail sales and service, home furnishings, restaurant management, general office support, administrative support, and travel and hospitality.
prefix = f'''<s>[INST] <<SYS>>You are a help assistant and a math expert. Please solve the following question and directly return me the answer.<</SYS>>
Problem: {sentence}
Let's think step by step\n[/INST]
'''
tokens = tokenizer.tokenize(prefix)
prefix_id_list = tokenizer.convert_tokens_to_ids(tokens)
input_ids = torch.tensor(prefix_id_list).to(device).repeat(1, 1)

So, what's going on? And what can people do to keep downtown growing?
input_ids = model.generate(input_ids, max_new_tokens=512, do_sample=True, ada=5e-4, bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=eos_token_id)
generated_results = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
pprint(generated_results)

What's going on in the Charlotte area
There are several reasons why Charlotte is
```
**Generation:**
```
["Of course! I'd be happy to help you solve this problem. Here's the step-by-step calculation:\n\n1. Paige had 11 songs on her mp3 player initially.\n2. She deleted 9 old songs from her mp3 player, so she has 11 - 9 = 2 songs left.\n3. Then, she added 8 new songs to her mp3 player, so she has 2 + 8 = 10 songs on her mp3 player now.\n\nTherefore, Paige has 10 songs on her mp3 player after deleting 9 old songs and adding 8 new ones."
```


## Citing our paper
If adaptive decoding or this repository is useful in your own research, you can use the following BibTeX entry:
If adaptive decoding or this repository is useful in your own research, you can use the following BibTeX entry. Thanks!🤗🤗🤗
```
@misc{zhu2024improving,
title={Improving Open-Ended Text Generation via Adaptive Decoding},
Expand Down
59 changes: 0 additions & 59 deletions adaptive.py

This file was deleted.

Binary file added img/Background.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/hyperparameter.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions reasoning_evaluation/gsm8k/adaptive/ada_13B_gsm8k.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions reasoning_evaluation/gsm8k/adaptive/ada_70B_gsm8k.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions reasoning_evaluation/gsm8k/adaptive/ada_7B_gsm8k.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions reasoning_evaluation/gsm8k/greedy/greedy_13B_gsm8k.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions reasoning_evaluation/gsm8k/greedy/greedy_70B_gsm8k.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions reasoning_evaluation/gsm8k/greedy/greedy_7B_gsm8k.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions reasoning_evaluation/gsm8k/topp/top95_13B_gsm8k.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions reasoning_evaluation/gsm8k/topp/top95_70B_gsm8k.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions reasoning_evaluation/gsm8k/topp/top95_7B_gsm8k.json

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

134 changes: 66 additions & 68 deletions test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,63 @@
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/wenhongzhu/.conda/envs/zwh/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cuda is available.\n",
"[2024-02-27 22:53:35,700] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
"Cuda is available.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00, 1.37s/it]\n"
]
},
{
"data": {
"text/plain": [
"GPT2LMHeadModel(\n",
" (transformer): GPT2Model(\n",
" (wte): Embedding(50257, 1600)\n",
" (wpe): Embedding(1024, 1600)\n",
" (drop): Dropout(p=0.1, inplace=False)\n",
" (h): ModuleList(\n",
" (0-47): 48 x GPT2Block(\n",
" (ln_1): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2Attention(\n",
" (c_attn): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
"LlamaForCausalLM(\n",
" (model): LlamaModel(\n",
" (embed_tokens): Embedding(32000, 4096)\n",
" (layers): ModuleList(\n",
" (0-31): 32 x LlamaDecoderLayer(\n",
" (self_attn): LlamaSdpaAttention(\n",
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (rotary_emb): LlamaRotaryEmbedding()\n",
" )\n",
" (ln_2): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D()\n",
" (c_proj): Conv1D()\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (mlp): LlamaMLP(\n",
" (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
" (act_fn): SiLU()\n",
" )\n",
" (input_layernorm): LlamaRMSNorm()\n",
" (post_attention_layernorm): LlamaRMSNorm()\n",
" )\n",
" )\n",
" (ln_f): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)\n",
" (norm): LlamaRMSNorm()\n",
" )\n",
" (lm_head): Linear(in_features=1600, out_features=50257, bias=False)\n",
" (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
")"
]
},
"execution_count": 3,
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -56,70 +69,55 @@
"import torch.nn as nn\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"\n",
"import os \n",
"os.environ['CUDA_VISIBLE_DEVICES'] = \"0\"\n",
"if torch.cuda.is_available():\n",
" print ('Cuda is available.')\n",
"cuda_available = torch.cuda.is_available()\n",
"device = torch.device('cuda')\n",
"device = 'auto'\n",
"\n",
"model_name = 'gpt2-xl'\n",
"model_name = 'meta-llama/Llama-2-7b-chat-hf'\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.float32)\n",
"model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.float16)\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"from adaptive import adaptive_decoding\n",
"\n",
"prefix = \"The city's growth has reflected the push and pull of many social and economic factors.\"\n",
"max_len = 256\n",
"epsilon = 0.001\n",
"results = adaptive_decoding(model, tokenizer, prefix, max_len, epsilon)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The city's growth has reflected the push and pull of many social and economic factors. Some places are flourishing, while others are struggling.\n",
"\n",
"When the Great Recession began in 2008, downtown Charlotte started losing people, and soon downtown lost nearly 3,000 residents since 2010, Census estimates show.\n",
"\n",
"As new development has come to the heart of Charlotte, the city's once-thriving downtown area has lost nearly one-quarter of its population from 2010 to 2015, more than any other area in the city, U.S. Census data shows. (Downtown's population has since been climbing.)\n",
"\n",
"While some people may now say the city is thriving, there are signs that downtown is becoming increasingly unaffordable. For instance, rent rates in the city as a whole increased by 28 percent from 2002 to 2015, according to a report from real estate analytics firm Axiometrics.\n",
"\n",
"At the same time, many downtown jobs have been displaced, as well.\n",
"\n",
"The jobs displaced in the Charlotte area include retail management, retail sales and service, home furnishings, restaurant management, general office support, administrative support, and travel and hospitality.\n",
"\n",
"So, what's going on? And what can people do to keep downtown growing?\n",
"\n",
"What's going on in the Charlotte area\n",
"\n",
"There are several reasons why Charlotte is\n"
"[\"[INST] <<SYS>>You are a help assistant and a math expert. Please solve the following question and directly return me the answer.<</SYS>>\\nProblem: Paige had 11 songs on her mp3 player. If she deleted 9 old songs from it and then added 8 new songs, how many songs does she have on her mp3 player? \\nLet's think step by step\\n[/INST]\\nOf course! I'd be happy to help you solve this problem. Here's the step-by-step calculation:\\n\\n1. Paige had 11 songs on her mp3 player initially.\\n2. She deleted 9 old songs from her mp3 player, so she has 11 - 9 = 2 songs left.\\n3. Then, she added 8 new songs to her mp3 player, so she has 2 + 8 = 10 songs on her mp3 player now.\\n\\nTherefore, Paige has 10 songs on her mp3 player after deleting 9 old songs and adding 8 new ones.\"]\n"
]
}
],
"source": [
"print(results[0])"
"from tqdm import tqdm\n",
"from pprint import pprint \n",
"bos_token_id = tokenizer.bos_token_id\n",
"eos_token_id = tokenizer.eos_token_id\n",
"pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id else eos_token_id\n",
"\n",
"device = model.device\n",
"\n",
"sentence = \"Paige had 11 songs on her mp3 player. If she deleted 9 old songs from it and then added 8 new songs, how many songs does she have on her mp3 player? \"\n",
"\n",
"prefix = f'''<s>[INST] <<SYS>>You are a help assistant and a math expert. Please solve the following question and directly return me the answer.<</SYS>>\n",
"Problem: {sentence} \n",
"Let's think step by step\\n[/INST]\n",
"'''\n",
"tokens = tokenizer.tokenize(prefix)\n",
"prefix_id_list = tokenizer.convert_tokens_to_ids(tokens)\n",
"input_ids = torch.tensor(prefix_id_list).to(device).repeat(1, 1)\n",
"\n",
"input_ids = model.generate(input_ids, max_new_tokens=512, do_sample=True, ada=5e-4, bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=eos_token_id)\n",
"generated_results = tokenizer.batch_decode(input_ids, skip_special_tokens=True)\n",
"print(generated_results)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Loading

0 comments on commit 64a9a1e

Please sign in to comment.