Skip to content

Commit

Permalink
cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonwu0731 committed Oct 1, 2020
1 parent 9e9f340 commit e9623fd
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 375 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
*pyc
data/
*.ipynb_checkpoints
runs/
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# ToD-BERT: Pre-trained Natural Language Understanding for Task-Oriented Dialogues
# TOD-BERT: Pre-trained Natural Language Understanding for Task-Oriented Dialogues

Authors: [Chien-Sheng Wu](https://jasonwu0731.github.io/), [Steven Hoi](http://mysmu.edu.sg/faculty/chhoi/), [Richard Socher](https://www.socher.org/) and [Caiming Xiong](http://cmxiong.com/).
Authors: [Chien-Sheng Wu](https://jasonwu0731.github.io/), [Steven Hoi](http://mysmu.edu.sg/faculty/chhoi/), [Richard Socher](https://www.socher.org/) and [Caiming Xiong](http://cmxiong.com/).

EMNLP 2020. Paper: https://arxiv.org/abs/2004.06871

Paper: https://arxiv.org/abs/2004.06871

## Introduction
The underlying difference of linguistic patterns between general text and task-oriented dialogue makes existing pre-trained language models less effective in practice. In this work, we unify nine human-human and multi-turn task-oriented dialogue datasets for language modeling. To better model dialogue behavior during pre-training, we incorporate user and system special tokens into the masked language modeling, and we add a contrastive objective function with a simulated response selection task. Our pre-trained task-oriented dialogue BERT (TOD-BERT) outperforms strong baselines like BERT in four downstream task-oriented dialogue applications, including intention detection, dialogue state tracking, dialogue act prediction, and response selection. We also show that TOD-BERT has stronger few-shot ability that can mitigate the data scarcity problem for task-oriented dialogue.
The underlying difference of linguistic patterns between general text and task-oriented dialogue makes existing pre-trained language models less useful in practice. In this work, we unify nine human-human and multi-turn task-oriented dialogue datasets for language modeling. To better model dialogue behavior during pre-training, we incorporate user and system tokens into the masked language modeling. We propose a contrastive objective function to simulate the response selection task. Our pre-trained task-oriented dialogue BERT (TOD-BERT) outperforms strong baselines like BERT on four downstream task-oriented dialogue applications, including intention recognition, dialogue state tracking, dialogue act prediction, and response selection. We also show that TOD-BERT has a stronger few-shot ability that can mitigate the data scarcity problem for task-oriented dialogue.


## Citation
If you use any source codes, pretrained models or datasets included in this repo in your work, please cite the following paper. The bibtex is listed below:
Expand All @@ -20,11 +22,13 @@ If you use any source codes, pretrained models or datasets included in this repo


## Update
* (2020.10.01) More training and inference information added. Release TOD-DistilBERT.
* (2020.07.10) Loading model from [Huggingface](https://huggingface.co/) is now supported.
* (2020.04.26) Pre-trained models are available.


## Pretrained Models
You can easily load the pre-trained model using huggingface [Transformer](https://github.com/huggingface/transformers) library using the AutoModel function. Several pre-trained versions are supported:
You can easily load the pre-trained model using huggingface [Transformers](https://github.com/huggingface/transformers) library using the AutoModel function. Several pre-trained versions are supported:
* TODBERT/TOD-BERT-MLM-V1: TOD-BERT pre-trained only using the MLM objective
* TODBERT/TOD-BERT-JNT-V1: TOD-BERT pre-trained using both the MLM and RCL objectives
* TODBERT/TOD-DistilBERT-JNT-V1: TOD-DistilBERT pre-trained using both the MLM and RCL objectives
Expand All @@ -45,9 +49,8 @@ tokenizer = tokenizer_class.from_pretrained(model_name_or_path)
tod_bert = model_class.from_pretrained(model_name_or_path)
```


## Direct Usage
Please refer to the following guide how to use our pre-trained ToD-BERT models. Full training and evaluation code will be released soon. Our model is built on top of the [PyTorch](https://pytorch.org/) library and huggingface [Transformer](https://github.com/huggingface/transformers) library. Let's do a very quick overview of the model architecture and code. Detailed examples for model architecturecan be found in the paper.
Please refer to the following guide how to use our pre-trained ToD-BERT models. Our model is built on top of the [PyTorch](https://pytorch.org/) library and huggingface [Transformers](https://github.com/huggingface/transformers) library. Let's do a very quick overview of the model architecture and code. Detailed examples for model architecturecan be found in the paper.

```
# Encode text
Expand Down
8 changes: 4 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
"distilbert": (DistilBertModel, DistilBertTokenizer, DistilBertConfig),
"electra": (ElectraModel, ElectraTokenizer, ElectraConfig)}


## Fix torch random seed
if args["fix_rand_seed"]:
torch.manual_seed(args["rand_seed"])
Expand Down Expand Up @@ -232,19 +231,21 @@
model.load_state_dict(torch.load(args["load_path"]))
else:
model.load_state_dict(torch.load(args["load_path"], lambda storage, loc: storage))
else:
print("[WARNING] No trained model is loaded...")

if torch.cuda.is_available():
model = model.cuda()

print("[Info] Start Evaluation on dev and test set...")
#if MY_MODEL:
dev_loader = get_loader(args, "dev" , tokenizer, datasets, unified_meta)
tst_loader = get_loader(args, "test" , tokenizer, datasets, unified_meta, shuffle=args["task_name"]=="rs")
model.eval()

for d_eval in ["tst"]: #["dev", "tst"]:
f_w = open(os.path.join(args["output_dir"], "{}_results.txt".format(d_eval)), "w")

# Start evaluating on the test set
## Start evaluating on the test set
test_loss = 0
preds, labels = [], []
pbar = tqdm(locals()["{}_loader".format(d_eval)])
Expand All @@ -254,7 +255,6 @@
test_loss += outputs["loss"]
preds += [item for item in outputs["pred"]]
labels += [item for item in outputs["label"]]
#break

test_loss = test_loss / len(tst_loader)
results = model.evaluation(preds, labels)
Expand Down
93 changes: 22 additions & 71 deletions models/BERT_DST_Picklist.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,77 +189,28 @@ def forward(self, data):#input_ids, input_len, labels, gate_label, n_gpu=1, targ
loss = 0
pred_slot = []

if self.args["oracle_domain"]:

for slot_id in range(self.num_slots):
pred_slot_local = []
for bsz_i in range(batch_size):
hidden_bsz = hidden[bsz_i, :, :]

if slot_id in data["triggered_ds_idx"][bsz_i]:

temp = [i for i, idx in enumerate(data["triggered_ds_idx"][bsz_i]) if idx == slot_id]
assert len(temp) == 1
ds_pos = data["triggered_ds_pos"][bsz_i][temp[0]]

hid_label = self.value_lookup[slot_id].weight # v * d
hidden_ds = hidden_bsz[ds_pos, :].unsqueeze(1) # d * 1
hidden_ds = torch.cat([hidden_ds, hidden_bsz[0, :].unsqueeze(1)], 0) # 2d * 1
hidden_ds = self.project_W_2[0](hidden_ds.transpose(1, 0)).transpose(1, 0) # d * 1

_dist = torch.mm(hid_label, hidden_ds).transpose(1, 0) # 1 * v, 51.6%

_, pred = torch.max(_dist, -1)
pred_item = pred.item()

if labels is not None:

if (self.args["gate_supervision_for_dst"] and labels[bsz_i, slot_id] != 0) or\
(not self.args["gate_supervision_for_dst"]):
_loss = self.nll(_dist, labels[bsz_i, slot_id].unsqueeze(0))
loss += _loss

if self.args["gate_supervision_for_dst"]:
_dist_gate = self.gate_classifier(hidden_ds.transpose(1, 0))
_loss_gate = self.nll(_dist_gate, data["slot_gate"][bsz_i, slot_id].unsqueeze(0))
loss += _loss_gate

if torch.max(_dist_gate, -1)[1].item() == 0:
pred_item = 0

pred_slot_local.append(pred_item)
else:
#print("slot_id Not Found")
pred_slot_local.append(0)

pred_slot.append(torch.tensor(pred_slot_local).unsqueeze(1))

predictions = torch.cat(pred_slot, 1).numpy()
labels = labels.detach().cpu().numpy()

else:
for slot_id in range(self.num_slots): ## note: target_slots are successive
# loss calculation
hid_label = self.value_lookup[slot_id].weight # v * d
num_slot_labels = hid_label.size(0)

_hidden = _gelu(self.project_W_1[slot_id](hidden_rep))
_hidden = torch.cat([hid_label.unsqueeze(0).repeat(batch_size, 1, 1), _hidden.unsqueeze(1).repeat(1, num_slot_labels, 1)], dim=2)
_hidden = _gelu(self.project_W_2[slot_id](_hidden))
_hidden = self.project_W_3[slot_id](_hidden)
_dist = _hidden.squeeze(2) # b * 1 * num_slot_labels

_, pred = torch.max(_dist, -1)
pred_slot.append(pred.unsqueeze(1))
#output.append(_dist)

if labels is not None:
_loss = self.nll(_dist, labels[:, slot_id])
#loss_slot.append(_loss.item())
loss += _loss

predictions = torch.cat(pred_slot, 1).detach().cpu().numpy()
labels = labels.detach().cpu().numpy()
for slot_id in range(self.num_slots): ## note: target_slots are successive
# loss calculation
hid_label = self.value_lookup[slot_id].weight # v * d
num_slot_labels = hid_label.size(0)

_hidden = _gelu(self.project_W_1[slot_id](hidden_rep))
_hidden = torch.cat([hid_label.unsqueeze(0).repeat(batch_size, 1, 1), _hidden.unsqueeze(1).repeat(1, num_slot_labels, 1)], dim=2)
_hidden = _gelu(self.project_W_2[slot_id](_hidden))
_hidden = self.project_W_3[slot_id](_hidden)
_dist = _hidden.squeeze(2) # b * 1 * num_slot_labels

_, pred = torch.max(_dist, -1)
pred_slot.append(pred.unsqueeze(1))
#output.append(_dist)

if labels is not None:
_loss = self.nll(_dist, labels[:, slot_id])
#loss_slot.append(_loss.item())
loss += _loss

predictions = torch.cat(pred_slot, 1).detach().cpu().numpy()
labels = labels.detach().cpu().numpy()

if self.training:
self.loss_grad = loss
Expand Down
1 change: 0 additions & 1 deletion models/multi_class_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(self, args): #, num_labels, device):
self.rnn_num_layers = args["num_rnn_layers"]
self.num_labels = args["num_labels"]
self.xeloss = nn.CrossEntropyLoss()
#self.sigmoid = nn.Sigmoid()
self.n_gpu = args["n_gpu"]

### Utterance Encoder
Expand Down
Loading

0 comments on commit e9623fd

Please sign in to comment.