diff --git a/evaluation_pipeline.sh b/evaluation_pipeline.sh new file mode 100755 index 0000000..f3c67f1 --- /dev/null +++ b/evaluation_pipeline.sh @@ -0,0 +1,137 @@ +gpu=$1 +model=$2 +bert_dir=$3 +output_dir=$4 +add1=$5 +add2=$6 +add3=$7 + +# ./evaluation_pipeline.sh 0 bert bert-base-uncased save/BERT + +# Intent +# for bsz in 8 16 32 +# do +# CUDA_VISIBLE_DEVICES=$gpu python main.py \ +# --my_model=multi_class_classifier \ +# --dataset='["oos_intent"]' \ +# --task_name="intent" \ +# --earlystop="acc" \ +# --output_dir=${output_dir}/Intent/OOS/BSZ${bsz} \ +# --do_train \ +# --task=nlu \ +# --example_type=turn \ +# --model_type=${model} \ +# --model_name_or_path=${bert_dir} \ +# --batch_size=${bsz} \ +# --usr_token=[USR] --sys_token=[SYS] \ +# --epoch=50 --eval_by_step=500 --warmup_steps=250 \ +# $add1 $add2 $add3 +# done + +# DST +# CUDA_VISIBLE_DEVICES=$gpu python main.py \ +# --my_model=BeliefTracker \ +# --model_type=${model} \ +# --dataset='["multiwoz"]' \ +# --task_name="dst" \ +# --earlystop="joint_acc" \ +# --output_dir=${output_dir}/DST/MWOZ \ +# --do_train \ +# --task=dst \ +# --example_type=turn \ +# --model_name_or_path=${bert_dir} \ +# --batch_size=6 --eval_batch_size=6 \ +# --usr_token=[USR] --sys_token=[SYS] \ +# --eval_by_step=4000 \ +# $add1 $add2 $add3 + +# DA +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=multi_label_classifier \ + --do_train --dataset='["multiwoz"]' \ + --task=dm --task_name=sysact --example_type=turn \ + --model_type=${model} \ + --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/DA/MWOZ/BSZ8 \ + --batch_size=8 \ + --eval_batch_size=4 \ + --learning_rate=5e-5 \ + --eval_by_step=1000 \ + --usr_token=[USR] --sys_token=[SYS] \ + --earlystop=f1_weighted \ + $add1 $add2 $add3 + +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=multi_label_classifier \ + --do_train \ + --dataset='["universal_act_dstc2"]' \ + --task=dm --task_name=sysact --example_type=turn \ + --model_type=${model} --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/DA/DSTC2/BSZ8 \ + --batch_size=8 \ + --eval_batch_size=4 \ + --learning_rate=5e-5 \ + --eval_by_step=500 \ + --usr_token=[USR] --sys_token=[SYS] \ + --earlystop=f1_weighted \ + $add1 $add2 $add3 + +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=multi_label_classifier \ + --do_train \ + --dataset='["universal_act_sim_joint"]' \ + --task=dm --task_name=sysact --example_type=turn \ + --model_type=${model} --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/DA/SIM_JOINT/BSZ8 \ + --batch_size=8 \ + --eval_batch_size=4 \ + --learning_rate=5e-5 \ + --eval_by_step=500 \ + --usr_token=[USR] --sys_token=[SYS] \ + --earlystop=f1_weighted \ + $add1 $add2 $add3 + +# Response Selection +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=dual_encoder_ranking \ + --do_train \ + --task=nlg \ + --task_name=rs \ + --example_type=turn \ + --model_type=${model} \ + --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/RS/MWOZ/ \ + --batch_size=25 --eval_batch_size=100 \ + --usr_token=[USR] --sys_token=[SYS] \ + --fix_rand_seed \ + --eval_by_step=1000 \ + --max_seq_length=256 \ + $add1 $add2 $add3 + +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=dual_encoder_ranking \ + --do_train \ + --dataset='["universal_act_dstc2"]' \ + --task=nlg --task_name=rs \ + --example_type=turn \ + --model_type=${model} \ + --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/RS/DSTC2/ \ + --batch_size=25 --eval_batch_size=100 \ + --max_seq_length=256\ + --fix_rand_seed \ + $add1 $add2 $add3 + +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=dual_encoder_ranking \ + --do_train \ + --dataset='["universal_act_sim_joint"]' \ + --task=nlg --task_name=rs \ + --example_type=turn \ + --model_type=${model} \ + --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/RS/SIM_JOINT/ \ + --batch_size=25 --eval_batch_size=100 \ + --max_seq_length=256 \ + --fix_rand_seed \ + $add1 $add2 $add3 \ No newline at end of file diff --git a/evaluation_ratio_pipeline.sh b/evaluation_ratio_pipeline.sh new file mode 100755 index 0000000..e1edd7f --- /dev/null +++ b/evaluation_ratio_pipeline.sh @@ -0,0 +1,275 @@ +gpu=$1 +model=$2 +bert_dir=$3 +output_dir=$4 +add1=$5 +add2=$6 +add3=$7 + +# ./evaluation_ratio_pipeline.sh 0 bert bert-base-uncased save/BERT --nb_runs=3 + +# Intent +for ratio in 0.01 0.1 +do + CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=multi_class_classifier \ + --dataset='["oos_intent"]' \ + --task_name="intent" \ + --earlystop="acc" \ + --output_dir=${output_dir}/Intent/OOS-Ratio/R${ratio} \ + --do_train \ + --task=nlu \ + --example_type=turn \ + --model_type=${model} \ + --model_name_or_path=${bert_dir} \ + --batch_size=16 \ + --usr_token=[USR] --sys_token=[SYS] \ + --epoch=500 --eval_by_step=100 --warmup_steps=100 \ + --train_data_ratio=${ratio} \ + $add1 $add2 $add3 +done + +# DST +for ratio in 0.01 0.05 +do +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=BeliefTracker \ + --model_type=${model} \ + --dataset='["multiwoz"]' \ + --task_name="dst" \ + --earlystop="joint_acc" \ + --output_dir=${output_dir}/DST/MWOZ-Ratio/R${ratio} \ + --do_train \ + --task=dst \ + --example_type=turn \ + --model_name_or_path=${bert_dir} \ + --batch_size=8 --eval_batch_size=8 \ + --usr_token=[USR] --sys_token=[SYS] \ + --eval_by_step=200 \ + --train_data_ratio=${ratio} \ + $add1 $add2 $add3 +done +for ratio in 0.1 0.25 +do +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=BeliefTracker \ + --model_type=${model} \ + --dataset='["multiwoz"]' \ + --task_name="dst" \ + --earlystop="joint_acc" \ + --output_dir=${output_dir}/DST/MWOZ-Ratio/R${ratio} \ + --do_train \ + --task=dst \ + --example_type=turn \ + --model_name_or_path=${bert_dir} \ + --batch_size=8 --eval_batch_size=8 \ + --usr_token=[USR] --sys_token=[SYS] \ + --eval_by_step=500 \ + --train_data_ratio=${ratio} \ + $add1 $add2 $add3 +done + +# DA +for ratio in 0.01 +do +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=multi_label_classifier \ + --do_train --dataset='["multiwoz"]' \ + --task=dm --task_name=sysact --example_type=turn \ + --model_type=${model} \ + --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/DA/MWOZ-Ratio/R${ratio} \ + --batch_size=8 \ + --eval_batch_size=4 \ + --learning_rate=5e-5 \ + --eval_by_step=200 \ + --usr_token=[USR] --sys_token=[SYS] \ + --earlystop=f1_weighted \ + --train_data_ratio=${ratio} \ + $add1 $add2 $add3 + +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=multi_label_classifier \ + --do_train \ + --dataset='["universal_act_dstc2"]' \ + --task=dm --task_name=sysact --example_type=turn \ + --model_type=${model} --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/DA/DSTC2-Ratio/R${ratio} \ + --batch_size=8 \ + --eval_batch_size=4 \ + --learning_rate=5e-5 \ + --eval_by_step=100 \ + --usr_token=[USR] --sys_token=[SYS] \ + --earlystop=f1_weighted \ + --train_data_ratio=${ratio} \ + $add1 $add2 $add3 + +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=multi_label_classifier \ + --do_train \ + --dataset='["universal_act_sim_joint"]' \ + --task=dm --task_name=sysact --example_type=turn \ + --model_type=${model} --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/DA/SIM_JOINT-Ratio/R${ratio} \ + --batch_size=8 \ + --eval_batch_size=4 \ + --learning_rate=5e-5 \ + --eval_by_step=100 \ + --usr_token=[USR] --sys_token=[SYS] \ + --earlystop=f1_weighted \ + --train_data_ratio=${ratio} \ + $add1 $add2 $add3 +done + +for ratio in 0.1 +do +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=multi_label_classifier \ + --do_train --dataset='["multiwoz"]' \ + --task=dm --task_name=sysact --example_type=turn \ + --model_type=${model} \ + --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/DA/MWOZ-Ratio/R${ratio} \ + --batch_size=8 \ + --eval_batch_size=4 \ + --learning_rate=5e-5 \ + --eval_by_step=500 \ + --usr_token=[USR] --sys_token=[SYS] \ + --earlystop=f1_weighted \ + --train_data_ratio=${ratio} \ + $add1 $add2 $add3 + +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=multi_label_classifier \ + --do_train \ + --dataset='["universal_act_dstc2"]' \ + --task=dm --task_name=sysact --example_type=turn \ + --model_type=${model} --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/DA/DSTC2-Ratio/R${ratio} \ + --batch_size=8 \ + --eval_batch_size=4 \ + --learning_rate=5e-5 \ + --eval_by_step=500 \ + --usr_token=[USR] --sys_token=[SYS] \ + --earlystop=f1_weighted \ + --train_data_ratio=${ratio} \ + $add1 $add2 $add3 + +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=multi_label_classifier \ + --do_train \ + --dataset='["universal_act_sim_joint"]' \ + --task=dm --task_name=sysact --example_type=turn \ + --model_type=${model} --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/DA/SIM_JOINT-Ratio/R${ratio} \ + --batch_size=8 \ + --eval_batch_size=4 \ + --learning_rate=5e-5 \ + --eval_by_step=500 \ + --usr_token=[USR] --sys_token=[SYS] \ + --earlystop=f1_weighted \ + --train_data_ratio=${ratio} \ + $add1 $add2 $add3 +done + + +# RS +for ratio in 0.01 +do +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=dual_encoder_ranking \ + --do_train \ + --task=nlg \ + --task_name=rs \ + --example_type=turn \ + --model_type=${model} \ + --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/RS/MWOZ-Ratio/R${ratio} \ + --batch_size=25 --eval_batch_size=100 \ + --usr_token=[USR] --sys_token=[SYS] \ + --fix_rand_seed \ + --eval_by_step=200 \ + --train_data_ratio=${ratio} \ + --max_seq_length=256 \ + $add1 $add2 $add3 + +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=dual_encoder_ranking \ + --do_train \ + --dataset='["universal_act_dstc2"]' \ + --task=nlg --task_name=rs \ + --example_type=turn \ + --model_type=${model} \ + --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/RS/DSTC2-Ratio/R${ratio} \ + --batch_size=25 --eval_batch_size=100 \ + --eval_by_step=100 \ + --train_data_ratio=${ratio} \ + --max_seq_length=256 \ + $add1 $add2 $add3 + +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=dual_encoder_ranking \ + --do_train \ + --dataset='["universal_act_sim_joint"]' \ + --task=nlg --task_name=rs \ + --example_type=turn \ + --model_type=${model} \ + --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/RS/SIM_JOINT-Ratio/R${ratio} \ + --batch_size=25 --eval_batch_size=100 \ + --eval_by_step=100 \ + --train_data_ratio=${ratio} \ + --max_seq_length=256 \ + $add1 $add2 $add3 +done +for ratio in 0.1 +do +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=dual_encoder_ranking \ + --do_train \ + --task=nlg \ + --task_name=rs \ + --example_type=turn \ + --model_type=${model} \ + --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/RS/MWOZ-Ratio/R${ratio} \ + --batch_size=25 --eval_batch_size=100 \ + --usr_token=[USR] --sys_token=[SYS] \ + --fix_rand_seed \ + --eval_by_step=500 \ + --train_data_ratio=${ratio} \ + --max_seq_length=256 \ + $add1 $add2 $add3 + +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=dual_encoder_ranking \ + --do_train \ + --dataset='["universal_act_dstc2"]' \ + --task=nlg --task_name=rs \ + --example_type=turn \ + --model_type=${model} \ + --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/RS/DSTC2-Ratio/R${ratio} \ + --batch_size=25 --eval_batch_size=100 \ + --eval_by_step=250 \ + --train_data_ratio=${ratio} \ + --max_seq_length=256 \ + $add1 $add2 $add3 + +CUDA_VISIBLE_DEVICES=$gpu python main.py \ + --my_model=dual_encoder_ranking \ + --do_train \ + --dataset='["universal_act_sim_joint"]' \ + --task=nlg --task_name=rs \ + --example_type=turn \ + --model_type=${model} \ + --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir}/RS/SIM_JOINT-Ratio/R${ratio} \ + --batch_size=25 --eval_batch_size=100 \ + --eval_by_step=250 \ + --train_data_ratio=${ratio} \ + --max_seq_length=256 \ + $add1 $add2 $add3 +done + diff --git a/main.py b/main.py new file mode 100644 index 0000000..e253484 --- /dev/null +++ b/main.py @@ -0,0 +1,263 @@ +from tqdm import tqdm +import torch.nn as nn +import logging +import ast +import glob +import numpy as np +import copy + +# utils +from utils.config import * +from utils.utils_general import * +from utils.utils_multiwoz import * +from utils.utils_oos_intent import * +from utils.utils_universal_act import * + +# models +from models.multi_label_classifier import * +from models.multi_class_classifier import * +from models.BERT_DST_Picklist import * +from models.dual_encoder_ranking import * + +# hugging face models +from transformers import * + +try: + from torch.utils.tensorboard import SummaryWriter +except ImportError: + from tensorboardX import SummaryWriter + +## model selection +MODELS = {"bert": (BertModel, BertTokenizer, BertConfig), + "todbert": (BertModel, BertTokenizer, BertConfig), + "gpt2": (GPT2Model, GPT2Tokenizer, GPT2Config), + "todgpt2": (GPT2Model, GPT2Tokenizer, GPT2Config), + "dialogpt": (AutoModelWithLMHead, AutoTokenizer, GPT2Config), + "albert": (AlbertModel, AlbertTokenizer, AlbertConfig), + "roberta": (RobertaModel, RobertaTokenizer, RobertaConfig), + "distilbert": (DistilBertModel, DistilBertTokenizer, DistilBertConfig), + "electra": (ElectraModel, ElectraTokenizer, ElectraConfig)} + + +## Fix torch random seed +if args["fix_rand_seed"]: + torch.manual_seed(args["rand_seed"]) + + +## Reading data and create data loaders +datasets = {} +for ds_name in ast.literal_eval(args["dataset"]): + data_trn, data_dev, data_tst, data_meta = globals()["prepare_data_{}".format(ds_name)](args) + datasets[ds_name] = {"train": data_trn, "dev":data_dev, "test": data_tst, "meta":data_meta} +unified_meta = get_unified_meta(datasets) +if "resp_cand_trn" not in unified_meta.keys(): unified_meta["resp_cand_trn"] = {} +args["unified_meta"] = unified_meta + + +## Create vocab and model class +args["model_type"] = args["model_type"].lower() +model_class, tokenizer_class, config_class = MODELS[args["model_type"]] +tokenizer = tokenizer_class.from_pretrained(args["model_name_or_path"], cache_dir=args["cache_dir"]) +args["model_class"] = model_class +args["tokenizer"] = tokenizer +if args["model_name_or_path"]: + config = config_class.from_pretrained(args["model_name_or_path"], cache_dir=args["cache_dir"]) +else: + config = config_class() +args["config"] = config +args["num_labels"] = unified_meta["num_labels"] + + +## Training and Testing Loop +if args["do_train"]: + result_runs = [] + output_dir_origin = str(args["output_dir"]) + + ## Setup logger + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', + datefmt='%m-%d %H:%M', + filename=os.path.join(args["output_dir"], "train.log"), + filemode='w') + console = logging.StreamHandler() + console.setLevel(logging.INFO) + formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') + console.setFormatter(formatter) + logging.getLogger('').addHandler(console) + + ## training loop + for run in range(args["nb_runs"]): + + ## Setup random seed and output dir + rand_seed = SEEDS[run] + if args["fix_rand_seed"]: + torch.manual_seed(rand_seed) + args["rand_seed"] = rand_seed + args["output_dir"] = os.path.join(output_dir_origin, "run{}".format(run)) + os.makedirs(args["output_dir"], exist_ok=False) + logging.info("Running Random Seed: {}".format(rand_seed)) + + ## Loading model + model = globals()[args['my_model']](args) + if torch.cuda.is_available(): model = model.cuda() + + ## Create Dataloader + trn_loader = get_loader(args, "train", tokenizer, datasets, unified_meta) + dev_loader = get_loader(args, "dev" , tokenizer, datasets, unified_meta, shuffle=args["task_name"]=="rs") + tst_loader = get_loader(args, "test" , tokenizer, datasets, unified_meta, shuffle=args["task_name"]=="rs") + + ## Create TF Writer + tb_writer = SummaryWriter(comment=args["output_dir"].replace("/", "-")) + + # Start training process with early stopping + loss_best, acc_best, cnt, train_step = 1e10, -1, 0, 0 + + try: + for epoch in range(args["epoch"]): + logging.info("Epoch:{}".format(epoch+1)) + train_loss = 0 + pbar = tqdm(trn_loader) + for i, d in enumerate(pbar): + model.train() + outputs = model(d) + train_loss += outputs["loss"] + train_step += 1 + pbar.set_description("Training Loss: {:.4f}".format(train_loss/(i+1))) + + ## Dev Evaluation + if (train_step % args["eval_by_step"] == 0 and args["eval_by_step"] != -1) or \ + (i == len(pbar)-1 and args["eval_by_step"] == -1): + model.eval() + dev_loss = 0 + preds, labels = [], [] + ppbar = tqdm(dev_loader) + for d in ppbar: + with torch.no_grad(): + outputs = model(d) + #print(outputs) + dev_loss += outputs["loss"] + preds += [item for item in outputs["pred"]] + labels += [item for item in outputs["label"]] + + dev_loss = dev_loss / len(dev_loader) + results = model.evaluation(preds, labels) + dev_acc = results[args["earlystop"]] if args["earlystop"] != "loss" else dev_loss + + ## write to tensorboard + tb_writer.add_scalar("train_loss", train_loss/(i+1), train_step) + tb_writer.add_scalar("eval_loss", dev_loss, train_step) + tb_writer.add_scalar("eval_{}".format(args["earlystop"]), dev_acc, train_step) + + if (dev_loss < loss_best and args["earlystop"] == "loss") or \ + (dev_acc > acc_best and args["earlystop"] != "loss"): + loss_best = dev_loss + acc_best = dev_acc + cnt = 0 # reset + + if args["not_save_model"]: + model_clone = globals()[args['my_model']](args) + model_clone.load_state_dict(copy.deepcopy(model.state_dict())) + else: + output_model_file = os.path.join(args["output_dir"], "pytorch_model.bin") + if args["n_gpu"] == 1: + torch.save(model.state_dict(), output_model_file) + else: + torch.save(model.module.state_dict(), output_model_file) + logging.info("[Info] Model saved at epoch {} step {}".format(epoch, train_step)) + else: + cnt += 1 + logging.info("[Info] Early stop count: {}/{}...".format(cnt, args["patience"])) + + if cnt > args["patience"]: + logging.info("Ran out of patient, early stop...") + break + + logging.info("Trn loss {:.4f}, Dev loss {:.4f}, Dev {} {:.4f}".format(train_loss/(i+1), + dev_loss, + args["earlystop"], + dev_acc)) + + if cnt > args["patience"]: + tb_writer.close() + break + + except KeyboardInterrupt: + logging.info("[Warning] Earlystop by KeyboardInterrupt") + + ## Load the best model + if args["not_save_model"]: + model.load_state_dict(copy.deepcopy(model_clone.state_dict())) + else: + # Start evaluating on the test set + if torch.cuda.is_available(): + model.load_state_dict(torch.load(output_model_file)) + else: + model.load_state_dict(torch.load(output_model_file, lambda storage, loc: storage)) + + ## Run test set evaluation + pbar = tqdm(tst_loader) + for nb_eval in range(args["nb_evals"]): + test_loss = 0 + preds, labels = [], [] + for d in pbar: + with torch.no_grad(): + outputs = model(d) + test_loss += outputs["loss"] + preds += [item for item in outputs["pred"]] + labels += [item for item in outputs["label"]] + + test_loss = test_loss / len(tst_loader) + results = model.evaluation(preds, labels) + result_runs.append(results) + logging.info("[{}] Test Results: ".format(nb_eval) + str(results)) + + ## Average results over runs + if args["nb_runs"] > 1: + f_out = open(os.path.join(output_dir_origin, "eval_results_multi-runs.txt"), "w") + f_out.write("Average over {} runs and {} evals \n".format(args["nb_runs"], args["nb_evals"])) + for key in results.keys(): + mean = np.mean([r[key] for r in result_runs]) + std = np.std([r[key] for r in result_runs]) + f_out.write("{}: mean {} std {} \n".format(key, mean, std)) + f_out.close() + +else: + + ## Load Model + print("[Info] Loading model from {}".format(args['my_model'])) + model = globals()[args['my_model']](args) + if args["load_path"]: + print("MODEL {} LOADED".format(args["load_path"])) + if torch.cuda.is_available(): + model.load_state_dict(torch.load(args["load_path"])) + else: + model.load_state_dict(torch.load(args["load_path"], lambda storage, loc: storage)) + if torch.cuda.is_available(): + model = model.cuda() + + print("[Info] Start Evaluation on dev and test set...") + #if MY_MODEL: + dev_loader = get_loader(args, "dev" , tokenizer, datasets, unified_meta) + tst_loader = get_loader(args, "test" , tokenizer, datasets, unified_meta, shuffle=args["task_name"]=="rs") + model.eval() + + for d_eval in ["tst"]: #["dev", "tst"]: + f_w = open(os.path.join(args["output_dir"], "{}_results.txt".format(d_eval)), "w") + + # Start evaluating on the test set + test_loss = 0 + preds, labels = [], [] + pbar = tqdm(locals()["{}_loader".format(d_eval)]) + for d in pbar: + with torch.no_grad(): + outputs = model(d) + test_loss += outputs["loss"] + preds += [item for item in outputs["pred"]] + labels += [item for item in outputs["label"]] + #break + + test_loss = test_loss / len(tst_loader) + results = model.evaluation(preds, labels) + print("{} Results: {}".format(d_eval, str(results))) + f_w.write(str(results)) + f_w.close() diff --git a/models/BERT_DST_Picklist.py b/models/BERT_DST_Picklist.py new file mode 100644 index 0000000..4b5ec21 --- /dev/null +++ b/models/BERT_DST_Picklist.py @@ -0,0 +1,321 @@ +import os.path +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn import CrossEntropyLoss +from torch.nn import CosineEmbeddingLoss +import numpy as np + +from transformers import * + +def _gelu(x): + """ Original Implementation of the gelu activation function in Google Bert repo when initialy created. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +class BeliefTracker(nn.Module): + def __init__(self, args): + super(BeliefTracker, self).__init__() + + self.args = args + self.n_gpu = args["n_gpu"] + self.hidden_dim = args["hdd_size"] + self.rnn_num_layers = args["num_rnn_layers"] + self.zero_init_rnn = args["zero_init_rnn"] + self.num_direct = 2 if self.args["bidirect"] else 1 + self.num_labels = [len(v) for k, v in args["unified_meta"]["slots"].items()] + self.num_slots = len(self.num_labels) + self.tokenizer = args["tokenizer"] + + self.slots = [k for k, v in self.args["unified_meta"]["slots"].items()] + self.slot_value2id_dict = self.args["unified_meta"]["slots"] + self.slot_id2value_dict = {} + for k, v in self.slot_value2id_dict.items(): + self.slot_id2value_dict[k] = {vv: kk for kk, vv in v.items()} + + #print("self.num_slots", self.num_slots) + + ### Utterance Encoder + self.utterance_encoder = args["model_class"].from_pretrained(self.args["model_name_or_path"]) + + self.bert_output_dim = args["config"].hidden_size + #self.hidden_dropout_prob = self.utterance_encoder.config.hidden_dropout_prob + + if self.args["fix_encoder"]: + print("[Info] Utterance Encoder does not requires grad...") + for p in self.utterance_encoder.parameters(): + p.requires_grad = False + + ### slot, slot-value Encoder (not trainable) + self.sv_encoder = args["model_class"].from_pretrained(self.args["model_name_or_path"]) + print("[Info] SV Encoder does not requires grad...") + for p in self.sv_encoder.parameters(): + p.requires_grad = False + + #self.slot_lookup = nn.Embedding(self.num_slots, self.bert_output_dim) + self.value_lookup = nn.ModuleList([nn.Embedding(num_label, self.bert_output_dim) for num_label in self.num_labels]) + + ### RNN Belief Tracker + #self.nbt = None + #self.linear = nn.Linear(self.hidden_dim, self.bert_output_dim) + #self.layer_norm = nn.LayerNorm(self.bert_output_dim) + + ### Classifier + self.nll = CrossEntropyLoss(ignore_index=-1) + + ### Etc. + #self.dropout = nn.Dropout(self.hidden_dropout_prob) + + ### My Add + self.project_W_1 = nn.ModuleList([nn.Linear(self.bert_output_dim, self.bert_output_dim) \ + for _ in range(self.num_slots)]) + self.project_W_2 = nn.ModuleList([nn.Linear(2*self.bert_output_dim, self.bert_output_dim) \ + for _ in range(self.num_slots)]) + self.project_W_3 = nn.ModuleList([nn.Linear(self.bert_output_dim, 1) \ + for _ in range(self.num_slots)]) + + if self.args["gate_supervision_for_dst"]: + self.gate_classifier = nn.Linear(self.bert_output_dim, 2) + + self.start_token = self.tokenizer.cls_token if "bert" in self.args["model_type"] else self.tokenizer.bos_token + self.sep_token = self.tokenizer.sep_token if "bert" in self.args["model_type"] else self.tokenizer.eos_token + + ## Prepare Optimizer + def get_optimizer_grouped_parameters(model): + param_optimizer = [(n, p) for n, p in model.named_parameters() if p.requires_grad] + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01, + 'lr': args["learning_rate"]}, + {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, + 'lr': args["learning_rate"]}, + ] + return optimizer_grouped_parameters + + if self.n_gpu == 1: + optimizer_grouped_parameters = get_optimizer_grouped_parameters(self) + else: + optimizer_grouped_parameters = get_optimizer_grouped_parameters(self.module) + + + self.optimizer = AdamW(optimizer_grouped_parameters, + lr=args["learning_rate"],) + #warmup=args["warmup_proportion"], + #t_total=t_total) + + self.initialize_slot_value_lookup() + + def optimize(self): + self.loss_grad.backward() + clip_norm = torch.nn.utils.clip_grad_norm_(self.parameters(), self.args["grad_clip"]) + self.optimizer.step() + + + def initialize_slot_value_lookup(self, max_seq_length=32): + + self.sv_encoder.eval() + + label_ids = [] + for dslot, value_dict in self.args["unified_meta"]["slots"].items(): + label_id = [] + value_dict_rev = {v:k for k, v in value_dict.items()} + for i in range(len(value_dict)): + label = value_dict_rev[i] + label = " ".join([i for i in label.split(" ") if i != ""]) + + label_tokens = [self.start_token] + self.tokenizer.tokenize(label) + [self.sep_token] + label_token_ids = self.tokenizer.convert_tokens_to_ids(label_tokens) + label_len = len(label_token_ids) + + label_padding = [0] * (max_seq_length - len(label_token_ids)) + label_token_ids += label_padding + assert len(label_token_ids) == max_seq_length + label_id.append(label_token_ids) + + label_id = torch.tensor(label_id).long() + label_ids.append(label_id) + + for s, label_id in enumerate(label_ids): + inputs = {"input_ids":label_id, "attention_mask":(label_id > 0).long()} + + if self.args["sum_token_emb_for_value"]: + hid_label = self.utterance_encoder.embeddings(input_ids=label_id).sum(1) + else: + if "bert" in self.args["model_type"]: + hid_label = self.sv_encoder(**inputs)[0] + hid_label = hid_label[:, 0, :] + elif self.args["model_type"] == "gpt2": + hid_label = self.sv_encoder(**inputs)[0] + hid_label = hid_label.mean(1) + elif self.args["model_type"] == "dialogpt": + transformer_outputs = self.sv_encoder.transformer(**inputs)[0] + hid_label = transformer_outputs.mean(1) + + hid_label = hid_label.detach() + self.value_lookup[s] = nn.Embedding.from_pretrained(hid_label, freeze=True) + self.value_lookup[s].padding_idx = -1 + + print("Complete initialization of slot and value lookup") + + def forward(self, data):#input_ids, input_len, labels, gate_label, n_gpu=1, target_slot=None): + batch_size = data["context"].size(0) + labels = data["belief_ontology"] + + # Utterance encoding + inputs = {"input_ids": data["context"], "attention_mask":(data["context"] > 0).long()} + + if "bert" in self.args["model_type"]: + hidden = self.utterance_encoder(**inputs)[0] + hidden_rep = hidden[:, 0, :] + elif self.args["model_type"] == "gpt2": + hidden = self.utterance_encoder(**inputs)[0] + hidden_rep = hidden.mean(1) + elif self.args["model_type"] == "dialogpt": + #outputs = self.utterance_encoder(**inputs)[2] # 0 is vocab logits, 1 is a tuple of attn head + transformer_outputs = self.utterance_encoder.transformer( + data["context"], + attention_mask=(data["context"] > 0).long() + ) + hidden = transformer_outputs[0] + hidden_rep = hidden.mean(1) + + # Label (slot-value) encoding + loss = 0 + pred_slot = [] + + if self.args["oracle_domain"]: + + for slot_id in range(self.num_slots): + pred_slot_local = [] + for bsz_i in range(batch_size): + hidden_bsz = hidden[bsz_i, :, :] + + if slot_id in data["triggered_ds_idx"][bsz_i]: + + temp = [i for i, idx in enumerate(data["triggered_ds_idx"][bsz_i]) if idx == slot_id] + assert len(temp) == 1 + ds_pos = data["triggered_ds_pos"][bsz_i][temp[0]] + + hid_label = self.value_lookup[slot_id].weight # v * d + hidden_ds = hidden_bsz[ds_pos, :].unsqueeze(1) # d * 1 + hidden_ds = torch.cat([hidden_ds, hidden_bsz[0, :].unsqueeze(1)], 0) # 2d * 1 + hidden_ds = self.project_W_2[0](hidden_ds.transpose(1, 0)).transpose(1, 0) # d * 1 + + _dist = torch.mm(hid_label, hidden_ds).transpose(1, 0) # 1 * v, 51.6% + + _, pred = torch.max(_dist, -1) + pred_item = pred.item() + + if labels is not None: + + if (self.args["gate_supervision_for_dst"] and labels[bsz_i, slot_id] != 0) or\ + (not self.args["gate_supervision_for_dst"]): + _loss = self.nll(_dist, labels[bsz_i, slot_id].unsqueeze(0)) + loss += _loss + + if self.args["gate_supervision_for_dst"]: + _dist_gate = self.gate_classifier(hidden_ds.transpose(1, 0)) + _loss_gate = self.nll(_dist_gate, data["slot_gate"][bsz_i, slot_id].unsqueeze(0)) + loss += _loss_gate + + if torch.max(_dist_gate, -1)[1].item() == 0: + pred_item = 0 + + pred_slot_local.append(pred_item) + else: + #print("slot_id Not Found") + pred_slot_local.append(0) + + pred_slot.append(torch.tensor(pred_slot_local).unsqueeze(1)) + + predictions = torch.cat(pred_slot, 1).numpy() + labels = labels.detach().cpu().numpy() + + else: + for slot_id in range(self.num_slots): ## note: target_slots are successive + # loss calculation + hid_label = self.value_lookup[slot_id].weight # v * d + num_slot_labels = hid_label.size(0) + + _hidden = _gelu(self.project_W_1[slot_id](hidden_rep)) + _hidden = torch.cat([hid_label.unsqueeze(0).repeat(batch_size, 1, 1), _hidden.unsqueeze(1).repeat(1, num_slot_labels, 1)], dim=2) + _hidden = _gelu(self.project_W_2[slot_id](_hidden)) + _hidden = self.project_W_3[slot_id](_hidden) + _dist = _hidden.squeeze(2) # b * 1 * num_slot_labels + + _, pred = torch.max(_dist, -1) + pred_slot.append(pred.unsqueeze(1)) + #output.append(_dist) + + if labels is not None: + _loss = self.nll(_dist, labels[:, slot_id]) + #loss_slot.append(_loss.item()) + loss += _loss + + predictions = torch.cat(pred_slot, 1).detach().cpu().numpy() + labels = labels.detach().cpu().numpy() + + if self.training: + self.loss_grad = loss + self.optimize() + + if self.args["error_analysis"]: + for bsz_i, (pred, label) in enumerate(zip(np.array(predictions), np.array(labels))): + assert len(pred) == len(label) + joint = 0 + pred_arr, gold_arr = [], [] + for i, p in enumerate(pred): + pred_str = self.slot_id2value_dict[self.slots[i]][p] + gold_str = self.slot_id2value_dict[self.slots[i]][label[i]] + pred_arr.append(self.slots[i]+"-"+pred_str) + gold_arr.append(self.slots[i]+"-"+gold_str) + if pred_str == gold_str or pred_str in gold_str.split("|"): + joint += 1 + if joint != len(pred): + print(data["context_plain"][bsz_i]) + print("Gold:", [s for s in gold_arr if s.split("-")[2] != "none"]) + print("Pred:", [s for s in pred_arr if s.split("-")[2] != "none"]) + print() + + + outputs = {"loss":loss.item(), "pred":predictions, "label":labels} + + return outputs + + def evaluation(self, preds, labels): + preds = np.array(preds) + labels = np.array(labels) + + slot_acc, joint_acc, slot_acc_total, joint_acc_total = 0, 0, 0, 0 + for pred, label in zip(preds, labels): + joint = 0 + + assert len(pred) == len(label) + + for i, p in enumerate(pred): + pred_str = self.slot_id2value_dict[self.slots[i]][p] + gold_str = self.slot_id2value_dict[self.slots[i]][label[i]] + + if pred_str == gold_str or pred_str in gold_str.split("|"): + slot_acc += 1 + joint += 1 + slot_acc_total += 1 + + if joint == len(pred): + joint_acc += 1 + + joint_acc_total += 1 + + joint_acc = joint_acc / joint_acc_total + slot_acc = slot_acc / slot_acc_total + results = {"joint_acc":joint_acc, "slot_acc":slot_acc} + print("Results 1: ", results) + + return results + diff --git a/models/__pycache__/BERT_DST_Picklist.cpython-36.pyc b/models/__pycache__/BERT_DST_Picklist.cpython-36.pyc new file mode 100644 index 0000000..06dc191 Binary files /dev/null and b/models/__pycache__/BERT_DST_Picklist.cpython-36.pyc differ diff --git a/models/__pycache__/PTEncoder_Dial.cpython-36.pyc b/models/__pycache__/PTEncoder_Dial.cpython-36.pyc new file mode 100644 index 0000000..cd66c98 Binary files /dev/null and b/models/__pycache__/PTEncoder_Dial.cpython-36.pyc differ diff --git a/models/__pycache__/PTEncoder_RNN.cpython-36.pyc b/models/__pycache__/PTEncoder_RNN.cpython-36.pyc new file mode 100644 index 0000000..478781a Binary files /dev/null and b/models/__pycache__/PTEncoder_RNN.cpython-36.pyc differ diff --git a/models/__pycache__/TRADE.cpython-36.pyc b/models/__pycache__/TRADE.cpython-36.pyc new file mode 100644 index 0000000..3a86e98 Binary files /dev/null and b/models/__pycache__/TRADE.cpython-36.pyc differ diff --git a/models/__pycache__/dual_encoder_ranking.cpython-36.pyc b/models/__pycache__/dual_encoder_ranking.cpython-36.pyc new file mode 100644 index 0000000..1fb209f Binary files /dev/null and b/models/__pycache__/dual_encoder_ranking.cpython-36.pyc differ diff --git a/models/__pycache__/dual_encoder_ranking_v2.cpython-36.pyc b/models/__pycache__/dual_encoder_ranking_v2.cpython-36.pyc new file mode 100644 index 0000000..c3aa00d Binary files /dev/null and b/models/__pycache__/dual_encoder_ranking_v2.cpython-36.pyc differ diff --git a/models/__pycache__/glue_type_train.cpython-36.pyc b/models/__pycache__/glue_type_train.cpython-36.pyc new file mode 100644 index 0000000..cc2a7ef Binary files /dev/null and b/models/__pycache__/glue_type_train.cpython-36.pyc differ diff --git a/models/__pycache__/multi_class_classifier.cpython-36.pyc b/models/__pycache__/multi_class_classifier.cpython-36.pyc new file mode 100644 index 0000000..82f605f Binary files /dev/null and b/models/__pycache__/multi_class_classifier.cpython-36.pyc differ diff --git a/models/__pycache__/multi_label_classifier.cpython-36.pyc b/models/__pycache__/multi_label_classifier.cpython-36.pyc new file mode 100644 index 0000000..3c26417 Binary files /dev/null and b/models/__pycache__/multi_label_classifier.cpython-36.pyc differ diff --git a/models/__pycache__/response_selection.cpython-36.pyc b/models/__pycache__/response_selection.cpython-36.pyc new file mode 100644 index 0000000..016f229 Binary files /dev/null and b/models/__pycache__/response_selection.cpython-36.pyc differ diff --git a/models/dual_encoder_ranking.py b/models/dual_encoder_ranking.py new file mode 100644 index 0000000..3249b4b --- /dev/null +++ b/models/dual_encoder_ranking.py @@ -0,0 +1,149 @@ +import os.path +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import optim + +from torch.nn import CrossEntropyLoss +from torch.nn import CosineEmbeddingLoss +from sklearn.metrics import f1_score #, average_precision_score +import numpy as np + + +from transformers import * + + +class dual_encoder_ranking(nn.Module): + def __init__(self, args): #, num_labels, device): + super(dual_encoder_ranking, self).__init__() + self.args = args + self.xeloss = nn.CrossEntropyLoss() + self.n_gpu = args["n_gpu"] + + ### Utterance Encoder + self.utterance_encoder = args["model_class"].from_pretrained(self.args["model_name_or_path"]) + + if self.args["fix_encoder"]: + for p in self.utterance_encoder.parameters(): + p.requires_grad = False + + ## Prepare Optimizer + def get_optimizer_grouped_parameters(model): + param_optimizer = [(n, p) for n, p in model.named_parameters() if p.requires_grad] + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01, + 'lr': args["learning_rate"]}, + {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, + 'lr': args["learning_rate"]}, + ] + return optimizer_grouped_parameters + + if self.n_gpu == 1: + optimizer_grouped_parameters = get_optimizer_grouped_parameters(self) + else: + optimizer_grouped_parameters = get_optimizer_grouped_parameters(self.module) + + + self.optimizer = AdamW(optimizer_grouped_parameters, + lr=args["learning_rate"],) + #warmup=args["warmup_proportion"], + #t_total=t_total) + + def optimize(self): + self.loss_grad.backward() + clip_norm = torch.nn.utils.clip_grad_norm_(self.parameters(), self.args["grad_clip"]) + self.optimizer.step() + + def forward(self, data): + #input_ids, input_len, labels=None, n_gpu=1, target_slot=None): + + self.optimizer.zero_grad() + + batch_size = data["context"].size(0) + #max_seq_len = 256 + + interval = 25 + start_list = list(np.arange(0, batch_size, interval)) + end_list = start_list[1:] + [None] + context_outputs, response_outputs = [], [] + + for start, end in zip(start_list, end_list): + + inputs_con = {"input_ids": data["context"][start:end], + "attention_mask": (data["context"][start:end] > 0).long()} + inputs_res = {"input_ids": data["response"][start:end], + "attention_mask": (data["response"][start:end] > 0).long()} + + if "bert" in self.args["model_type"]: + _, context_output = self.utterance_encoder(**inputs_con) + _, response_output = self.utterance_encoder(**inputs_res) + elif self.args["model_type"] == "gpt2": + context_output = self.utterance_encoder(**inputs_con)[0].mean(1) + response_output = self.utterance_encoder(**inputs_res)[0].mean(1) + elif self.args["model_type"] == "dialogpt": + transformer_outputs = self.utterance_encoder.transformer(**inputs_con) + context_output = transformer_outputs[0].mean(1) + transformer_outputs = self.utterance_encoder.transformer(**inputs_res) + response_output = transformer_outputs[0].mean(1) + + context_outputs.append(context_output.cpu()) + response_outputs.append(response_output.cpu()) + + # evaluation for k-to-100 + if (not self.training) and (batch_size < 100): + response_outputs.append(self.final_response_output[:100-batch_size, :]) + + final_context_output = torch.cat(context_outputs, 0) + final_response_output = torch.cat(response_outputs, 0) + + if torch.cuda.is_available(): + final_context_output = final_context_output.cuda() + final_response_output = final_response_output.cuda() + + if (not self.training): + self.final_response_output = final_response_output.cpu() + + # mat + logits = torch.matmul(final_context_output, final_response_output.transpose(1, 0)) + + # loss + labels = torch.tensor(np.arange(batch_size)) + if torch.cuda.is_available(): labels = labels.cuda() + loss = self.xeloss(logits, labels) + + if self.training: + self.loss_grad = loss + self.optimize() + + predictions = np.argsort(logits.detach().cpu().numpy(), axis=1) #torch.argmax(logits, -1) + + outputs = {"loss":loss.item(), + "pred":predictions, + "label":np.arange(batch_size)} + + return outputs + + def evaluation(self, preds, labels): + assert len(preds) == len(labels) + + preds = np.array(preds) + labels = np.array(labels) + + def _recall_topk(preds_top10, labels, k): + preds = preds_top10[:, -k:] + acc = 0 + for li, label in enumerate(labels): + if label in preds[li]: acc += 1 + acc = acc / len(labels) + return acc + + results = {"top-1": _recall_topk(preds, labels, 1), + "top-3": _recall_topk(preds, labels, 3), + "top-5": _recall_topk(preds, labels, 5), + "top-10": _recall_topk(preds, labels, 10)} + + print(results) + + return results diff --git a/models/multi_class_classifier.py b/models/multi_class_classifier.py new file mode 100644 index 0000000..71600e9 --- /dev/null +++ b/models/multi_class_classifier.py @@ -0,0 +1,159 @@ +import os.path +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import optim + +from torch.nn import CrossEntropyLoss +from torch.nn import CosineEmbeddingLoss +from sklearn.metrics import f1_score #, average_precision_score +import numpy as np + + +from transformers import * + + +class multi_class_classifier(nn.Module): + def __init__(self, args): #, num_labels, device): + super(multi_class_classifier, self).__init__() + self.args = args + self.hidden_dim = args["hdd_size"] + self.rnn_num_layers = args["num_rnn_layers"] + self.num_labels = args["num_labels"] + self.xeloss = nn.CrossEntropyLoss() + #self.sigmoid = nn.Sigmoid() + self.n_gpu = args["n_gpu"] + + ### Utterance Encoder + self.utterance_encoder = args["model_class"].from_pretrained(self.args["model_name_or_path"]) + + self.bert_output_dim = args["config"].hidden_size + #self.hidden_dropout_prob = self.utterance_encoder.config.hidden_dropout_prob + + if self.args["fix_encoder"]: + print("[Info] Fixing Encoder...") + for p in self.utterance_encoder.parameters(): + p.requires_grad = False + + if self.args["more_linear_mapping"]: + self.one_more_layer = nn.Linear(self.bert_output_dim, self.bert_output_dim) + + self.classifier = nn.Linear(self.bert_output_dim, self.num_labels) + + ## Prepare Optimizer + def get_optimizer_grouped_parameters(model): + param_optimizer = [(n, p) for n, p in model.named_parameters() if p.requires_grad] + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01, + 'lr': args["learning_rate"]}, + {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, + 'lr': args["learning_rate"]}, + ] + return optimizer_grouped_parameters + + if self.n_gpu == 1: + optimizer_grouped_parameters = get_optimizer_grouped_parameters(self) + else: + optimizer_grouped_parameters = get_optimizer_grouped_parameters(self.module) + + + self.optimizer = AdamW(optimizer_grouped_parameters, + lr=args["learning_rate"],) + #warmup=args["warmup_proportion"], + #t_total=t_total) + + def optimize(self): + self.loss_grad.backward() + clip_norm = torch.nn.utils.clip_grad_norm_(self.parameters(), self.args["grad_clip"]) + self.optimizer.step() + + def forward(self, data): + #input_ids, input_len, labels=None, n_gpu=1, target_slot=None): + + self.optimizer.zero_grad() + + inputs = {"input_ids": data[self.args["input_name"]], "attention_mask":(data[self.args["input_name"]] > 0).long()} + + if self.args["fix_encoder"]: + with torch.no_grad(): + if "gpt2" in self.args["model_type"]: + hidden = self.utterance_encoder(**inputs)[0] + hidden_head = hidden.mean(1) + elif self.args["model_type"] == "dialogpt": + transformer_outputs = self.utterance_encoder.transformer( + inputs["input_ids"], + attention_mask=(inputs["input_ids"] > 0).long())[0] + hidden_head = transformer_outputs.mean(1) + else: + hidden = self.utterance_encoder(**inputs)[0] + hidden_head = hidden[:, 0, :] + else: + if "gpt2" in self.args["model_type"]: + hidden = self.utterance_encoder(**inputs)[0] + hidden_head = hidden.mean(1) + elif self.args["model_type"] == "dialogpt": + transformer_outputs = self.utterance_encoder.transformer( + inputs["input_ids"], + attention_mask=(inputs["input_ids"] > 0).long())[0] + hidden_head = transformer_outputs.mean(1) + else: + hidden = self.utterance_encoder(**inputs)[0] + hidden_head = hidden[:, 0, :] + + # loss + if self.args["more_linear_mapping"]: + hidden_head = self.one_more_layer(hidden_head) + + logits = self.classifier(hidden_head) + loss = self.xeloss(logits, data[self.args["task_name"]]) + + if self.training: + self.loss_grad = loss + self.optimize() + + softmax = nn.Softmax(-1) + predictions = torch.argmax(logits, -1) + + outputs = {"loss":loss.item(), + "pred":predictions.detach().cpu().numpy(), + "label":data[self.args["task_name"]].detach().cpu().numpy(), + "prob":softmax(logits)} + + return outputs + + def evaluation(self, preds, labels): + preds = np.array(preds) + labels = np.array(labels) + + if self.args["task_name"] == "intent": + oos_idx = self.args["unified_meta"]["intent"]["oos"] + acc = (preds == labels).mean() + oos_labels, oos_preds = [], [] + ins_labels, ins_preds = [], [] + for i in range(len(preds)): + if labels[i] != oos_idx: + ins_preds.append(preds[i]) + ins_labels.append(labels[i]) + + oos_labels.append(int(labels[i] == oos_idx)) + oos_preds.append(int(preds[i] == oos_idx)) + + ins_preds = np.array(ins_preds) + ins_labels = np.array(ins_labels) + oos_preds = np.array(oos_preds) + oos_labels = np.array(oos_labels) + ins_acc = (ins_preds == ins_labels).mean() + oos_acc = (oos_preds == oos_labels).mean() + + # for oos samples recall = tp / (tp + fn) + TP = (oos_labels & oos_preds).sum() + FN = ((oos_labels - oos_preds) > 0).sum() + recall = TP / (TP+FN) + results = {"acc":acc, "ins_acc":ins_acc, "oos_acc":oos_acc, "oos_recall":recall} + else: + acc = (preds == labels).mean() + results = {"acc":acc} + + return results \ No newline at end of file diff --git a/models/multi_label_classifier.py b/models/multi_label_classifier.py new file mode 100644 index 0000000..d0b7f89 --- /dev/null +++ b/models/multi_label_classifier.py @@ -0,0 +1,119 @@ +import os.path +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import optim + +from torch.nn import CrossEntropyLoss +from torch.nn import CosineEmbeddingLoss +from sklearn.metrics import f1_score #, average_precision_score +import numpy as np + + +from transformers import * + + +class multi_label_classifier(nn.Module): + def __init__(self, args): #, num_labels, device): + super(multi_label_classifier, self).__init__() + self.args = args + self.hidden_dim = args["hdd_size"] + self.rnn_num_layers = args["num_rnn_layers"] + self.num_labels = args["num_labels"] + self.bce = nn.BCELoss() + self.sigmoid = nn.Sigmoid() + self.n_gpu = args["n_gpu"] + + ### Utterance Encoder + self.utterance_encoder = args["model_class"].from_pretrained(self.args["model_name_or_path"]) + + self.bert_output_dim = args["config"].hidden_size + #self.hidden_dropout_prob = self.utterance_encoder.config.hidden_dropout_prob + + if self.args["fix_encoder"]: + print("[Info] fix_encoder") + for p in self.utterance_encoder.parameters(): + p.requires_grad = False + + if self.args["more_linear_mapping"]: + self.one_more_layer = nn.Linear(self.bert_output_dim, self.bert_output_dim) + + self.classifier = nn.Linear(self.bert_output_dim, self.num_labels) + print("self.classifier", self.bert_output_dim, self.num_labels) + + ## Prepare Optimizer + def get_optimizer_grouped_parameters(model): + param_optimizer = [(n, p) for n, p in model.named_parameters() if p.requires_grad] + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01, + 'lr': args["learning_rate"]}, + {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, + 'lr': args["learning_rate"]}, + ] + return optimizer_grouped_parameters + + if self.n_gpu == 1: + optimizer_grouped_parameters = get_optimizer_grouped_parameters(self) + else: + optimizer_grouped_parameters = get_optimizer_grouped_parameters(self.module) + + + self.optimizer = AdamW(optimizer_grouped_parameters, + lr=args["learning_rate"],) + + def optimize(self): + self.loss_grad.backward() + clip_norm = torch.nn.utils.clip_grad_norm_(self.parameters(), self.args["grad_clip"]) + self.optimizer.step() + + def forward(self, data): + #input_ids, input_len, labels=None, n_gpu=1, target_slot=None): + + self.optimizer.zero_grad() + + inputs = {"input_ids": data[self.args["input_name"]], "attention_mask":(data[self.args["input_name"]] > 0).long()} + + if "gpt2" in self.args["model_type"]: + hidden = self.utterance_encoder(**inputs)[0] + hidden_head = hidden.mean(1) + elif self.args["model_type"] == "dialogpt": + transformer_outputs = self.utterance_encoder.transformer( + inputs["input_ids"], + attention_mask=(inputs["input_ids"] > 0).long())[0] + hidden_head = transformer_outputs.mean(1) + else: + hidden = self.utterance_encoder(**inputs)[0] + hidden_head = hidden[:, 0, :] + + # loss + if self.args["more_linear_mapping"]: + hidden_head = self.one_more_layer(hidden_head) + + logits = self.classifier(hidden_head) + prob = self.sigmoid(logits) + loss = self.bce(prob, data[self.args["task_name"]]) + + if self.training: + self.loss_grad = loss + self.optimize() + + predictions = (prob > 0.5) + + outputs = {"loss":loss.item(), + "pred":predictions.detach().cpu().numpy(), + "label":data[self.args["task_name"]].detach().cpu().numpy()} + + return outputs + + def evaluation(self, preds, labels): + preds = np.array(preds) + labels = np.array(labels) + results = {} + for avg_name in ['micro', 'macro', 'weighted', 'samples']: + my_f1_score = f1_score(y_true=labels, y_pred=preds, average=avg_name) + results["f1_{}".format(avg_name)] = my_f1_score + + return results + diff --git a/my_tod_pretraining.py b/my_tod_pretraining.py new file mode 100644 index 0000000..45a1674 --- /dev/null +++ b/my_tod_pretraining.py @@ -0,0 +1,1150 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa). +GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned +using a masked language modeling (MLM) loss. +""" + +import argparse +import glob +import logging +import os +import pickle +import random +import re +import shutil +from typing import Tuple +import gzip +import shelve +import json +import math +import faiss + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler, TensorDataset +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm, trange +from concurrent.futures import ThreadPoolExecutor +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader, Sampler + +from utils.utils_general import * +from utils.utils_multiwoz import * +from utils.utils_camrest676 import * +from utils.utils_woz import * +from utils.utils_smd import * +from utils.utils_frames import * +from utils.utils_msre2e import * +from utils.utils_taskmaster import * +from utils.utils_metalwoz import * +from utils.utils_schema import * + +from transformers import ( + WEIGHTS_NAME, + AdamW, + BertConfig, + BertModel, + BertForMaskedLM, + BertTokenizer, + CamembertConfig, + CamembertForMaskedLM, + CamembertTokenizer, + DistilBertConfig, + DistilBertForMaskedLM, + DistilBertTokenizer, + GPT2Config, + GPT2LMHeadModel, + GPT2Tokenizer, + OpenAIGPTConfig, + OpenAIGPTLMHeadModel, + OpenAIGPTTokenizer, + PreTrainedTokenizer, + RobertaConfig, + RobertaForMaskedLM, + RobertaTokenizer, + get_linear_schedule_with_warmup, +) + + +try: + from torch.utils.tensorboard import SummaryWriter +except ImportError: + from tensorboardX import SummaryWriter + + +logger = logging.getLogger(__name__) + + +MODEL_CLASSES = { + "gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer), + "openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), + "bert": (BertConfig, BertForMaskedLM, BertTokenizer), + "bert-seq": (BertConfig, BertModel, BertTokenizer), + "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer), + "distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer), + "camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer), +} + + +def _norm_text(text): + w, *toks = text.strip().split() + try: + w = float(w) + except Exception: + toks = [w] + toks + w = 1.0 + return w, ' '.join(toks) + + +def gelu(x): + """ Original Implementation of the gelu activation function in Google Bert repo when initialy created. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + + +class myTextDataset(Dataset): + def __init__(self, tokenizer, args, text, dtype, block_size=512): + cached_features_file = os.path.join( + "./cached/", args.model_name_or_path.replace("/", "-") + "_cached_lm_" + str(block_size) + "_" + dtype + "_all" + ) + print("cached_features_file", cached_features_file) + + if not os.path.exists("./cached"): os.mkdir("./cached") + + if os.path.exists(cached_features_file) and not args.overwrite_cache: + logger.info("Loading features from cached file %s", cached_features_file) + with open(cached_features_file, "rb") as handle: + self.examples = pickle.load(handle) + + else: + self.examples = [] + + tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) + + for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size + self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size])) + # Note that we are loosing the last truncated example here for the sake of simplicity (no padding) + # If your dataset is small, first you should loook for a bigger one :-) and second you + # can change this behavior by adding (model specific) padding. + + logger.info("Saving features into cached file %s", cached_features_file) + with open(cached_features_file, "wb") as handle: + pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) + + def __len__(self): + return len(self.examples) + + def __getitem__(self, item): + return torch.tensor(self.examples[item]) + + +def my_load_and_cache_examples(args, tokenizer, text, dtype): + dataset = myTextDataset( + tokenizer, + args, + text, + dtype, + block_size=args.block_size, + ) + return dataset + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False): + if not args.save_total_limit: + return + if args.save_total_limit <= 0: + return + + # Check if we should delete older checkpoint(s) + glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix))) + if len(glob_checkpoints) <= args.save_total_limit: + return + + ordering_and_checkpoint_path = [] + for path in glob_checkpoints: + if use_mtime: + ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) + else: + regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path) + if regex_match and regex_match.groups(): + ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) + + checkpoints_sorted = sorted(ordering_and_checkpoint_path) + checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] + number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit) + checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] + for checkpoint in checkpoints_to_be_deleted: + logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) + shutil.rmtree(checkpoint) + + +def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]: + """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ + + inputs = inputs.to("cpu") + + labels = inputs.clone() + # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) + probability_matrix = torch.full(labels.shape, args.mlm_probability) + special_tokens_mask = [ + tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() + ] + probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) + + # padding position value = 0 + inputs_pad_pos = (inputs == 0).cpu() + probability_matrix.masked_fill_(inputs_pad_pos, value=0.0) + + masked_indices = torch.bernoulli(probability_matrix).bool() + try: + labels[~masked_indices] = -100 # We only compute loss on masked tokens + except: + masked_indices = masked_indices.byte() + labels[~masked_indices] = -100 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + try: + indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices + inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) + except: + indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool().byte() & masked_indices + inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) + + # 10% of the time, we replace masked input tokens with random word + try: + indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced + random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long) + if inputs.is_cuda: + indices_random = indices_random.to(args.device) + random_words = random_words.to(args.device) + inputs[indices_random] = random_words[indices_random] + except: + indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool().byte() & masked_indices & ~indices_replaced + random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long) + if inputs.is_cuda: + indices_random = indices_random.to(args.device) + random_words = random_words.to(args.device) + inputs[indices_random] = random_words[indices_random] + + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + return inputs, labels + + +def mask_for_response_selection(batch, tokenizer, args, cand_uttr_sys_dict, others): + + inputs = batch if args.concat_all_data else batch["context"] + inputs = inputs.to("cpu") + + batch_size = inputs.size(0) + probability_matrix = torch.full(inputs.shape, 1) + usr_token_idx = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(args.usr_token))[0] + sys_token_idx = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(args.sys_token))[0] + + cand_uttr_sys = list(cand_uttr_sys_dict.keys()) + cand_uttr_sys_tokens = list(cand_uttr_sys_dict.values()) + + last_sys_position, last_usr_position = [], [] + for bsz_i, batch_sample in enumerate(inputs): + nb_sys_token = len((batch_sample == sys_token_idx).nonzero()) + nb_usr_token = len((batch_sample == usr_token_idx).nonzero()) + + if nb_sys_token == 0 or nb_usr_token == 0: + last_sys_position.append(len(batch_sample)//2) + last_usr_position.append(len(batch_sample)) + else: + if nb_sys_token > 2 and nb_usr_token > 2: + rand_pos = random.randint(1, min(nb_sys_token, nb_usr_token)-1) + else: + rand_pos = -1 + + temp1 = (batch_sample == sys_token_idx).nonzero()[rand_pos][0].item() + last_sys_position.append(temp1) + temp2 = (batch_sample == usr_token_idx).nonzero()[rand_pos][0].item() + + if temp2 > temp1: + last_usr_position.append(temp2) + else: + if temp1 + 10 < len(batch_sample): + last_usr_position.append(temp1 + 10) + else: + last_usr_position.append(len(batch_sample)) + + set_max_resp_len = 150 + + last_usr_position = np.array(last_usr_position) + last_sys_position = np.array(last_sys_position) + max_last_sys_position = max(last_sys_position) + max_response_len = max(last_usr_position-last_sys_position) + 1 + max_response_len = max_response_len if max_response_len < set_max_resp_len else set_max_resp_len + + input_contexts = torch.zeros(batch_size, max_last_sys_position).long()#.to(args.device) + input_responses = torch.zeros(batch_size, max_response_len).long()#.to(args.device) + output_labels = torch.tensor(np.arange(batch_size)).long()#.to(args.device) + + responses = [] + for bsz_i, (sys_pos, usr_pos) in enumerate(zip(last_sys_position, last_usr_position)): + input_contexts[bsz_i, :sys_pos] = inputs[bsz_i, :sys_pos] + input_responses[bsz_i, 0] = inputs[bsz_i, 0] + responses.append(tokenizer.decode(inputs[bsz_i, sys_pos+1:usr_pos]).replace(" ", "")) + s, e = (sys_pos, usr_pos) if usr_pos-sys_pos < max_response_len else (sys_pos, sys_pos+max_response_len-1) + input_responses[bsz_i, 1:e-s+1] = inputs[bsz_i, s:e] + + if args.negative_sampling_by_kmeans: + candidates_tokens = [] + for ri, resp in enumerate(responses): + if resp in others["ToD_BERT_SYS_UTTR_KMEANS"].keys(): + cur_cluster = others["ToD_BERT_SYS_UTTR_KMEANS"][resp] + candidates = others["KMEANS_to_SENTS"][cur_cluster] + nb_selected = min(args.nb_neg_sample_rs, len(candidates)-1) + start_pos = random.randint(0, len(candidates)-nb_selected-1) + sampled_neg_resps = candidates[start_pos:start_pos+nb_selected] + candidates_tokens += [cand_uttr_sys_dict[r] for r in sampled_neg_resps] + else: + start_pos = random.randint(0, len(cand_uttr_sys)-args.nb_neg_sample_rs-1) + candidates_tokens += cand_uttr_sys_tokens[start_pos:start_pos+args.nb_neg_sample_rs] + else: + candidates_tokens = [] + for i in range(args.nb_negative_samples): + pos = random.randint(0, len(cand_uttr_sys_tokens)-1) + candidates_tokens.append(cand_uttr_sys_tokens[pos]) + + input_responses_neg = torch.zeros(len(candidates_tokens), max_response_len).long() + for i in range(len(candidates_tokens)): + if len(candidates_tokens[i]) > input_responses.size(1): + input_responses_neg[i] = candidates_tokens[i][:input_responses.size(1)] + else: + input_responses_neg[i, :len(candidates_tokens[i])] = candidates_tokens[i] + + input_responses = torch.cat([input_responses, input_responses_neg], 0) + + return input_contexts, input_responses, output_labels + + +def get_candidate_embeddings(uttr_sys_dict, tokenizer, model): + + print("Start obtaining representations from model...") + + uttr_sys = list(uttr_sys_dict.keys()) + uttr_sys_tokens = list(uttr_sys_dict.values()) + + ToD_BERT_SYS_UTTR_EMB = {} + batch_size = 100 + for start in tqdm(range(0, len(uttr_sys), batch_size)): #len(uttr_sys) + if start+batch_size > len(uttr_sys): + inputs = uttr_sys[start:] + inputs_ids = uttr_sys_tokens[start:] + else: + inputs = uttr_sys[start:start+batch_size] + inputs_ids = uttr_sys_tokens[start:start+batch_size] + + inputs_ids = pad_sequence(inputs_ids, batch_first=True, padding_value=0) + if torch.cuda.is_available(): inputs_ids = inputs_ids.cuda() + + with torch.no_grad(): + outputs = model.bert(input_ids=inputs_ids, attention_mask=inputs_ids>0) + sequence_output = outputs[0] + cls_rep = sequence_output[:, 0, :] + #cls_rep = pool_out + + for i in range(cls_rep.size(0)): + ToD_BERT_SYS_UTTR_EMB[inputs[i].replace(" ", "")] = { + "sent":inputs[i], + "emb":cls_rep[i, :].cpu().numpy() + } + return ToD_BERT_SYS_UTTR_EMB + +def get_candidate_kmeans(args, uttr_sys_dict, tokenizer, model): + ToD_BERT_SYS_UTTR_EMB = get_candidate_embeddings(uttr_sys_dict, tokenizer, model) + + print("Start computing kmeans...") + ToD_BERT_SYS_UTTR_KMEANS = {} + KMEANS_to_SENTS = {i:[] for i in range(args.nb_kmeans)} + + # faiss + data = [v["emb"] for v in ToD_BERT_SYS_UTTR_EMB.values()] + data = np.array(data) + kmeans_1k = faiss.Kmeans(data.shape[1], args.nb_kmeans, niter=20, nredo=5, verbose=True) + kmeans_1k.train(data) + D, I = kmeans_1k.index.search(data, 1) + for i, key in enumerate(ToD_BERT_SYS_UTTR_EMB.keys()): + ToD_BERT_SYS_UTTR_KMEANS[key] = I[i][0] + KMEANS_to_SENTS[I[i][0]].append(ToD_BERT_SYS_UTTR_EMB[key]["sent"]) + + return ToD_BERT_SYS_UTTR_KMEANS, KMEANS_to_SENTS + + +def train(args, trn_loader, dev_loader, model, tokenizer, cand_uttr_sys_dict, others): + """ Train the model """ + if args.local_rank in [-1, 0]: + tb_writer = SummaryWriter("runs/"+args.output_dir.replace("/","-")) + + if args.max_steps > 0: + t_total = args.max_steps + args.num_train_epochs = args.max_steps // (len(trn_loader) // args.gradient_accumulation_steps) + 1 + else: + t_total = len(trn_loader) // args.gradient_accumulation_steps * args.num_train_epochs + + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) + scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total + ) + + # Check if saved optimizer or scheduler states exist + if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( + os.path.join(args.model_name_or_path, "scheduler.pt") + ): + # Load in optimizer and scheduler states + optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) + scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) + + if args.fp16: + try: + from apex import amp + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) + + # multi-gpu training (should be after apex fp16 initialization) + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + + # Distributed training (should be after apex fp16 initialization) + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True + ) + + # Train! + logger.info("***** Running training *****") + logger.info(" Num Epochs = %d", args.num_train_epochs) + logger.info(" Num batches = %d", len(trn_loader)) + logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) + logger.info( + " Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size + * args.gradient_accumulation_steps + * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), + ) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d", t_total) + + global_step = 0 + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + + tr_loss, logging_loss = 0.0, 0.0 + loss_mlm, loss_rs = 0, 0 + patience, best_loss = 0, 1e10 + xeloss = torch.nn.CrossEntropyLoss() + + model_to_resize = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training + model_to_resize.resize_token_embeddings(len(tokenizer)) + + model.zero_grad() + train_iterator = trange( + epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0] + ) + set_seed(args) # Added here for reproducibility + for _ in train_iterator: + + if args.negative_sampling_by_kmeans: + ToD_BERT_SYS_UTTR_KMEANS, KMEANS_to_SENTS = get_candidate_kmeans(args, cand_uttr_sys_dict, tokenizer, model) + trn_loader = get_loader(vars(args), "train", tokenizer, others["datasets"], others["unified_meta"], "train") + + epoch_iterator = tqdm(trn_loader, disable=args.local_rank not in [-1, 0]) + for step, batch in enumerate(epoch_iterator): + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + continue + + if args.add_rs_loss: # add response selection into pretraining + + if args.negative_sampling_by_kmeans: + kmeans_others = {"ToD_BERT_SYS_UTTR_KMEANS":ToD_BERT_SYS_UTTR_KMEANS, + "KMEANS_to_SENTS":KMEANS_to_SENTS} + else: + kmeans_others = {} + + input_cont, input_resp, resp_label = mask_for_response_selection(batch, + tokenizer, + args, + cand_uttr_sys_dict, + kmeans_others) + + input_cont, labels = mask_tokens(input_cont, tokenizer, args) if args.mlm else (input_cont, input_cont) + + input_cont = input_cont.to(args.device) + input_resp = input_resp.to(args.device) + resp_label = resp_label.to(args.device) + labels = labels.to(args.device) + + outputs = model.bert( + input_cont, + attention_mask=input_cont>0, + ) + sequence_output = outputs[0] + hid_cont = sequence_output[:, 0, :] + prediction_scores = model.cls(sequence_output) + + loss = xeloss(prediction_scores.view(-1, model.config.vocab_size), labels.view(-1)) + loss_mlm = loss.item() + + outputs = model.bert( + input_resp, + attention_mask=input_resp>0, + ) + sequence_output = outputs[0] + hid_resp = sequence_output[:, 0, :] + + scores = torch.matmul(hid_cont, hid_resp.transpose(1, 0)) + + loss_rs = xeloss(scores, resp_label) + loss += loss_rs + loss_rs = loss_rs.item() + + else: + inputs = batch if args.concat_all_data else batch["context"].clone() + model.train() + if args.mlm: + inputs, labels = mask_tokens(inputs, tokenizer, args) + + inputs = inputs.to(args.device) + labels = labels.to(args.device) + + outputs = model(inputs, + masked_lm_labels=labels, + attention_mask=inputs>0) + else: + labels = inputs.clone() + masked_indices = (labels == 0) + labels[masked_indices] = -100 # We only compute loss on masked tokens + outputs = model(inputs, labels=labels) + + loss = outputs[0] # model outputs are always tuple in transformers (see doc) + loss_mlm = loss.item() + + epoch_iterator.set_description("Loss:{:.4f} MLM:{:.4f} RS:{:.4f}".format(loss.item(), + loss_mlm, + loss_rs)) + + if args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + if args.fp16: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + tr_loss += loss.item() + if (step + 1) % args.gradient_accumulation_steps == 0: + if args.fp16: + torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + optimizer.step() + scheduler.step() # Update learning rate schedule + model.zero_grad() + global_step += 1 + + if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: + tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) + tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step) + logging_loss = tr_loss + + if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: + + if args.evaluate_during_training and args.n_gpu == 1: + results = evaluate(args, model, dev_loader, tokenizer) + for key, value in results.items(): + tb_writer.add_scalar("eval_{}".format(key), value, global_step) + else: + results = {} + results["loss"] = best_loss - 0.1 # always saving + + if results["loss"] < best_loss: + patience = 0 + best_loss = results["loss"] + + checkpoint_prefix = "checkpoint" + # Save model checkpoint + output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step)) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + model_to_save = ( + model.module if hasattr(model, "module") else model + ) # Take care of distributed/parallel training + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + torch.save(args, os.path.join(output_dir, "training_args.bin")) + logger.info("Saving model checkpoint to %s", output_dir) + + _rotate_checkpoints(args, checkpoint_prefix) + + torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) + torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) + logger.info("Saving optimizer and scheduler states to %s", output_dir) + else: + patience += 1 + logger.info("Current patience: patience {}".format(patience)) + + if args.max_steps > 0 and global_step > args.max_steps: + epoch_iterator.close() + break + + if patience > args.patience: + logger.info("Ran out of patience...") + break + + if (args.max_steps > 0 and global_step > args.max_steps) or patience > args.patience: + train_iterator.close() + break + + if args.local_rank in [-1, 0]: + tb_writer.close() + + return global_step, tr_loss / global_step + + +def evaluate(args, model, dev_loader, tokenizer, prefix=""): + eval_output_dir = args.output_dir + + if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: + os.makedirs(eval_output_dir) + + eval_dataloader = dev_loader + + # multi-gpu evaluate + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + + # Eval! + logger.info("***** Running evaluation {} *****".format(prefix)) + logger.info(" Num examples = %d", len(eval_dataloader)) + logger.info(" Batch size = %d", args.eval_batch_size) + eval_loss = 0.0 + nb_eval_steps = 0 + model.eval() + + for batch in tqdm(eval_dataloader, desc="Evaluating"): + + inputs = batch if args.concat_all_data else batch["context"].clone() + + #inputs, labels = mask_tokens(inputs, tokenizer, args) if args.mlm else (inputs, inputs) + if args.mlm: + inputs, labels = mask_tokens(inputs, tokenizer, args) + else: + labels = inputs.clone() + masked_indices = (labels == 0) + labels[masked_indices] = -100 + + inputs = inputs.to(args.device) + labels = labels.to(args.device) + + with torch.no_grad(): + outputs = model(inputs, + masked_lm_labels=labels, + attention_mask=inputs>0) if args.mlm else model(inputs, labels=labels) + + lm_loss = outputs[0] + eval_loss += lm_loss.mean().item() + nb_eval_steps += 1 + + eval_loss = eval_loss / nb_eval_steps + perplexity = torch.exp(torch.tensor(eval_loss)) + + result = {"perplexity": perplexity, "loss":eval_loss} + + output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results {} *****".format(prefix)) + for key in sorted(result.keys()): + logger.info(" %s = %s", key, str(result[key])) + writer.write("%s = %s\n" % (key, str(result[key]))) + + return result + + +def main(): + parser = argparse.ArgumentParser() + + # Required parameters + #parser.add_argument( + # "--train_data_file", default=None, type=str, required=True, help="The input training data file (a text file)." + #) + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--model_type", default="bert", type=str, help="The model architecture to be fine-tuned.") + parser.add_argument( + "--model_name_or_path", + default="bert-base-uncased", + type=str, + help="The model checkpoint for weights initialization.", + ) + parser.add_argument( + "--mlm", action="store_true", help="Train with masked-language modeling loss instead of language modeling." + ) + parser.add_argument( + "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss" + ) + parser.add_argument( + "--config_name", + default="", + type=str, + help="Optional pretrained config name or path if not the same as model_name_or_path", + ) + parser.add_argument( + "--tokenizer_name", + default="", + type=str, + help="Optional pretrained tokenizer name or path if not the same as model_name_or_path", + ) + parser.add_argument( + "--cache_dir", + default="", + type=str, + help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)", + ) + parser.add_argument( + "--block_size", + default=-1, + type=int, + help="Optional input sequence length after tokenization." + "The training dataset will be truncated in block of this size for training." + "Default to the model max input length for single sentence inputs (take into account special tokens).", + ) + parser.add_argument("--do_train", action="store_true", help="Whether to run training.") + parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") + parser.add_argument( + "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step." + ) + parser.add_argument( + "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model." + ) + + parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for training.") + parser.add_argument( + "--per_gpu_eval_batch_size", default=4, type=int, help="Batch size per GPU/CPU for evaluation." + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") + parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") + parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--num_train_epochs", default=300, type=int, help="Total number of training epochs to perform." + ) + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") + + parser.add_argument("--logging_steps", type=int, default=100, help="Log every X updates steps.") + parser.add_argument("--save_steps", type=int, default=100, help="Save checkpoint every X updates steps.") + parser.add_argument( + "--save_total_limit", + type=int, + default=1, + help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default", + ) + parser.add_argument( + "--eval_all_checkpoints", + action="store_true", + help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number", + ) + parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") + parser.add_argument( + "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory" + ) + parser.add_argument( + "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" + ) + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + + parser.add_argument( + "--fp16", + action="store_true", + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", + ) + parser.add_argument( + "--fp16_opt_level", + type=str, + default="O1", + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html", + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.") + parser.add_argument("--server_port", type=str, default="", help="For distant debugging.") + + # My add + parser.add_argument("--data_source", type=str, default="", help="For distant debugging.") + parser.add_argument("--patience", type=int, default=20, help="earlystop") + parser.add_argument("--db_folder", type=str, default="./RedditDB/", help="") + parser.add_argument( + "--shuffle_dial_during_training", + action="store_true", + help="", + ) + parser.add_argument("--db_ratio", type=float, default=0.1, help="") + parser.add_argument( + "--multithread", + action="store_true", + help="", + ) + parser.add_argument( + '-ds','--dataset', + help='which dataset to be used.', + required=False, + #default='["multiwoz"]', + default='["multiwoz", "camrest676", "woz", "smd", "frames", "msre2e", "taskmaster", "metalwoz", "schema"]', + type=str) + parser.add_argument( + '--example_type', + help='type in ["turn", "dial"]', + required=False, + default="turn") + parser.add_argument( + '--max_line', + help='maximum line for reading data (for quick testing)', + required=False, + default=None, + type=int) + parser.add_argument( + '-dpath','--data_path', + help='path to dataset folder', + required=False, + default='/export/home/dialog_datasets', + type=str) + parser.add_argument( + "--train_data_ratio", + default=1.0, + type=float, + help="") + parser.add_argument( + "--ratio_by_random", + action="store_true", + help="") + parser.add_argument( + "--domain_act", + action="store_true", + help="") + parser.add_argument( + '-task','--task', + help='task in ["nlu", "dst", "dm", "nlg", "e2e"] to decide which dataloader to use', + required=True) + parser.add_argument( + '-task_name', '--task_name', + help='', + required=False, + default="") + parser.add_argument( + '--usr_token', + help='', + required=False, + default="[USR]", + type=str) + parser.add_argument( + '--sys_token', + help='', + required=False, + default="[SYS]", + type=str) + parser.add_argument( + "--add_rs_loss", + action="store_true", + help="") + parser.add_argument( + "--only_last_turn", + action="store_true", + help="") + parser.add_argument( + "--concat_all_data", + action="store_true", + help="") + parser.add_argument( + "--oracle_domain", + action="store_true", + help="",) + parser.add_argument( + "--ontology_version", + default="", + type=str, + help="['', '1.0']") + parser.add_argument( + "--dstlm", + action="store_true", + help="",) + parser.add_argument( + "--max_seq_length", + default=512, + type=int, + help="") + parser.add_argument( + "--nb_negative_samples", + default=0, + type=int, + help="") + parser.add_argument( + "--negative_sampling_by_kmeans", + action="store_true", + help="",) + parser.add_argument( + "--nb_kmeans", + default=500, + type=int, + help="") + parser.add_argument( + "--nb_neg_sample_rs", + default=0, + type=int, + help="") + parser.add_argument( + "--nb_shots", + default=-1, + type=int, + help="") + + args = parser.parse_args() + args_dict = vars(args) + + if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm: + raise ValueError( + "BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm " + "flag (masked language modeling)." + ) + #if args.eval_data_file is None and args.do_eval: + # raise ValueError( + # "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " + # "or remove the --do_eval argument." + # ) + + if ( + os.path.exists(args.output_dir) + and os.listdir(args.output_dir) + and args.do_train + and not args.overwrite_output_dir + ): + raise ValueError( + "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( + args.output_dir + ) + ) + + # Setup distant debugging if needed + if args.server_ip and args.server_port: + # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script + import ptvsd + + print("Waiting for debugger attach") + ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) + ptvsd.wait_for_attach() + + # Setup CUDA, GPU & distributed training + if args.local_rank == -1 or args.no_cuda: + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend="nccl") + args.n_gpu = 1 + args.device = device + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, + ) + logger.warning( + "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + args.local_rank, + device, + args.n_gpu, + bool(args.local_rank != -1), + args.fp16, + ) + + # Set seed + set_seed(args) + + # Load pretrained model and tokenizer + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab + + config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + config = config_class.from_pretrained( + args.config_name if args.config_name else args.model_name_or_path, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + #config.output_hidden_states = True + + tokenizer = tokenizer_class.from_pretrained( + args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, + do_lower_case=args.do_lower_case, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + + if args.block_size <= 0: + args.block_size = ( + tokenizer.max_len_single_sentence + ) # Our input block size will be the max possible for the model + args.block_size = min(args.block_size, tokenizer.max_len_single_sentence) + model = model_class.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + cache_dir=args.cache_dir if args.cache_dir else None, + ) + model.to(args.device) + + # Add new tokens to the vocabulary and embeddings of our model + tokenizer.add_tokens([args.sys_token, args.usr_token]) + model.resize_token_embeddings(len(tokenizer)) + + if args.local_rank == 0: + torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab + + logger.info("Training/evaluation parameters %s", args) + + # Training + if args.do_train: + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() + # Barrier to make sure only the first process in distributed training process the dataset, + # and the others will use the cache + + #train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False) + + datasets = {} + cand_uttr_sys = set() + for ds_name in ast.literal_eval(args.dataset): + data_trn, data_dev, data_tst, data_meta = globals()["prepare_data_{}".format(ds_name)](args_dict) + + # held-out mwoz for now + if ds_name == "multiwoz": + datasets[ds_name] = {"train": data_trn, "dev":data_dev, "test": data_tst, "meta":data_meta} + else: + datasets[ds_name] = {"train": data_trn + data_dev + data_tst, "dev":[], "test": [], "meta":data_meta} + + + for d in datasets[ds_name]["train"]: + cand_uttr_sys.add(d["turn_sys"]) + cand_uttr_sys.update(set([sent for si, sent in enumerate(d["dialog_history"]) if si%2==0])) + + unified_meta = get_unified_meta(datasets) + args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) + args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) + + # obtain candidate responses + if args.nb_negative_samples > 0: + cand_uttr_sys = list(cand_uttr_sys) + cand_uttr_sys = [s.lower() for s in cand_uttr_sys if len(s.split(" ")) <= 100] # remove too long responses + cand_uttr_sys_tokens = [] + for cand in tqdm(cand_uttr_sys): + cand_ids = tokenizer.tokenize("[CLS] [SYS]") + tokenizer.tokenize(cand) + cand_ids = torch.tensor(tokenizer.convert_tokens_to_ids(cand_ids)) + cand_uttr_sys_tokens.append(cand_ids) + cand_uttr_sys_dict = {a:b for a, b in zip(cand_uttr_sys, cand_uttr_sys_tokens)} + else: + cand_uttr_sys_dict = {} + print("len of cand_uttr_sys_dict:", len(cand_uttr_sys_dict)) + + args_dict["batch_size"] = args.train_batch_size + args_dict["eval_batch_size"] = args.eval_batch_size + + # Create Dataloader + trn_loader = get_loader(args_dict, "train", tokenizer, datasets, unified_meta, "train") + dev_loader = get_loader(args_dict, "dev" , tokenizer, datasets, unified_meta, "dev") + + others = {} + if args.negative_sampling_by_kmeans: + others["datasets"] = datasets + others["unified_meta"] = unified_meta + + if args.local_rank == 0: + torch.distributed.barrier() + + global_step, tr_loss = train(args, trn_loader, dev_loader, model, tokenizer, cand_uttr_sys_dict, others) + logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) + + # Evaluation + results = {} + if args.do_eval and args.local_rank in [-1, 0]: + #checkpoints = [args.output_dir] + #if args.eval_all_checkpoints: + checkpoints = list( + os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True)) + ) + logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging + logger.info("Evaluate the following checkpoints: %s", checkpoints) + for checkpoint in checkpoints: + global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" + prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else "" + + model = model_class.from_pretrained(checkpoint) + model.to(args.device) + result = evaluate(args, model, dev_loader, tokenizer, prefix=prefix) + result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) + results.update(result) + + print(results) + return results + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/run_tod_lm_pretraining.sh b/run_tod_lm_pretraining.sh new file mode 100755 index 0000000..d468e6c --- /dev/null +++ b/run_tod_lm_pretraining.sh @@ -0,0 +1,26 @@ +gpu=$1 +model_type=$2 +bert_dir=$3 +output_dir=$4 +add1=$5 +add2=$6 +add3=$7 +add4=$8 +add5=$9 + +# ./run_tod_lm_pretraining.sh 0 bert bert-base-uncased save/pretrain/ToD-BERT-MLM --only_last_turn +# ./run_tod_lm_pretraining.sh 0 bert bert-base-uncased save/pretrain/ToD-BERT-JNT --only_last_turn --add_rs_loss + +CUDA_VISIBLE_DEVICES=$gpu python my_tod_pretraining.py \ + --task=usdl \ + --model_type=${model_type} \ + --model_name_or_path=${bert_dir} \ + --output_dir=${output_dir} \ + --do_train \ + --do_eval \ + --mlm \ + --do_lower_case \ + --evaluate_during_training \ + --save_steps=2500 --logging_steps=1000 \ + --per_gpu_train_batch_size=8 --per_gpu_eval_batch_size=8 \ + ${add1} ${add2} ${add3} ${add4} ${add5} \ No newline at end of file diff --git a/utils/__pycache__/config.cpython-36.pyc b/utils/__pycache__/config.cpython-36.pyc new file mode 100644 index 0000000..807e46f Binary files /dev/null and b/utils/__pycache__/config.cpython-36.pyc differ diff --git a/utils/__pycache__/dataloader_dm.cpython-36.pyc b/utils/__pycache__/dataloader_dm.cpython-36.pyc new file mode 100644 index 0000000..b8aaa43 Binary files /dev/null and b/utils/__pycache__/dataloader_dm.cpython-36.pyc differ diff --git a/utils/__pycache__/dataloader_dst.cpython-36.pyc b/utils/__pycache__/dataloader_dst.cpython-36.pyc new file mode 100644 index 0000000..cb745d7 Binary files /dev/null and b/utils/__pycache__/dataloader_dst.cpython-36.pyc differ diff --git a/utils/__pycache__/dataloader_nlg.cpython-36.pyc b/utils/__pycache__/dataloader_nlg.cpython-36.pyc new file mode 100644 index 0000000..a7da9ad Binary files /dev/null and b/utils/__pycache__/dataloader_nlg.cpython-36.pyc differ diff --git a/utils/__pycache__/dataloader_nlu.cpython-36.pyc b/utils/__pycache__/dataloader_nlu.cpython-36.pyc new file mode 100644 index 0000000..b9a87e1 Binary files /dev/null and b/utils/__pycache__/dataloader_nlu.cpython-36.pyc differ diff --git a/utils/__pycache__/dataloader_usdl.cpython-36.pyc b/utils/__pycache__/dataloader_usdl.cpython-36.pyc new file mode 100644 index 0000000..ee4dcf7 Binary files /dev/null and b/utils/__pycache__/dataloader_usdl.cpython-36.pyc differ diff --git a/utils/__pycache__/utils_amazonQA.cpython-36.pyc b/utils/__pycache__/utils_amazonQA.cpython-36.pyc new file mode 100644 index 0000000..aaf104f Binary files /dev/null and b/utils/__pycache__/utils_amazonQA.cpython-36.pyc differ diff --git a/utils/__pycache__/utils_camrest676.cpython-36.pyc b/utils/__pycache__/utils_camrest676.cpython-36.pyc new file mode 100644 index 0000000..20a1312 Binary files /dev/null and b/utils/__pycache__/utils_camrest676.cpython-36.pyc differ diff --git a/utils/__pycache__/utils_frames.cpython-36.pyc b/utils/__pycache__/utils_frames.cpython-36.pyc new file mode 100644 index 0000000..08f368a Binary files /dev/null and b/utils/__pycache__/utils_frames.cpython-36.pyc differ diff --git a/utils/__pycache__/utils_function.cpython-36.pyc b/utils/__pycache__/utils_function.cpython-36.pyc new file mode 100644 index 0000000..f657bfb Binary files /dev/null and b/utils/__pycache__/utils_function.cpython-36.pyc differ diff --git a/utils/__pycache__/utils_general.cpython-36.pyc b/utils/__pycache__/utils_general.cpython-36.pyc new file mode 100644 index 0000000..7fa1ee2 Binary files /dev/null and b/utils/__pycache__/utils_general.cpython-36.pyc differ diff --git a/utils/__pycache__/utils_metalwoz.cpython-36.pyc b/utils/__pycache__/utils_metalwoz.cpython-36.pyc new file mode 100644 index 0000000..b38fc26 Binary files /dev/null and b/utils/__pycache__/utils_metalwoz.cpython-36.pyc differ diff --git a/utils/__pycache__/utils_msre2e.cpython-36.pyc b/utils/__pycache__/utils_msre2e.cpython-36.pyc new file mode 100644 index 0000000..ad4d38f Binary files /dev/null and b/utils/__pycache__/utils_msre2e.cpython-36.pyc differ diff --git a/utils/__pycache__/utils_multiwoz.cpython-36.pyc b/utils/__pycache__/utils_multiwoz.cpython-36.pyc new file mode 100644 index 0000000..4a6e062 Binary files /dev/null and b/utils/__pycache__/utils_multiwoz.cpython-36.pyc differ diff --git a/utils/__pycache__/utils_multiwoz_dstlm.cpython-36.pyc b/utils/__pycache__/utils_multiwoz_dstlm.cpython-36.pyc new file mode 100644 index 0000000..1b57c8f Binary files /dev/null and b/utils/__pycache__/utils_multiwoz_dstlm.cpython-36.pyc differ diff --git a/utils/__pycache__/utils_oos_intent.cpython-36.pyc b/utils/__pycache__/utils_oos_intent.cpython-36.pyc new file mode 100644 index 0000000..14d519b Binary files /dev/null and b/utils/__pycache__/utils_oos_intent.cpython-36.pyc differ diff --git a/utils/__pycache__/utils_schema.cpython-36.pyc b/utils/__pycache__/utils_schema.cpython-36.pyc new file mode 100644 index 0000000..506e37e Binary files /dev/null and b/utils/__pycache__/utils_schema.cpython-36.pyc differ diff --git a/utils/__pycache__/utils_smd.cpython-36.pyc b/utils/__pycache__/utils_smd.cpython-36.pyc new file mode 100644 index 0000000..f3536f7 Binary files /dev/null and b/utils/__pycache__/utils_smd.cpython-36.pyc differ diff --git a/utils/__pycache__/utils_taskmaster.cpython-36.pyc b/utils/__pycache__/utils_taskmaster.cpython-36.pyc new file mode 100644 index 0000000..c797192 Binary files /dev/null and b/utils/__pycache__/utils_taskmaster.cpython-36.pyc differ diff --git a/utils/__pycache__/utils_universal_act.cpython-36.pyc b/utils/__pycache__/utils_universal_act.cpython-36.pyc new file mode 100644 index 0000000..c58ce37 Binary files /dev/null and b/utils/__pycache__/utils_universal_act.cpython-36.pyc differ diff --git a/utils/__pycache__/utils_woz.cpython-36.pyc b/utils/__pycache__/utils_woz.cpython-36.pyc new file mode 100644 index 0000000..e8d7a0c Binary files /dev/null and b/utils/__pycache__/utils_woz.cpython-36.pyc differ diff --git a/utils/config.py b/utils/config.py new file mode 100755 index 0000000..55fea77 --- /dev/null +++ b/utils/config.py @@ -0,0 +1,197 @@ +import os +import logging +import argparse +from tqdm import tqdm +import torch +import numpy as np + + +parser = argparse.ArgumentParser(description='Task-oriented Dialogue System Benchmarking') + + +## Training Setting +parser.add_argument( + '--do_train', action='store_true', help="do training") +parser.add_argument( + '-epoch','--epoch', help='number of epochs to train', required=False, default=300, type=int) +parser.add_argument( + '-patience','--patience', help='patience for early stopping', required=False, default=10, type=int) +parser.add_argument( + '-earlystop','--earlystop', help='metric for early stopping', required=False, default="loss", type=str) +parser.add_argument( + '--my_model', help='my cutomized model', required=False, default="") +parser.add_argument( + '-dr','--dropout', help='Dropout ratio', required=False, type=float, default=0.2) +parser.add_argument( + '-lr','--learning_rate', help='Learning Rate', required=False, type=float, default=5e-5) +parser.add_argument( + '-bsz','--batch_size', help='Batch_size', required=False, type=int, default=16) +parser.add_argument( + '-ebsz','--eval_batch_size', help='Batch_size', required=False, type=int, default=16) +parser.add_argument( + '-hdd','--hdd_size', help='Hidden size', required=False, type=int, default=400) +parser.add_argument( + '-emb','--emb_size', help='Embedding size', required=False, type=int, default=400) +parser.add_argument( + '-clip','--grad_clip', help='gradient clipping', required=False, default=1, type=int) +parser.add_argument( + '-tfr','--teacher_forcing_ratio', help='teacher_forcing_ratio', type=float, required=False, default=0.5) +parser.add_argument( + '-loadEmb','--load_embedding', help='Load Pretrained Glove and Char Embeddings', required=False, default=False, type=bool) +parser.add_argument( + '-fixEmb','--fix_embedding', help='', required=False, default=False, type=bool) +parser.add_argument( + '--n_gpu', help='', required=False, default=1, type=int) +parser.add_argument( + '--eval_by_step', help='', required=False, default=-1, type=int) +parser.add_argument( + '--fix_encoder', action='store_true', help="") +parser.add_argument( + '--model_type', help='', required=False, default="bert", type=str) +parser.add_argument( + '--model_name_or_path', help='', required=False, default="bert", type=str) +parser.add_argument( + '--usr_token', help='', required=False, default="[USR]", type=str) +parser.add_argument( + '--sys_token', help='', required=False, default="[SYS]", type=str) +parser.add_argument( + '--warmup_proportion', help='warm up training in the begining', required=False, default=0.1, type=float) +parser.add_argument( + "--local_rank", type=int, default=-1, help="For distributed training: local_rank") +parser.add_argument( + "--gradient_accumulation_steps", type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.",) +parser.add_argument( + "--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") +parser.add_argument( + "--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") +parser.add_argument( + "--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") +parser.add_argument( + "--fp16", action="store_true", help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",) +parser.add_argument( + "--fp16_opt_level", type=str, default="O1", + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html",) +parser.add_argument( + "--output_mode", default="classification", type=str, help="") +parser.add_argument( + "--max_steps", default=-1, type=int, help="If > 0: set total number of training steps to perform. Override num_train_epochs.",) +parser.add_argument( + "--rand_seed", default=0, type=int, help="") +parser.add_argument( + "--fix_rand_seed", action="store_true", help="fix random seed for training",) +parser.add_argument( + "--nb_runs", default=1, type=int, help="number of runs to conduct during training") +parser.add_argument( + "--nb_evals", default=1, type=int, help="number of runs to conduct during inference") +parser.add_argument( + "--max_seq_length", default=512, type=int, help="") +parser.add_argument( + "--input_name", default="context", type=str, help="") + + +## Dataset or Input/Output Setting +parser.add_argument( + '-dpath','--data_path', help='path to dataset folder, need to change to your local folder', + required=False, default='/export/home/dialog_datasets', type=str) +parser.add_argument( + '-task','--task', help='task in ["nlu", "dst", "dm", "nlg", "usdl"] to decide which dataloader to use', required=True) +parser.add_argument( + '-task_name','--task_name', help='task in ["intent", "sysact","rs"]', required=False, default="") +parser.add_argument( + '--example_type', help='type in ["turn", "dial"]', required=False, default="turn") +parser.add_argument( + '-ds','--dataset', help='which dataset to be used.', required=False, default='["multiwoz"]', type=str) +parser.add_argument( + '-load_path','--load_path', help='path of the saved model to load from', required=False) +parser.add_argument( + '-an','--add_name', help='An added name for the save folder', required=False, default='') +parser.add_argument( + '--max_line', help='maximum line for reading data (for quick testing)', required=False, default=None, type=int) +parser.add_argument( + '--output_dir', help='', required=False, default="save/temp/", type=str) +parser.add_argument( + '--overwrite', action='store_true', help="") +parser.add_argument( + "--cache_dir", default=None, type=str, + help="Optional directory to store the pre-trained models downloaded from s3 (instead of the default one)",) +parser.add_argument( + "--logging_steps", default=500, type=int, help="") +parser.add_argument( + "--save_steps", default=1000, type=int, help="") +parser.add_argument( + "--save_total_limit", type=int, default=1, + help="Limit the total amount of checkpoints, delete the older checkpoints in the output_dir",) +parser.add_argument( + "--train_data_ratio", default=1.0, type=float, help="") +parser.add_argument( + "--domain_act", action="store_true", help="",) +parser.add_argument( + "--only_last_turn", action="store_true", help="",) +parser.add_argument( + "--error_analysis", action="store_true", help="",) +parser.add_argument( + "--not_save_model", action="store_true", help="") +parser.add_argument( + "--nb_shots", default=-1, type=int, help="") + + +## Others (May be able to delete or not used in this repo) +parser.add_argument( + '--do_embeddings', action='store_true') +parser.add_argument( + '--create_own_vocab', action='store_true', help="") +parser.add_argument( + '-um','--unk_mask', help='mask out input token to UNK', type=bool, required=False, default=True) +parser.add_argument( + '-paral','--parallel_decode', help='', required=False, default=True, type=bool) +parser.add_argument( + '--self_supervised', help='', required=False, default="generative", type=str) +parser.add_argument( + "--oracle_domain", action="store_true", help="",) +parser.add_argument( + "--more_linear_mapping", action="store_true", help="",) +parser.add_argument( + "--gate_supervision_for_dst", action="store_true", help="",) +parser.add_argument( + "--sum_token_emb_for_value", action="store_true", help="",) +parser.add_argument( + "--nb_neg_sample_rs", default=0, type=int, help="") +parser.add_argument( + "--sample_negative_by_kmeans", action="store_true", help="",) +parser.add_argument( + "--nb_kmeans", default=1000, type=int, help="") +parser.add_argument( + "--bidirect", action="store_true", help="",) +parser.add_argument( + '--rnn_type', help='rnn type ["gru", "lstm"]', required=False, type=str, default="gru") +parser.add_argument( + '--num_rnn_layers', help='rnn layers size', required=False, type=int, default=1) +parser.add_argument( + '--zero_init_rnn',action='store_true', help="set initial hidden of rnns zero") +parser.add_argument( + "--do_zeroshot", action="store_true", help="",) +parser.add_argument( + "--oos_threshold", action="store_true", help="",) +parser.add_argument( + "--ontology_version", default="", type=str, help="1.0 is the cleaned version but not used") +parser.add_argument( + "--dstlm", action="store_true", help="",) +parser.add_argument( + '-viz','--vizualization', help='vizualization', type=int, required=False, default=0) + + + +args = vars(parser.parse_args()) +# args = parser.parse_args() +print(str(args)) + +# check output_dir +if os.path.exists(args["output_dir"]) and os.listdir(args["output_dir"]) and args["do_train"] and (not args["overwrite"]): + raise ValueError("Output directory ({}) already exists and is not empty.".format(args["output_dir"])) +os.makedirs(args["output_dir"], exist_ok=True) + +# Dictionary Predefined +SEEDS = np.arange(0, 100, 5) + diff --git a/utils/dataloader_dm.py b/utils/dataloader_dm.py new file mode 100644 index 0000000..a422b56 --- /dev/null +++ b/utils/dataloader_dm.py @@ -0,0 +1,104 @@ +import torch +import torch.utils.data as data +# from .config import * +from .utils_function import to_cuda, merge, merge_multi_response, merge_sent_and_word + +class Dataset_dm(torch.utils.data.Dataset): + """Custom data.Dataset compatible with data.DataLoader.""" + def __init__(self, data_info, tokenizer, args, unified_meta, mode, max_length=512): + """Reads source and target sequences from txt files.""" + self.data = data_info + self.tokenizer = tokenizer + self.num_total_seqs = len(data_info["ID"]) + self.usr_token = args["usr_token"] + self.sys_token = args["sys_token"] + self.max_length = max_length + self.args = args + self.unified_meta = unified_meta + + if "bert" in self.args["model_type"] or "electra" in self.args["model_type"]: + self.start_token = self.tokenizer.cls_token + self.sep_token = self.tokenizer.sep_token + else: + self.start_token = self.tokenizer.bos_token + self.sep_token = self.tokenizer.eos_token + + def __getitem__(self, index): + """Returns one data pair (source and target).""" + + if self.args["example_type"] == "turn": + dialog_history_str = self.get_concat_context(self.data["dialog_history"][index]) + context_plain = self.concat_dh_sys_usr(dialog_history_str, self.data["turn_sys"][index], self.data["turn_usr"][index]) + context = self.preprocess(context_plain) + act_plain = self.data["sys_act"][index] + + turn_sys_plain = "{} {}".format(self.sys_token, self.data["turn_sys"][index]) + turn_sys = self.preprocess(turn_sys_plain) + + act_one_hot = [0] * len(self.unified_meta["sysact"]) + for act in act_plain: + act_one_hot[self.unified_meta["sysact"][act]] = 1 + + elif self.args["example_type"] == "dial": + #TODO + print("Not Implemented dial for nlu yet...") + + item_info = { + "ID":self.data["ID"][index], + "turn_id":self.data["turn_id"][index], + "context":context, + "context_plain":context_plain, + "sysact":act_one_hot, + "sysact_plain":act_plain, + "turn_sys":turn_sys} + + return item_info + + def __len__(self): + return self.num_total_seqs + + def preprocess(self, sequence): + """Converts words to ids.""" + tokens = self.tokenizer.tokenize(self.start_token) + self.tokenizer.tokenize(sequence)[-self.max_length+1:] + story = torch.Tensor(self.tokenizer.convert_tokens_to_ids(tokens)) + return story + + def concat_dh_sys_usr(self, dialog_history, sys, usr): + return dialog_history + " {} ".format(self.sys_token) + " {} ".format(self.sep_token) + sys + " {} ".format(self.usr_token) + usr + + def get_concat_context(self, dialog_history): + dialog_history_str = "" + for ui, uttr in enumerate(dialog_history): + if ui%2 == 0: + dialog_history_str += "{} {} ".format(self.sys_token, uttr) + else: + dialog_history_str += "{} {} ".format(self.usr_token, uttr) + dialog_history_str = dialog_history_str.strip() + return dialog_history_str + + +def collate_fn_dm_turn(data): + # sort a list by sequence length (descending order) to use pack_padded_sequence + data.sort(key=lambda x: len(x['context']), reverse=True) + + item_info = {} + for key in data[0].keys(): + item_info[key] = [d[key] for d in data] + + # merge sequences + src_seqs, src_lengths = merge(item_info['context']) + turn_sys, _ = merge(item_info["turn_sys"]) + sysact = torch.tensor(item_info["sysact"]).float() + + item_info["context"] = to_cuda(src_seqs) + item_info["context_len"] = src_lengths + item_info["sysact"] = to_cuda(sysact) + item_info["turn_sys"] = to_cuda(turn_sys) + + return item_info + + +def collate_fn_nlu_dial(data): + # TODO + return + diff --git a/utils/dataloader_dst.py b/utils/dataloader_dst.py new file mode 100644 index 0000000..0c2ccae --- /dev/null +++ b/utils/dataloader_dst.py @@ -0,0 +1,159 @@ +import torch +import numpy as np +import torch.utils.data as data +from .utils_function import to_cuda, merge, merge_multi_response, merge_sent_and_word + +# SLOT_GATE = {"ptr":0, "dontcare":1, "none":2} + +class Dataset_dst(torch.utils.data.Dataset): + """Custom data.Dataset compatible with data.DataLoader.""" + def __init__(self, data_info, tokenizer, args, unified_meta, mode, max_length=512): + """Reads source and target sequences from txt files.""" + self.data = data_info + self.tokenizer = tokenizer + self.num_total_seqs = len(data_info["ID"]) + self.usr_token = args["usr_token"] + self.sys_token = args["sys_token"] + self.max_length = max_length + self.args = args + self.unified_meta = unified_meta + self.slots = list(unified_meta["slots"].keys()) + self.mask_token_idx = tokenizer.convert_tokens_to_ids("[MASK]") + self.sep_token_idx = tokenizer.convert_tokens_to_ids("[SEP]") + + self.start_token = self.tokenizer.cls_token if "bert" in self.args["model_type"] else self.tokenizer.bos_token + self.sep_token = self.tokenizer.sep_token if "bert" in self.args["model_type"] else self.tokenizer.eos_token + + def __getitem__(self, index): + """Returns one data pair (source and target).""" + + if self.args["example_type"] == "turn": + dialog_history_str = self.get_concat_context(self.data["dialog_history"][index]) + gate_label = self.data["slot_gate"][index] + context_plain = self.concat_dh_sys_usr(dialog_history_str, + self.data["turn_sys"][index], + self.data["turn_usr"][index]) + slot_values_plain = self.data["slot_values"][index] + slot_values = self.preprocess_slot(slot_values_plain) + + triggered_domains = set([domain_slot.split("-")[0] for domain_slot in self.data["belief"][index].keys()]) + triggered_domains.add(self.data["turn_domain"][index]) + assert len(triggered_domains) != 0 + + triggered_ds_mask = [1 if s.split("-")[0] in triggered_domains else 0 for s in self.slots] + triggered_ds_idx = [] + triggered_ds_pos = [] + + context = self.preprocess(context_plain) + + ontology_idx = [] + for si, sv in enumerate(slot_values_plain): + try: + ontology_idx.append(self.unified_meta["slots"][self.slots[si]][sv]) + except Exception as e: + print("Not In Ontology") + print(e) + print(self.slots[si], sv) + ontology_idx.append(-1) + + elif self.args["example_type"] == "dial": + raise NotImplemented() + + item_info = { + "ID":self.data["ID"][index], + "turn_id":self.data["turn_id"][index], + "del_belief":self.data["del_belief"][index], + "slot_gate":gate_label, + "context":context, + "context_plain":context_plain, + "slot_values":slot_values, + "belief":self.data["belief"][index], + "slots":self.data["slots"][index], + "belief_ontology":ontology_idx, + "triggered_ds_mask":triggered_ds_mask, + "triggered_ds_idx":triggered_ds_idx, + "triggered_ds_pos":triggered_ds_pos} + + return item_info + + def __len__(self): + return self.num_total_seqs + + def concat_dh_sys_usr(self, dialog_history, sys, usr): + return dialog_history + " {} ".format(self.sep_token) + " {} ".format(self.sys_token) + sys + " {} ".format(self.usr_token) + usr + + def preprocess(self, sequence): + """Converts words to ids.""" + #story = torch.Tensor(self.tokenizer.encode(sequence)) + tokens = self.tokenizer.tokenize(self.start_token) + self.tokenizer.tokenize(sequence)[-self.max_length+1:] + story = torch.Tensor(self.tokenizer.convert_tokens_to_ids(tokens)) + return story + + def preprocess_slot(self, sequence): + """Converts words to ids.""" + story = [] + for value in sequence: + v = list(self.tokenizer.encode(value + " {}".format(self.sep_token))) + story.append(v) + return story + + def get_concat_context(self, dialog_history): + dialog_history_str = "" + for ui, uttr in enumerate(dialog_history): + if ui%2 == 0: + dialog_history_str += "{} {} ".format(self.sys_token, uttr) + else: + dialog_history_str += "{} {} ".format(self.usr_token, uttr) + dialog_history_str = dialog_history_str.strip() + return dialog_history_str + + +def collate_fn_dst_turn(data): + # sort a list by sequence length (descending order) to use pack_padded_sequence + data.sort(key=lambda x: len(x['context']), reverse=True) + + item_info = {} + for key in data[0].keys(): + item_info[key] = [d[key] for d in data] + + # merge sequences + src_seqs, src_lengths = merge(item_info['context']) + y_seqs, y_lengths = merge_multi_response(item_info["slot_values"]) + gates = torch.tensor(item_info["slot_gate"]) + belief_ontology = torch.tensor(item_info["belief_ontology"]) + triggered_ds_mask = torch.tensor(item_info["triggered_ds_mask"]) + + item_info["context"] = to_cuda(src_seqs) + item_info["context_len"] = src_lengths + item_info["slot_gate"] = to_cuda(gates) + item_info["slot_values"] = to_cuda(y_seqs) + item_info["slot_values_len"] = y_lengths + item_info["belief_ontology"] = to_cuda(belief_ontology) + item_info["triggered_ds_mask"] = to_cuda(triggered_ds_mask) + + return item_info + +def collate_fn_dst_dial(data): + # sort a list by sequence length (descending order) to use pack_padded_sequence + data.sort(key=lambda x: len(x['context']), reverse=True) + + item_info = {} + for key in data[0].keys(): + item_info[key] = [d[key] for d in data] + + # merge sequences + src_seqs, src_lengths = merge_sent_and_word(item_info['context']) + y = [merge_multi_response(sv) for sv in item_info["slot_values"]] + y_seqs = [_y[0] for _y in y] + y_lengths = [_y[1] for _y in y] + gates, gate_lengths = merge_sent_and_word(item_info['slot_gate'], ignore_idx=-1) + belief_ontology = torch.tensor(item_info["belief_ontology"]) + + item_info["context"] = to_cuda(src_seqs) + item_info["context_len"] = src_lengths + item_info["slot_gate"] = to_cuda(gates) + item_info["slot_values"] = [to_cuda(y) for y in y_seqs] # TODO + item_info["slot_values_len"] = y_lengths # TODO + + return item_info + diff --git a/utils/dataloader_nlg.py b/utils/dataloader_nlg.py new file mode 100644 index 0000000..c32fd3f --- /dev/null +++ b/utils/dataloader_nlg.py @@ -0,0 +1,169 @@ +import torch +import torch.utils.data as data +import random + +from .utils_function import to_cuda, merge +# from .config import * + + +class Dataset_nlg(torch.utils.data.Dataset): + """Custom data.Dataset compatible with data.DataLoader.""" + def __init__(self, data_info, tokenizer, args, unified_meta, mode, max_length=512, max_sys_resp_len=50): + """Reads source and target sequences from txt files.""" + self.data = data_info + self.tokenizer = tokenizer + self.max_length = max_length + self.num_total_seqs = len(data_info["ID"]) + self.usr_token = args["usr_token"] + self.sys_token = args["sys_token"] + self.unified_meta = unified_meta + self.args = args + self.mode = mode + + if "bert" in self.args["model_type"] or "electra" in self.args["model_type"]: + self.start_token = self.tokenizer.cls_token + self.sep_token = self.tokenizer.sep_token + else: + self.start_token = self.tokenizer.bos_token + self.sep_token = self.tokenizer.eos_token + + self.resp_cand_trn = list(self.unified_meta["resp_cand_trn"]) + random.shuffle(self.resp_cand_trn) + self.max_sys_resp_len = max_sys_resp_len + self.others = unified_meta["others"] + + def __getitem__(self, index): + """Returns one data pair (source and target).""" + + if self.args["example_type"] == "turn": + context_plain = self.get_concat_context(self.data["dialog_history"][index]) + context_plain_delex = self.get_concat_context(self.data["dialog_history_delex"][index]) + context = self.preprocess(context_plain) + context_delex = self.preprocess(context_plain_delex) + response_plain = "{} ".format(self.sys_token) + self.data["turn_sys"][index] + response = self.preprocess(response_plain)[:self.max_sys_resp_len] + response_plain_delex = "{} ".format(self.sys_token) + self.data["turn_sys_delex"][index] + response_delex = self.preprocess(response_plain_delex) + utterance_plain = "{} ".format(self.usr_token) + self.data["turn_usr"][index] + utterance = self.preprocess(utterance_plain) + utterance_plain_delex = "{} ".format(self.usr_token) + self.data["turn_usr_delex"][index] + utterance_delex = self.preprocess(utterance_plain_delex) + else: + raise NotImplementedError + + item_info = { + "ID":self.data["ID"][index], + "turn_id":self.data["turn_id"][index], + "context":context, + "context_plain":context_plain, + "context_delex":context_delex, + "context_plain_delex":context_plain_delex, + "response":response, + "response_plain":response_plain, + "response_delex":response_delex, + "response_plain_delex":response_plain_delex, + "utterance":utterance, + "utterance_plain":utterance_plain, + "utterance_delex":utterance_delex, + "utterance_plain_delex":utterance_plain_delex} + + if self.args["nb_neg_sample_rs"] != 0 and self.mode == "train": + + if self.args["sample_negative_by_kmeans"]: + try: + cur_cluster = self.others["ToD_BERT_SYS_UTTR_KMEANS"][self.data["turn_sys"][index]] + candidates = self.others["KMEANS_to_SENTS"][cur_cluster] + nb_selected = min(self.args["nb_neg_sample_rs"], len(candidates)) + try: + start_pos = random.randint(0, len(candidates)-nb_selected-1) + except: + start_pos = 0 + sampled_neg_resps = candidates[start_pos:start_pos+nb_selected] + + except: + start_pos = random.randint(0, len(self.resp_cand_trn)-self.args["nb_neg_sample_rs"]-1) + sampled_neg_resps = self.resp_cand_trn[start_pos:start_pos+self.args["nb_neg_sample_rs"]] + else: + start_pos = random.randint(0, len(self.resp_cand_trn)-self.args["nb_neg_sample_rs"]-1) + sampled_neg_resps = self.resp_cand_trn[start_pos:start_pos+self.args["nb_neg_sample_rs"]] + + neg_resp_arr, neg_resp_idx_arr = [], [] + for neg_resp in sampled_neg_resps: + neg_resp_plain = "{} ".format(self.sys_token) + neg_resp + neg_resp_idx = self.preprocess(neg_resp_plain)[:self.max_sys_resp_len] + neg_resp_idx_arr.append(neg_resp_idx) + neg_resp_arr.append(neg_resp_plain) + + item_info["neg_resp_idx_arr"] = neg_resp_idx_arr + item_info["neg_resp_arr"] = neg_resp_arr + + return item_info + + def __len__(self): + return self.num_total_seqs + + def preprocess(self, sequence): + """Converts words to ids.""" + tokens = self.tokenizer.tokenize(self.start_token) + self.tokenizer.tokenize(sequence)[-self.max_length+1:] + story = torch.Tensor(self.tokenizer.convert_tokens_to_ids(tokens)) + return story + + def get_concat_context(self, dialog_history): + dialog_history_str = "" + for ui, uttr in enumerate(dialog_history): + if ui%2 == 0: + dialog_history_str += "{} {} ".format(self.sys_token, uttr) + else: + dialog_history_str += "{} {} ".format(self.usr_token, uttr) + dialog_history_str = dialog_history_str.strip() + return dialog_history_str + + +def collate_fn_nlg_turn(data): + # sort a list by sequence length (descending order) to use pack_padded_sequence + data.sort(key=lambda x: len(x['context']), reverse=True) + + item_info = {} + for key in data[0].keys(): + item_info[key] = [d[key] for d in data] + + # augment negative samples + if "neg_resp_idx_arr" in item_info.keys(): + neg_resp_idx_arr = [] + for arr in item_info['neg_resp_idx_arr']: + neg_resp_idx_arr += arr + + # remove neg samples that are the same as one of the gold responses + #print('item_info["response"]', item_info["response"]) + #print('neg_resp_idx_arr', neg_resp_idx_arr) + + for bi, arr in enumerate(item_info['neg_resp_arr']): + for ri, neg_resp in enumerate(arr): + if neg_resp not in item_info["response_plain"]: + item_info["response"] += [item_info['neg_resp_idx_arr'][bi][ri]] + + #neg_resp_idx_arr = [ng for ng in neg_resp_idx_arr if ng not in item_info["response"]] + #item_info["response"] += neg_resp_idx_arr + + # merge sequences + context, context_lengths = merge(item_info['context']) + context_delex, context_delex_lengths = merge(item_info['context_delex']) + response, response_lengths = merge(item_info["response"]) + response_delex, response_delex_lengths = merge(item_info["response_delex"]) + utterance, utterance_lengths = merge(item_info["utterance"]) + utterance_delex, utterance_delex_lengths = merge(item_info["utterance_delex"]) + + #print("context", context.size()) + #print("response", response.size()) + + item_info["context"] = to_cuda(context) + item_info["context_lengths"] = context_lengths + item_info["response"] = to_cuda(response) + item_info["response_lengths"] = response_lengths + item_info["utterance"] = to_cuda(utterance) + item_info["utterance_lengths"] = response_lengths + + return item_info + + + diff --git a/utils/dataloader_nlu.py b/utils/dataloader_nlu.py new file mode 100644 index 0000000..a676db5 --- /dev/null +++ b/utils/dataloader_nlu.py @@ -0,0 +1,117 @@ +import torch +import torch.utils.data as data +# from .config import * +from .utils_function import to_cuda, merge, merge_multi_response, merge_sent_and_word + +class Dataset_nlu(torch.utils.data.Dataset): + """Custom data.Dataset compatible with data.DataLoader.""" + def __init__(self, data_info, tokenizer, args, unified_meta, mode, max_length=512): + """Reads source and target sequences from txt files.""" + self.data = data_info + self.tokenizer = tokenizer + self.num_total_seqs = len(data_info["ID"]) + self.usr_token = args["usr_token"] + self.sys_token = args["sys_token"] + self.max_length = max_length + self.args = args + self.unified_meta = unified_meta + + if "bert" in self.args["model_type"] or "electra" in self.args["model_type"]: + self.start_token = self.tokenizer.cls_token + self.sep_token = self.tokenizer.sep_token + else: + self.start_token = self.tokenizer.bos_token + self.sep_token = self.tokenizer.eos_token + + def __getitem__(self, index): + """Returns one data pair (source and target).""" + + if self.args["example_type"] == "turn": + context_plain = "{} {} {} {} {}".format(self.start_token, + self.sys_token, + self.data["turn_sys"][index], + self.usr_token, + self.data["turn_usr"][index]) + + context = self.preprocess(context_plain) + intent_plain = self.data["intent"][index] + + turn_sys_plain = "{} {}".format(self.sys_token, self.data["turn_sys"][index]) + turn_sys = self.preprocess(turn_sys_plain) + + try: + intent_idx = self.unified_meta["intent"][intent_plain] + except: + intent_idx = -100 + + try: + domain_idx = self.unified_meta["turn_domain"][self.data["turn_domain"][index]] + except: + domain_idx = -100 + + try: + turn_slot_one_hot = [0] * len(self.unified_meta["turn_slot"]) + for ts in self.data["turn_slot"][index]: + turn_slot_one_hot[self.unified_meta["turn_slot"][ts]] = 1 + except: + turn_slot_one_hot = -100 + + elif self.args["example_type"] == "dial": + print("Not Implemented dial for nlu yet...") + + item_info = { + "ID":self.data["ID"][index], + "turn_id":self.data["turn_id"][index], + "turn_domain":self.data["turn_domain"][index], + "context":context, + "context_plain":context_plain, + "intent":intent_idx, + "intent_plain":intent_plain, + "domain_plain":self.data["turn_domain"][index], + "turn_domain": domain_idx, + "turn_sys":turn_sys, + "turn_slot":turn_slot_one_hot, + "turn_sys_plain":turn_sys_plain + } + + return item_info + + def __len__(self): + return self.num_total_seqs + + def preprocess(self, sequence): + """Converts words to ids.""" + tokens = self.tokenizer.tokenize(self.start_token) + self.tokenizer.tokenize(sequence)[-self.max_length+1:] + story = torch.Tensor(self.tokenizer.convert_tokens_to_ids(tokens)) + return story + + +def collate_fn_nlu_turn(data): + # sort a list by sequence length (descending order) to use pack_padded_sequence + data.sort(key=lambda x: len(x['context']), reverse=True) + + item_info = {} + for key in data[0].keys(): + item_info[key] = [d[key] for d in data] + + # merge sequences + src_seqs, src_lengths = merge(item_info['context']) + turn_sys, _ = merge(item_info["turn_sys"]) + intent = torch.tensor(item_info["intent"]) + turn_domain = torch.tensor(item_info["turn_domain"]) + turn_slot = torch.tensor(item_info["turn_slot"]).float() + + item_info["context"] = to_cuda(src_seqs) + item_info["context_len"] = src_lengths + item_info["intent"] = to_cuda(intent) + item_info["turn_domain"] = to_cuda(turn_domain) + item_info["turn_sys"] = to_cuda(turn_sys) + item_info["turn_slot"] = to_cuda(turn_slot) + + return item_info + + +def collate_fn_nlu_dial(data): + # TODO + return + diff --git a/utils/dataloader_usdl.py b/utils/dataloader_usdl.py new file mode 100644 index 0000000..8522e0b --- /dev/null +++ b/utils/dataloader_usdl.py @@ -0,0 +1,131 @@ +import torch +import torch.utils.data as data +from .utils_function import to_cuda, merge, merge_multi_response, merge_sent_and_word + +class Dataset_usdl(torch.utils.data.Dataset): + """Custom data.Dataset compatible with data.DataLoader.""" + def __init__(self, data_info, tokenizer, args, unified_meta, mode, max_length=512): + """Reads source and target sequences from txt files.""" + self.data = data_info + self.tokenizer = tokenizer + self.num_total_seqs = len(data_info["ID"]) + self.usr_token = args["usr_token"] + self.sys_token = args["sys_token"] + self.usr_token_id = self.tokenizer.convert_tokens_to_ids(args["usr_token"]) + self.sys_token_id = self.tokenizer.convert_tokens_to_ids(args["sys_token"]) + self.max_length = max_length + self.args = args + self.unified_meta = unified_meta + self.start_token = self.tokenizer.cls_token if "bert" in self.args["model_type"] else self.tokenizer.bos_token + self.sep_token = self.tokenizer.sep_token if "bert" in self.args["model_type"] else self.tokenizer.eos_token + self.mode = mode + + def __getitem__(self, index): + """Returns one data pair (source and target).""" + item_info = {} + + if self.args["example_type"] == "turn": + dialog_history_str = self.get_concat_context(self.data["dialog_history"][index]) + context_plain = self.concat_dh_sys_usr(dialog_history_str, + self.data["turn_sys"][index], + self.data["turn_usr"][index]) + + context = self.preprocess(context_plain) + + elif self.args["example_type"] == "dial": + context_plain = self.data["dialog_history"][index] + context = self.preprocess_slot(context_plain) + + item_info["ID"] = self.data["ID"][index] + item_info["turn_id"] = self.data["turn_id"][index] + item_info["context"] = context + item_info["context_plain"] = context_plain + + return item_info + + def __len__(self): + return self.num_total_seqs + + def concat_dh_sys_usr(self, dialog_history, sys, usr): + return dialog_history + " {} ".format(self.sys_token) + sys + " {} ".format(self.usr_token) + usr + + def preprocess(self, sequence): + """Converts words to ids.""" + tokens = self.tokenizer.tokenize(self.start_token) + self.tokenizer.tokenize(sequence)[-self.max_length+1:] + story = torch.Tensor(self.tokenizer.convert_tokens_to_ids(tokens)) + return story + + def preprocess_slot(self, sequence): + """Converts words to ids.""" + story = [] + for value in sequence: + #v = list(self.tokenizer.encode(value))# + self.tokenizer.encode("[SEP]")) + v = list(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(value))) + story.append(v) + return story + + def get_concat_context(self, dialog_history): + candidate_sys_responses = [] + dialog_history_str = "" + for ui, uttr in enumerate(dialog_history): + if ui%2 == 0: + dialog_history_str += "{} {} ".format(self.sys_token, uttr) + else: + dialog_history_str += "{} {} ".format(self.usr_token, uttr) + dialog_history_str = dialog_history_str.strip() + return dialog_history_str + + +def collate_fn_usdl_turn(data): + # sort a list by sequence length (descending order) to use pack_padded_sequence + data.sort(key=lambda x: len(x['context']), reverse=True) + + item_info = {} + for key in data[0].keys(): + item_info[key] = [d[key] for d in data] + + # merge sequences + src_seqs, src_lengths = merge(item_info['context']) + + item_info["context"] = to_cuda(src_seqs) + item_info["context_len"] = src_lengths + + return item_info + +def collate_fn_usdl_dial(data): + # sort a list by sequence length (descending order) to use pack_padded_sequence + data.sort(key=lambda x: len(x['context']), reverse=True) + + item_info = {} + for key in data[0].keys(): + item_info[key] = [d[key] for d in data] + + # merge sequences + src_seqs, src_lengths = merge_sent_and_word(item_info['context']) + + item_info["context"] = to_cuda(src_seqs) + item_info["context_len"] = src_lengths + + return item_info + +def collate_fn_usdl_dial_flat(data): + # sort a list by sequence length (descending order) to use pack_padded_sequence + data.sort(key=lambda x: len(x['context_flat']), reverse=True) + + item_info = {} + for key in data[0].keys(): + item_info[key] = [d[key] for d in data] + + # merge sequences + src_flat_seqs, src_flat_lengths = merge(item_info['context_flat']) + src_seqs, src_lengths = merge_sent_and_word(item_info['context']) + src_pos_seqs, src_pos_lengths = merge(item_info["sys_usr_id_positions"]) + + item_info["context"] = to_cuda(src_seqs) + item_info["context_len"] = src_lengths + item_info["context_flat"] = to_cuda(src_flat_seqs) + item_info["context_flat_len"] = src_flat_lengths + item_info["sys_usr_id_positions"] = to_cuda(src_pos_seqs) + + return item_info + diff --git a/utils/loss_function/masked_cross_entropy.py b/utils/loss_function/masked_cross_entropy.py new file mode 100755 index 0000000..1ca1f8c --- /dev/null +++ b/utils/loss_function/masked_cross_entropy.py @@ -0,0 +1,171 @@ +import torch +from torch.nn import functional +from torch.autograd import Variable +from utils.config import * +import torch.nn as nn +import numpy as np + +def sequence_mask(sequence_length, max_len=None): + if max_len is None: + max_len = sequence_length.data.max() + batch_size = sequence_length.size(0) + seq_range = torch.arange(0, max_len).long() + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_range_expand = Variable(seq_range_expand) + if sequence_length.is_cuda: + seq_range_expand = seq_range_expand.cuda() + seq_length_expand = (sequence_length.unsqueeze(1) + .expand_as(seq_range_expand)) + return seq_range_expand < seq_length_expand + +def cross_entropy(logits, target): + batch_size = logits.size(0) + log_probs_flat = functional.log_softmax(logits) + losses_flat = -torch.gather(log_probs_flat, dim=1, index=target) + loss = losses_flat.sum() / batch_size + return loss + +def masked_cross_entropy(logits, target, length): + """ + Args: + logits: A Variable containing a FloatTensor of size + (batch, max_len, num_classes) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + + Returns: + loss: An average loss value masked by the length. + """ + if USE_CUDA: + length = Variable(torch.LongTensor(length)).cuda() + else: + length = Variable(torch.LongTensor(length)) + + # logits_flat: (batch * max_len, num_classes) + logits_flat = logits.view(-1, logits.size(-1)) ## -1 means infered from other dimentions + # log_probs_flat: (batch * max_len, num_classes) + log_probs_flat = functional.log_softmax(logits_flat, dim=1) + # target_flat: (batch * max_len, 1) + target_flat = target.view(-1, 1) + # losses_flat: (batch * max_len, 1) + losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) + # losses: (batch, max_len) + losses = losses_flat.view(*target.size()) + # mask: (batch, max_len) + mask = sequence_mask(sequence_length=length, max_len=target.size(1)) + losses = losses * mask.float() + loss = losses.sum() / length.float().sum() + return loss + +def masked_binary_cross_entropy(logits, target, length): + ''' + logits: (batch, max_len, num_class) + target: (batch, max_len, num_class) + ''' + if USE_CUDA: + length = Variable(torch.LongTensor(length)).cuda() + else: + length = Variable(torch.LongTensor(length)) + bce_criterion = nn.BCEWithLogitsLoss() + loss = 0 + for bi in range(logits.size(0)): + for i in range(logits.size(1)): + if i < length[bi]: + loss += bce_criterion(logits[bi][i], target[bi][i]) + loss = loss / length.float().sum() + return loss + + +def masked_cross_entropy_(logits, target, length, take_log=False): + if USE_CUDA: + length = Variable(torch.LongTensor(length)).cuda() + else: + length = Variable(torch.LongTensor(length)) + + # logits_flat: (batch * max_len, num_classes) + logits_flat = logits.view(-1, logits.size(-1)) ## -1 means infered from other dimentions + if take_log: + logits_flat = torch.log(logits_flat) + # target_flat: (batch * max_len, 1) + target_flat = target.view(-1, 1) + # losses_flat: (batch * max_len, 1) + losses_flat = -torch.gather(logits_flat, dim=1, index=target_flat) + # losses: (batch, max_len) + losses = losses_flat.view(*target.size()) + # mask: (batch, max_len) + mask = sequence_mask(sequence_length=length, max_len=target.size(1)) + losses = losses * mask.float() + loss = losses.sum() / length.float().sum() + return loss + +def masked_coverage_loss(coverage, attention, length): + if USE_CUDA: + length = Variable(torch.LongTensor(length)).cuda() + else: + length = Variable(torch.LongTensor(length)) + mask = sequence_mask(sequence_length=length) + min_ = torch.min(coverage, attention) + mask = mask.unsqueeze(2).expand_as(min_) + min_ = min_ * mask.float() + loss = min_.sum() / (len(length)*1.0) + return loss + +def masked_cross_entropy_for_slot(logits, target, mask, use_softmax=True): + # print("logits", logits) + # print("target", target) + logits_flat = logits.view(-1, logits.size(-1)) ## -1 means infered from other dimentions + # print(logits_flat.size()) + if use_softmax: + log_probs_flat = functional.log_softmax(logits_flat, dim=1) + else: + log_probs_flat = logits_flat #torch.log(logits_flat) + # print("log_probs_flat", log_probs_flat) + target_flat = target.view(-1, 1) + # print("target_flat", target_flat) + losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) + losses = losses_flat.view(*target.size()) # b * |s| + losses = losses * mask.float() + loss = losses.sum() / (losses.size(0)*losses.size(1)) + # print("loss inside", loss) + return loss + +def masked_cross_entropy_for_value(logits, target, mask): + # logits: b * |s| * m * |v| + # target: b * |s| * m + # mask: b * |s| + logits_flat = logits.view(-1, logits.size(-1)) ## -1 means infered from other dimentions + # print(logits_flat.size()) + log_probs_flat = torch.log(logits_flat) + # print("log_probs_flat", log_probs_flat) + target_flat = target.view(-1, 1) + # print("target_flat", target_flat) + losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) + losses = losses_flat.view(*target.size()) # b * |s| * m + loss = masking(losses, mask) + return loss + +def masking(losses, mask): + mask_ = [] + batch_size = mask.size(0) + max_len = losses.size(2) + for si in range(mask.size(1)): + seq_range = torch.arange(0, max_len).long() + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + if mask[:,si].is_cuda: + seq_range_expand = seq_range_expand.cuda() + seq_length_expand = mask[:, si].unsqueeze(1).expand_as(seq_range_expand) + mask_.append( (seq_range_expand < seq_length_expand) ) + mask_ = torch.stack(mask_) + mask_ = mask_.transpose(0, 1) + if losses.is_cuda: + mask_ = mask_.cuda() + losses = losses * mask_.float() + loss = losses.sum() / (mask_.sum().float()) + return loss + + + diff --git a/utils/metrics/__pycache__/measures.cpython-36.pyc b/utils/metrics/__pycache__/measures.cpython-36.pyc new file mode 100644 index 0000000..118d596 Binary files /dev/null and b/utils/metrics/__pycache__/measures.cpython-36.pyc differ diff --git a/utils/metrics/measures.py b/utils/metrics/measures.py new file mode 100755 index 0000000..f1edc7b --- /dev/null +++ b/utils/metrics/measures.py @@ -0,0 +1,116 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +import numpy + +import os +import re +import subprocess +import tempfile +import numpy as np + +from six.moves import urllib + +def word_error_rate(r, h): + """ + This is a function that calculate the word error rate in ASR. + You can use it like this: wer("what is it".split(), "what is".split()) + """ + #build the matrix + d = numpy.zeros((len(r)+1)*(len(h)+1), dtype=numpy.uint8).reshape((len(r)+1, len(h)+1)) + for i in range(len(r)+1): + for j in range(len(h)+1): + if i == 0: d[0][j] = j + elif j == 0: d[i][0] = i + for i in range(1,len(r)+1): + for j in range(1, len(h)+1): + if r[i-1] == h[j-1]: + d[i][j] = d[i-1][j-1] + else: + substitute = d[i-1][j-1] + 1 + insert = d[i][j-1] + 1 + delete = d[i-1][j] + 1 + d[i][j] = min(substitute, insert, delete) + result = float(d[len(r)][len(h)]) / len(r) * 100 + # result = str("%.2f" % result) + "%" + return result + +# -*- coding: utf-8 -*- +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BLEU metric implementation. +""" + + +def moses_multi_bleu(hypotheses, references, lowercase=False): + """Calculate the bleu score for hypotheses and references + using the MOSES ulti-bleu.perl script. + Args: + hypotheses: A numpy array of strings where each string is a single example. + references: A numpy array of strings where each string is a single example. + lowercase: If true, pass the "-lc" flag to the multi-bleu script + Returns: + The BLEU score as a float32 value. + """ + + if np.size(hypotheses) == 0: + return np.float32(0.0) + + + # Get MOSES multi-bleu script + try: + multi_bleu_path, _ = urllib.request.urlretrieve( + "https://raw.githubusercontent.com/moses-smt/mosesdecoder/" + "master/scripts/generic/multi-bleu.perl") + os.chmod(multi_bleu_path, 0o755) + except: #pylint: disable=W0702 + print("Unable to fetch multi-bleu.perl script, using local.") + metrics_dir = os.path.dirname(os.path.realpath(__file__)) + bin_dir = os.path.abspath(os.path.join(metrics_dir, "..", "..", "bin")) + multi_bleu_path = os.path.join(bin_dir, "tools/multi-bleu.perl") + + + # Dump hypotheses and references to tempfiles + hypothesis_file = tempfile.NamedTemporaryFile() + hypothesis_file.write("\n".join(hypotheses).encode("utf-8")) + hypothesis_file.write(b"\n") + hypothesis_file.flush() + reference_file = tempfile.NamedTemporaryFile() + reference_file.write("\n".join(references).encode("utf-8")) + reference_file.write(b"\n") + reference_file.flush() + + + # Calculate BLEU using multi-bleu script + with open(hypothesis_file.name, "r") as read_pred: + bleu_cmd = [multi_bleu_path] + if lowercase: + bleu_cmd += ["-lc"] + bleu_cmd += [reference_file.name] + try: + bleu_out = subprocess.check_output(bleu_cmd, stdin=read_pred, stderr=subprocess.STDOUT) + bleu_out = bleu_out.decode("utf-8") + bleu_score = re.search(r"BLEU = (.+?),", bleu_out).group(1) + bleu_score = float(bleu_score) + except subprocess.CalledProcessError as error: + if error.output is not None: + print("multi-bleu.perl script returned non-zero exit code") + print(error.output) + bleu_score = np.float32(0.0) + + # Close temp files + hypothesis_file.close() + reference_file.close() + return bleu_score \ No newline at end of file diff --git a/utils/metrics/multi-bleu.perl b/utils/metrics/multi-bleu.perl new file mode 100755 index 0000000..92645cb --- /dev/null +++ b/utils/metrics/multi-bleu.perl @@ -0,0 +1,177 @@ +#!/usr/bin/env perl +# +# This file is part of moses. Its use is licensed under the GNU Lesser General +# Public License version 2.1 or, at your option, any later version. + +# $Id$ +use warnings; +use strict; + +my $lowercase = 0; +if ($ARGV[0] eq "-lc") { + $lowercase = 1; + shift; +} + +my $stem = $ARGV[0]; +if (!defined $stem) { + print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; + print STDERR "Reads the references from reference or reference0, reference1, ...\n"; + exit(1); +} + +$stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; + +my @REF; +my $ref=0; +while(-e "$stem$ref") { + &add_to_ref("$stem$ref",\@REF); + $ref++; +} +&add_to_ref($stem,\@REF) if -e $stem; +die("ERROR: could not find reference file $stem") unless scalar @REF; + +# add additional references explicitly specified on the command line +shift; +foreach my $stem (@ARGV) { + &add_to_ref($stem,\@REF) if -e $stem; +} + + + +sub add_to_ref { + my ($file,$REF) = @_; + my $s=0; + if ($file =~ /.gz$/) { + open(REF,"gzip -dc $file|") or die "Can't read $file"; + } else { + open(REF,$file) or die "Can't read $file"; + } + while() { + chop; + push @{$$REF[$s++]}, $_; + } + close(REF); +} + +my(@CORRECT,@TOTAL,$length_translation,$length_reference); +my $s=0; +while() { + chop; + $_ = lc if $lowercase; + my @WORD = split; + my %REF_NGRAM = (); + my $length_translation_this_sentence = scalar(@WORD); + my ($closest_diff,$closest_length) = (9999,9999); + foreach my $reference (@{$REF[$s]}) { +# print "$s $_ <=> $reference\n"; + $reference = lc($reference) if $lowercase; + my @WORD = split(' ',$reference); + my $length = scalar(@WORD); + my $diff = abs($length_translation_this_sentence-$length); + if ($diff < $closest_diff) { + $closest_diff = $diff; + $closest_length = $length; + # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; + } elsif ($diff == $closest_diff) { + $closest_length = $length if $length < $closest_length; + # from two references with the same closeness to me + # take the *shorter* into account, not the "first" one. + } + for(my $n=1;$n<=4;$n++) { + my %REF_NGRAM_N = (); + for(my $start=0;$start<=$#WORD-($n-1);$start++) { + my $ngram = "$n"; + for(my $w=0;$w<$n;$w++) { + $ngram .= " ".$WORD[$start+$w]; + } + $REF_NGRAM_N{$ngram}++; + } + foreach my $ngram (keys %REF_NGRAM_N) { + if (!defined($REF_NGRAM{$ngram}) || + $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { + $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; +# print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; + } + } + } + } + $length_translation += $length_translation_this_sentence; + $length_reference += $closest_length; + for(my $n=1;$n<=4;$n++) { + my %T_NGRAM = (); + for(my $start=0;$start<=$#WORD-($n-1);$start++) { + my $ngram = "$n"; + for(my $w=0;$w<$n;$w++) { + $ngram .= " ".$WORD[$start+$w]; + } + $T_NGRAM{$ngram}++; + } + foreach my $ngram (keys %T_NGRAM) { + $ngram =~ /^(\d+) /; + my $n = $1; + # my $corr = 0; +# print "$i e $ngram $T_NGRAM{$ngram}
\n"; + $TOTAL[$n] += $T_NGRAM{$ngram}; + if (defined($REF_NGRAM{$ngram})) { + if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { + $CORRECT[$n] += $T_NGRAM{$ngram}; + # $corr = $T_NGRAM{$ngram}; +# print "$i e correct1 $T_NGRAM{$ngram}
\n"; + } + else { + $CORRECT[$n] += $REF_NGRAM{$ngram}; + # $corr = $REF_NGRAM{$ngram}; +# print "$i e correct2 $REF_NGRAM{$ngram}
\n"; + } + } + # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; + # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" + } + } + $s++; +} +my $brevity_penalty = 1; +my $bleu = 0; + +my @bleu=(); + +for(my $n=1;$n<=4;$n++) { + if (defined ($TOTAL[$n])){ + $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; + # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; + }else{ + $bleu[$n]=0; + } +} + +if ($length_reference==0){ + printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; + exit(1); +} + +if ($length_translation<$length_reference) { + $brevity_penalty = exp(1-$length_reference/$length_translation); +} +$bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + + my_log( $bleu[2] ) + + my_log( $bleu[3] ) + + my_log( $bleu[4] ) ) / 4) ; +printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", + 100*$bleu, + 100*$bleu[1], + 100*$bleu[2], + 100*$bleu[3], + 100*$bleu[4], + $brevity_penalty, + $length_translation / $length_reference, + $length_translation, + $length_reference; + + +print STDERR "It is in-advisable to publish scores from multi-bleu.perl. The scores depend on your tokenizer, which is unlikely to be reproducible from your paper or consistent across research groups. Instead you should detokenize then use mteval-v14.pl, which has a standard tokenization. Scores from multi-bleu.perl can still be used for internal purposes when you have a consistent tokenizer.\n"; + +sub my_log { + return -9999999999 unless $_[0]; + return log($_[0]); +} \ No newline at end of file diff --git a/utils/multiwoz/__pycache__/fix_label.cpython-36.pyc b/utils/multiwoz/__pycache__/fix_label.cpython-36.pyc new file mode 100644 index 0000000..8271e34 Binary files /dev/null and b/utils/multiwoz/__pycache__/fix_label.cpython-36.pyc differ diff --git a/utils/multiwoz/dbPointer.py b/utils/multiwoz/dbPointer.py new file mode 100644 index 0000000..420aff7 --- /dev/null +++ b/utils/multiwoz/dbPointer.py @@ -0,0 +1,172 @@ + +import sqlite3 + +import numpy as np + +from .nlp import normalize + + +# loading databases +domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital']#, 'police'] +dbs = {} +for domain in domains: + db = 'data/multi-woz/db/{}-dbase.db'.format(domain) + conn = sqlite3.connect(db) + c = conn.cursor() + dbs[domain] = c + + +def oneHotVector(num, domain, vector): + """Return number of available entities for particular domain.""" + number_of_options = 6 + if domain != 'train': + idx = domains.index(domain) + if num == 0: + vector[idx * 6: idx * 6 + 6] = np.array([1, 0, 0, 0, 0,0]) + elif num == 1: + vector[idx * 6: idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0]) + elif num == 2: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0]) + elif num == 3: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0]) + elif num == 4: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0]) + elif num >= 5: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1]) + else: + idx = domains.index(domain) + if num == 0: + vector[idx * 6: idx * 6 + 6] = np.array([1, 0, 0, 0, 0, 0]) + elif num <= 2: + vector[idx * 6: idx * 6 + 6] = np.array([0, 1, 0, 0, 0, 0]) + elif num <= 5: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 1, 0, 0, 0]) + elif num <= 10: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 1, 0, 0]) + elif num <= 40: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 1, 0]) + elif num > 40: + vector[idx * 6: idx * 6 + 6] = np.array([0, 0, 0, 0, 0, 1]) + + return vector + +def queryResult(domain, turn): + """Returns the list of entities for a given domain + based on the annotation of the belief state""" + # query the db + sql_query = "select * from {}".format(domain) + + flag = True + #print turn['metadata'][domain]['semi'] + for key, val in turn['metadata'][domain]['semi'].items(): + if val == "" or val == "dont care" or val == 'not mentioned' or val == "don't care" or val == "dontcare" or val == "do n't care": + pass + else: + if flag: + sql_query += " where " + val2 = val.replace("'", "''") + #val2 = normalize(val2) + # change query for trains + if key == 'leaveAt': + sql_query += r" " + key + " > " + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += r" " + key + " < " + r"'" + val2 + r"'" + else: + sql_query += r" " + key + "=" + r"'" + val2 + r"'" + flag = False + else: + val2 = val.replace("'", "''") + #val2 = normalize(val2) + if key == 'leaveAt': + sql_query += r" and " + key + " > " + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += r" and " + key + " < " + r"'" + val2 + r"'" + else: + sql_query += r" and " + key + "=" + r"'" + val2 + r"'" + + #try: # "select * from attraction where name = 'queens college'" + #print sql_query + #print domain + num_entities = len(dbs[domain].execute(sql_query).fetchall()) + + return num_entities + + +def queryResultVenues(domain, turn, real_belief=False): + # query the db + sql_query = "select * from {}".format(domain) + + if real_belief == True: + items = turn.items() + elif real_belief=='tracking': + for slot in turn[domain]: + key = slot[0].split("-")[1] + val = slot[0].split("-")[2] + if key == "price range": + key = "pricerange" + elif key == "leave at": + key = "leaveAt" + elif key == "arrive by": + key = "arriveBy" + if val == "do n't care": + pass + else: + if flag: + sql_query += " where " + val2 = val.replace("'", "''") + val2 = normalize(val2) + if key == 'leaveAt': + sql_query += key + " > " + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += key + " < " + r"'" + val2 + r"'" + else: + sql_query += r" " + key + "=" + r"'" + val2 + r"'" + flag = False + else: + val2 = val.replace("'", "''") + val2 = normalize(val2) + if key == 'leaveAt': + sql_query += r" and " + key + " > " + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += r" and " + key + " < " + r"'" + val2 + r"'" + else: + sql_query += r" and " + key + "=" + r"'" + val2 + r"'" + + try: # "select * from attraction where name = 'queens college'" + return dbs[domain].execute(sql_query).fetchall() + except: + return [] # TODO test it + pass + else: + items = turn['metadata'][domain]['semi'].items() + + flag = True + for key, val in items: + if val == "" or val == "dontcare" or val == 'not mentioned' or val == "don't care" or val == "dont care" or val == "do n't care": + pass + else: + if flag: + sql_query += " where " + val2 = val.replace("'", "''") + val2 = normalize(val2) + if key == 'leaveAt': + sql_query += r" " + key + " > " + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += r" " +key + " < " + r"'" + val2 + r"'" + else: + sql_query += r" " + key + "=" + r"'" + val2 + r"'" + flag = False + else: + val2 = val.replace("'", "''") + val2 = normalize(val2) + if key == 'leaveAt': + sql_query += r" and " + key + " > " + r"'" + val2 + r"'" + elif key == 'arriveBy': + sql_query += r" and " + key + " < " + r"'" + val2 + r"'" + else: + sql_query += r" and " + key + "=" + r"'" + val2 + r"'" + + try: # "select * from attraction where name = 'queens college'" + return dbs[domain].execute(sql_query).fetchall() + except: + return [] # TODO test it \ No newline at end of file diff --git a/utils/multiwoz/delexicalize.py b/utils/multiwoz/delexicalize.py new file mode 100644 index 0000000..302b193 --- /dev/null +++ b/utils/multiwoz/delexicalize.py @@ -0,0 +1,148 @@ +import re + +import simplejson as json + +from .nlp import normalize + +digitpat = re.compile('\d+') +timepat = re.compile("\d{1,2}[:]\d{1,2}") +pricepat2 = re.compile("\d{1,3}[.]\d{1,2}") + +# FORMAT +# domain_value +# restaurant_postcode +# restaurant_address +# taxi_car8 +# taxi_number +# train_id etc.. + + +def prepareSlotValuesIndependent(): + domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police'] + requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + dic = [] + dic_area = [] + dic_food = [] + dic_price = [] + + # read databases + for domain in domains: + try: + fin = open('data/multi-woz/db/' + domain + '_db.json', 'r') + db_json = json.load(fin) + fin.close() + + for ent in db_json: + for key, val in ent.items(): + if val == '?' or val == 'free': + pass + elif key == 'address': + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + if "road" in val: + val = val.replace("road", "rd") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif "rd" in val: + val = val.replace("rd", "road") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif "st" in val: + val = val.replace("st", "street") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif "street" in val: + val = val.replace("street", "st") + dic.append((normalize(val), '[' + domain + '_' + 'address' + ']')) + elif key == 'name': + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + if "b & b" in val: + val = val.replace("b & b", "bed and breakfast") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif "bed and breakfast" in val: + val = val.replace("bed and breakfast", "b & b") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif "hotel" in val and 'gonville' not in val: + val = val.replace("hotel", "") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif "restaurant" in val: + val = val.replace("restaurant", "") + dic.append((normalize(val), '[' + domain + '_' + 'name' + ']')) + elif key == 'postcode': + dic.append((normalize(val), '[' + domain + '_' + 'postcode' + ']')) + elif key == 'phone': + dic.append((val, '[' + domain + '_' + 'phone' + ']')) + elif key == 'trainID': + dic.append((normalize(val), '[' + domain + '_' + 'id' + ']')) + elif key == 'department': + dic.append((normalize(val), '[' + domain + '_' + 'department' + ']')) + + # NORMAL DELEX + elif key == 'area': + dic_area.append((normalize(val), '[' + 'value' + '_' + 'area' + ']')) + elif key == 'food': + dic_food.append((normalize(val), '[' + 'value' + '_' + 'food' + ']')) + elif key == 'pricerange': + dic_price.append((normalize(val), '[' + 'value' + '_' + 'pricerange' + ']')) + else: + pass + # TODO car type? + except: + pass + + if domain == 'hospital': + dic.append((normalize('Hills Rd'), '[' + domain + '_' + 'address' + ']')) + dic.append((normalize('Hills Road'), '[' + domain + '_' + 'address' + ']')) + dic.append((normalize('CB20QQ'), '[' + domain + '_' + 'postcode' + ']')) + dic.append(('01223245151', '[' + domain + '_' + 'phone' + ']')) + dic.append(('1223245151', '[' + domain + '_' + 'phone' + ']')) + dic.append(('0122324515', '[' + domain + '_' + 'phone' + ']')) + dic.append((normalize('Addenbrookes Hospital'), '[' + domain + '_' + 'name' + ']')) + + elif domain == 'police': + dic.append((normalize('Parkside'), '[' + domain + '_' + 'address' + ']')) + dic.append((normalize('CB11JG'), '[' + domain + '_' + 'postcode' + ']')) + dic.append(('01223358966', '[' + domain + '_' + 'phone' + ']')) + dic.append(('1223358966', '[' + domain + '_' + 'phone' + ']')) + dic.append((normalize('Parkside Police Station'), '[' + domain + '_' + 'name' + ']')) + + # add at the end places from trains + fin = open('data/multi-woz/db/' + 'train' + '_db.json', 'r') + db_json = json.load(fin) + fin.close() + + for ent in db_json: + for key, val in ent.items(): + if key == 'departure' or key == 'destination': + dic.append((normalize(val), '[' + 'value' + '_' + 'place' + ']')) + + # add specific values: + for key in ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']: + dic.append((normalize(key), '[' + 'value' + '_' + 'day' + ']')) + + # more general values add at the end + dic.extend(dic_area) + dic.extend(dic_food) + dic.extend(dic_price) + + return dic + + +def delexicalise(utt, dictionary): + for key, val in dictionary: + utt = (' ' + utt + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + utt = utt[1:-1] # why this? + + return utt + + +def delexicaliseDomain(utt, dictionary, domain): + for key, val in dictionary: + if key == domain or key == 'value': + utt = (' ' + utt + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + utt = utt[1:-1] # why this? + + # go through rest of domain in case we are missing something out? + for key, val in dictionary: + utt = (' ' + utt + ' ').replace(' ' + key + ' ', ' ' + val + ' ') + utt = utt[1:-1] # why this? + return utt + +if __name__ == '__main__': + prepareSlotValuesIndependent() \ No newline at end of file diff --git a/utils/multiwoz/fix_label.py b/utils/multiwoz/fix_label.py new file mode 100644 index 0000000..e723bcb --- /dev/null +++ b/utils/multiwoz/fix_label.py @@ -0,0 +1,407 @@ + +def fix_general_label_error(labels, type, slots, ontology_version=""): + label_dict = dict([ (l[0], l[1]) for l in labels]) if type else dict([ (l["slots"][0][0], l["slots"][0][1]) for l in labels]) + + GENERAL_TYPO = { + # type + "guesthouse":"guest house","guesthouses":"guest house","guest":"guest house","mutiple sports":"multiple sports", + "mutliple sports":"multiple sports","sports":"multiple sports","swimmingpool":"swimming pool", + "concerthall":"concert hall", "concert":"concert hall", "pool":"swimming pool", "night club":"nightclub", "mus":"museum", + "colleges":"college", "coll":"college","architectural":"architecture", "musuem":"museum", "churches":"church", + + # area + "center":"centre", "center of town":"centre", "near city center":"centre", "in the north":"north", + "cen":"centre", "east side":"east","east area":"east", "west part of town":"west", "ce":"centre", + "town center":"centre", "centre of cambridge":"centre", + "city center":"centre", "the south":"south", "scentre":"centre", "town centre":"centre", "in town":"centre", + "north part of town":"north", "centre of town":"centre", "cb30aq": "none", + + # price + "mode":"moderate", "moderate -ly": "moderate", "mo":"moderate", + + # day + "monda": "monday", + + # parking + "free parking":"free", + + # internet + "free internet":"yes", + + # star + "4 star":"4", "4 stars":"4", "0 star rarting":"none", + + # others + "y":"yes", "any":"do n't care", "does not care":"do n't care", "not men":"none", "not":"none", + "not mentioned":"none", '':"none", "not mendtioned":"none", "3 .":"3", "does not":"no", "fun":"none", + } + + for slot in slots: + if slot in label_dict.keys(): + + # general typos + if label_dict[slot] in GENERAL_TYPO.keys(): + label_dict[slot] = label_dict[slot].replace(label_dict[slot], GENERAL_TYPO[label_dict[slot]]) + + # do not care + if label_dict[slot] in ["doesn't care", "don't care", "dont care", "does not care", "do not care", "dontcare"]: + label_dict[slot] = "do n't care" + + # miss match slot and value + if slot == "hotel-type" and label_dict[slot] in ["nigh", "moderate -ly priced", "bed and breakfast", "centre", "venetian", "intern", "a cheap -er hotel"]: + label_dict[slot] = "none" + if slot == "hotel-internet" and label_dict[slot] == "4": + label_dict[slot] = "none" + if slot == "hotel-internet" and label_dict[slot] == "4": + label_dict[slot] = "none" + if slot == "hotel-pricerange" and label_dict[slot] == "2": + label_dict[slot] = "none" + if "area" in slot and label_dict[slot] in ["moderate"]: + label_dict[slot] = "none" + if "day" in slot and label_dict[slot] == "t": + label_dict[slot] = "none" + if slot == "hotel-type" and label_dict[slot] in ["hotel with free parking and free wifi", "4", "3 star hotel"]: + label_dict[slot] = "hotel" + if slot == "hotel-star" and label_dict[slot] == "3 star hotel": + label_dict[slot] = "3" + + if "area" in slot: + if label_dict[slot] == "no": + label_dict[slot] = "north" + elif label_dict[slot] == "we": + label_dict[slot] = "west" + elif label_dict[slot] == "cent": + label_dict[slot] = "centre" + + if "day" in slot: + if label_dict[slot] == "we": + label_dict[slot] = "wednesday" + elif label_dict[slot] == "no": + label_dict[slot] = "none" + + if "price" in slot and label_dict[slot] == "ch": + label_dict[slot] = "cheap" + if "internet" in slot and label_dict[slot] == "free": + label_dict[slot] = "yes" + + + # Add on May, 2020 + if ontology_version in ["1.0"]: + + label_dict[slot] = label_dict[slot].replace("theater", "theatre").replace("guesthouse", "guest house") + + # Typo or naming + if label_dict[slot] == "cafe uno": + label_dict[slot] = "caffe uno" + if label_dict[slot] == "alpha milton guest house": + label_dict[slot] = "alpha-milton guest house" + if label_dict[slot] in ["churchills college", "churchhill college", "churchill", "the churchill college"]: + label_dict[slot] = "churchill college" + if label_dict[slot] == "portugese": + label_dict[slot] = "portuguese" + if label_dict[slot] == "pizza hut fenditton": + label_dict[slot] = "pizza hut fen ditton" + if label_dict[slot] == "restaurant 17": + label_dict[slot] = "restaurant one seven" + if label_dict[slot] == "restaurant 2 two": + label_dict[slot] = "restaurant two two" + if label_dict[slot] == "gallery at 12 a high street": + label_dict[slot] = "gallery at twelve a high street" + if label_dict[slot] == "museum of archaelogy": + label_dict[slot] = "museum of archaelogy and anthropology" + if label_dict[slot] in ["huntingdon marriot hotel", "marriot hotel"]: + label_dict[slot] = "huntingdon marriott hotel" + if label_dict[slot] in ["sheeps green and lammas land park fen causeway", "sheeps green and lammas land park"]: + label_dict[slot] = "sheep's green and lammas land park fen causeway" + if label_dict[slot] in ["cambridge and country folk museum", "county folk museum"]: + label_dict[slot] = "cambridge and county folk museum" + if label_dict[slot] == "ambridge": + label_dict[slot] = "cambridge" + if label_dict[slot] == "cambridge contemporary art museum": + label_dict[slot] = "cambridge contemporary art" + if label_dict[slot] == "molecular gastonomy": + label_dict[slot] = "molecular gastronomy" + if label_dict[slot] == "2 two and cote": + label_dict[slot] = "two two and cote" + if label_dict[slot] == "caribbeanindian": + label_dict[slot] = "caribbean|indian" + if label_dict[slot] == "whipple museum": + label_dict[slot] = "whipple museum of the history of science" + if label_dict[slot] == "ian hong": + label_dict[slot] = "ian hong house" + if label_dict[slot] == "sundaymonday": + label_dict[slot] = "sunday|monday" + if label_dict[slot] == "mondaythursday": + label_dict[slot] = "monday|thursday" + if label_dict[slot] == "fridaytuesday": + label_dict[slot] = "friday|tuesday" + if label_dict[slot] == "cheapmoderate": + label_dict[slot] = "cheap|moderate" + if label_dict[slot] == "golden house golden house": + label_dict[slot] = "the golden house" + if label_dict[slot] == "golden house": + label_dict[slot] = "the golden house" + if label_dict[slot] == "sleeperz": + label_dict[slot] = "sleeperz hotel" + if label_dict[slot] == "jamaicanchinese": + label_dict[slot] = "jamaican|chinese" + if label_dict[slot] == "shiraz": + label_dict[slot] = "shiraz restaurant" + if label_dict[slot] == "museum of archaelogy and anthropogy": + label_dict[slot] = "museum of archaelogy and anthropology" + if label_dict[slot] == "yipee noodle bar": + label_dict[slot] = "yippee noodle bar" + if label_dict[slot] == "abc theatre": + label_dict[slot] = "adc theatre" + if label_dict[slot] == "wankworth house": + label_dict[slot] = "warkworth house" + if label_dict[slot] in ["cherry hinton water play park", "cherry hinton water park"]: + label_dict[slot] = "cherry hinton water play" + if label_dict[slot] == "the gallery at 12": + label_dict[slot] = "the gallery at twelve" + if label_dict[slot] == "barbequemodern european": + label_dict[slot] = "barbeque|modern european" + if label_dict[slot] == "north americanindian": + label_dict[slot] = "north american|indian" + if label_dict[slot] == "chiquito": + label_dict[slot] = "chiquito restaurant bar" + + + # Abbreviation + if label_dict[slot] == "city centre north bed and breakfast": + label_dict[slot] = "city centre north b and b" + if label_dict[slot] == "north bed and breakfast": + label_dict[slot] = "north b and b" + + # Article and 's + if label_dict[slot] == "christ college": + label_dict[slot] = "christ's college" + if label_dict[slot] == "kings college": + label_dict[slot] = "king's college" + if label_dict[slot] == "saint johns college": + label_dict[slot] = "saint john's college" + if label_dict[slot] == "kettles yard": + label_dict[slot] = "kettle's yard" + if label_dict[slot] == "rosas bed and breakfast": + label_dict[slot] = "rosa's bed and breakfast" + if label_dict[slot] == "saint catharines college": + label_dict[slot] = "saint catharine's college" + if label_dict[slot] == "little saint marys church": + label_dict[slot] = "little saint mary's church" + if label_dict[slot] == "great saint marys church": + label_dict[slot] = "great saint mary's church" + if label_dict[slot] in ["queens college", "queens' college"]: + label_dict[slot] = "queen's college" + if label_dict[slot] == "peoples portraits exhibition at girton college": + label_dict[slot] = "people's portraits exhibition at girton college" + if label_dict[slot] == "st johns college": + label_dict[slot] = "saint john's college" + if label_dict[slot] == "whale of time": + label_dict[slot] = "whale of a time" + if label_dict[slot] in ["st catharines college", "saint catharines college"]: + label_dict[slot] = "saint catharine's college" + + # Time + if label_dict[slot] == "16,15": + label_dict[slot] = "16:15" + if label_dict[slot] == "1330": + label_dict[slot] = "13:30" + if label_dict[slot] == "1430": + label_dict[slot] = "14:30" + if label_dict[slot] == "1532": + label_dict[slot] = "15:32" + if label_dict[slot] == "845": + label_dict[slot] = "08:45" + if label_dict[slot] == "1145": + label_dict[slot] = "11:45" + if label_dict[slot] == "1545": + label_dict[slot] = "15:45" + if label_dict[slot] == "1329": + label_dict[slot] = "13:29" + if label_dict[slot] == "1345": + label_dict[slot] = "13:45" + if label_dict[slot] == "1715": + label_dict[slot] = "17:15" + if label_dict[slot] == "929": + label_dict[slot] = "09:29" + + + # restaurant + if slot == "restaurant-name" and "meze bar" in label_dict[slot]: + label_dict[slot] = "meze bar restaurant" + if slot == "restaurant-name" and label_dict[slot] == "alimentum": + label_dict[slot] = "restaurant alimentum" + if slot == "restaurant-name" and label_dict[slot] == "good luck": + label_dict[slot] = "the good luck chinese food takeaway" + if slot == "restaurant-name" and label_dict[slot] == "grafton hotel": + label_dict[slot] = "grafton hotel restaurant" + if slot == "restaurant-name" and label_dict[slot] == "2 two": + label_dict[slot] = "restaurant two two" + if slot == "restaurant-name" and label_dict[slot] == "hotpot": + label_dict[slot] = "the hotpot" + if slot == "restaurant-name" and label_dict[slot] == "hobsons house": + label_dict[slot] = "hobson house" + if slot == "restaurant-name" and label_dict[slot] == "shanghai": + label_dict[slot] = "shanghai family restaurant" + if slot == "restaurant-name" and label_dict[slot] == "17": + label_dict[slot] = "restaurant one seven" + if slot == "restaurant-name" and label_dict[slot] in ["22", "restaurant 22"]: + label_dict[slot] = "restaurant two two" + if slot == "restaurant-name" and label_dict[slot] == "the maharajah tandoor": + label_dict[slot] = "maharajah tandoori restaurant" + if slot == "restaurant-name" and label_dict[slot] == "the grafton hotel": + label_dict[slot] = "grafton hotel restaurant" + if slot == "restaurant-name" and label_dict[slot] == "gardenia": + label_dict[slot] = "the gardenia" + if slot == "restaurant-name" and label_dict[slot] == "el shaddia guest house": + label_dict[slot] = "el shaddai" + if slot == "restaurant-name" and label_dict[slot] == "the bedouin": + label_dict[slot] = "bedouin" + if slot == "restaurant-name" and label_dict[slot] == "the kohinoor": + label_dict[slot] = "kohinoor" + if slot == "restaurant-name" and label_dict[slot] == "the peking": + label_dict[slot] = "peking restaurant" + if slot == "restaurant-book time" and label_dict[slot] == "7pm": + label_dict[slot] = "19:00" + if slot == "restaurant-book time" and label_dict[slot] == "4pm": + label_dict[slot] = "16:00" + if slot == "restaurant-book time" and label_dict[slot] == "8pm": + label_dict[slot] = "20:00" + if slot == "restaurant-name" and label_dict[slot] == "sitar": + label_dict[slot] = "sitar tandoori" + if slot == "restaurant-name" and label_dict[slot] == "binh": + label_dict[slot] = "thanh binh" + if slot == "restaurant-name" and label_dict[slot] == "mahal": + label_dict[slot] = "mahal of cambridge" + + # attraction + if slot == "attraction-name" and label_dict[slot] == "scudamore": + label_dict[slot] = "scudamores punting co" + if slot == "attraction-name" and label_dict[slot] == "salsa": + label_dict[slot] = "club salsa" + if slot == "attraction-name" and label_dict[slot] in ["abbey pool", "abbey pool and astroturf"]: + label_dict[slot] = "abbey pool and astroturf pitch" + if slot == "attraction-name" and label_dict[slot] == "cherry hinton hall": + label_dict[slot] = "cherry hinton hall and grounds" + if slot == "attraction-name" and label_dict[slot] == "trinity street college": + label_dict[slot] = "trinity college" + if slot == "attraction-name" and label_dict[slot] == "the wandlebury": + label_dict[slot] = "wandlebury country park" + if slot == "attraction-name" and label_dict[slot] == "king hedges learner pool": + label_dict[slot] = "kings hedges learner pool" + if slot == "attraction-name" and label_dict[slot] in ["botanic gardens", "cambridge botanic gardens"]: + label_dict[slot] = "cambridge university botanic gardens" + if slot == "attraction-name" and label_dict[slot] == "soultree": + label_dict[slot] = "soul tree nightclub" + if slot == "attraction-name" and label_dict[slot] == "queens": + label_dict[slot] = "queen's college" + if slot == "attraction-name" and label_dict[slot] == "sheeps green": + label_dict[slot] = "sheep's green and lammas land park fen causeway" + if slot == "attraction-name" and label_dict[slot] == "jesus green": + label_dict[slot] = "jesus green outdoor pool" + if slot == "attraction-name" and label_dict[slot] == "adc": + label_dict[slot] = "adc theatre" + if slot == "attraction-name" and label_dict[slot] == "hobsons house": + label_dict[slot] = "hobson house" + if slot == "attraction-name" and label_dict[slot] == "cafe jello museum": + label_dict[slot] = "cafe jello gallery" + if slot == "attraction-name" and label_dict[slot] == "whippple museum": + label_dict[slot] = "whipple museum of the history of science" + if slot == "attraction-type" and label_dict[slot] == "boating": + label_dict[slot] = "boat" + if slot == "attraction-name" and label_dict[slot] == "peoples portraits exhibition": + label_dict[slot] = "people's portraits exhibition at girton college" + if slot == "attraction-name" and label_dict[slot] == "lammas land park": + label_dict[slot] = "sheep's green and lammas land park fen causeway" + + # taxi + if slot in ["taxi-destination", "taxi-departure"] and label_dict[slot] == "meze bar": + label_dict[slot] = "meze bar restaurant" + if slot in ["taxi-destination", "taxi-departure"] and label_dict[slot] == "el shaddia guest house": + label_dict[slot] = "el shaddai" + if slot == "taxi-departure" and label_dict[slot] == "centre of town at my hotel": + label_dict[slot] = "hotel" + + # train + if slot == "train-departure" and label_dict[slot] in ["liverpool", "london liverpool"]: + label_dict[slot] = "london liverpool street" + if slot == "train-destination" and label_dict[slot] == "liverpool street": + label_dict[slot] = "london liverpool street" + if slot == "train-departure" and label_dict[slot] == "alpha milton": + label_dict[slot] = "alpha-milton" + + # hotel + if slot == "hotel-name" and label_dict[slot] == "el shaddia guest house": + label_dict[slot] = "el shaddai" + if slot == "hotel-name" and label_dict[slot] == "alesbray lodge guest house": + label_dict[slot] = "aylesbray lodge guest house" + if slot == "hotel-name" and label_dict[slot] == "the gonvile hotel": + label_dict[slot] = "the gonville hotel" + if slot == "hotel-name" and label_dict[slot] == "no": + label_dict[slot] = "none" + if slot == "hotel-name" and label_dict[slot] in ["holiday inn", "holiday inn cambridge"]: + label_dict[slot] = "express by holiday inn cambridge" + if slot == "hotel-name" and label_dict[slot] == "wartworth": + label_dict[slot] = "warkworth house" + + # Suppose to be a wrong annotation + if slot == "restaurant-name" and label_dict[slot] == "south": + label_dict[slot] = "none" + if slot == "attraction-type" and label_dict[slot] == "churchill college": + label_dict[slot] = "none" + if slot == "attraction-name" and label_dict[slot] == "boat": + label_dict[slot] = "none" + if slot == "attraction-type" and label_dict[slot] == "museum kettles yard": + label_dict[slot] = "none" + if slot == "attraction-type" and label_dict[slot] == "hotel": + label_dict[slot] = "none" + if slot == "attraction-type" and label_dict[slot] == "camboats": + label_dict[slot] = "boat" + + + # TODO: Need to check with dialogue data to deal with strange labels before + + # if slot == "restaurant-name" and label_dict[slot] == "eraina and michaelhouse cafe": + # label_dict[slot] = "eraina|michaelhouse cafe" + # if slot == "attraction-name" and label_dict[slot] == "gonville hotel": + # label_dict[slot] = "none" + # if label_dict[slot] == "good luck": + # label_dict[slot] = "the good luck chinese food takeaway" + # if slot == "restaurant-book time" and label_dict[slot] == "9": + # label_dict[slot] = "21:00" + # if slot == "taxi-departure" and label_dict[slot] == "girton college": + # label_dict[slot] = "people's portraits exhibition at girton college" + # if slot == "restaurant-name" and label_dict[slot] == "molecular gastronomy": + # label_dict[slot] = "none" + # [Info] Adding Slot: restaurant-name with value: primavera + # [Info] Adding Slot: train-departure with value: huntingdon + # [Info] Adding Slot: attraction-name with value: aylesbray lodge guest house + # [Info] Adding Slot: attraction-name with value: gallery + # [Info] Adding Slot: hotel-name with value: eraina + # [Info] Adding Slot: restaurant-name with value: india west + # [Info] Adding Slot: restaurant-name with value: autumn house + # [Info] Adding Slot: train-destination with value: norway + # [Info] Adding Slot: attraction-name with value: cinema cinema + # [Info] Adding Slot: hotel-name with value: lan hon + # [Info] Adding Slot: restaurant-food with value: sushi + # [Info] Adding Slot: attraction-name with value: university arms hotel + # [Info] Adding Slot: train-departure with value: stratford + # [Info] Adding Slot: attraction-name with value: history of science museum + # [Info] Adding Slot: restaurant-name with value: nil + # [Info] Adding Slot: train-leaveat with value: 9 + # [Info] Adding Slot: restaurant-name with value: ashley hotel + # [Info] Adding Slot: taxi-destination with value: the cambridge shop + # [Info] Adding Slot: hotel-name with value: acorn place + # [Info] Adding Slot: restaurant-name with value: de luca cucina and bar riverside brasserie + # [Info] Adding Slot: hotel-name with value: super 5 + # [Info] Adding Slot: attraction-name with value: archway house + # [Info] Adding Slot: train-arriveby with value: 8 + # [Info] Adding Slot: train-leaveat with value: 10 + # [Info] Adding Slot: restaurant-book time with value: 9 + # [Info] Adding Slot: hotel-name with value: nothamilton lodge + # [Info] Adding Slot: attraction-name with value: st christs college + + return label_dict + + + \ No newline at end of file diff --git a/utils/multiwoz/mapping.pair b/utils/multiwoz/mapping.pair new file mode 100644 index 0000000..34df41d --- /dev/null +++ b/utils/multiwoz/mapping.pair @@ -0,0 +1,83 @@ +it's it is +don't do not +doesn't does not +didn't did not +you'd you would +you're you are +you'll you will +i'm i am +they're they are +that's that is +what's what is +couldn't could not +i've i have +we've we have +can't cannot +i'd i would +i'd i would +aren't are not +isn't is not +wasn't was not +weren't were not +won't will not +there's there is +there're there are +. . . +restaurants restaurant -s +hotels hotel -s +laptops laptop -s +cheaper cheap -er +dinners dinner -s +lunches lunch -s +breakfasts breakfast -s +expensively expensive -ly +moderately moderate -ly +cheaply cheap -ly +prices price -s +places place -s +venues venue -s +ranges range -s +meals meal -s +locations location -s +areas area -s +policies policy -s +children child -s +kids kid -s +kidfriendly kid friendly +cards card -s +upmarket expensive +inpricey cheap +inches inch -s +uses use -s +dimensions dimension -s +driverange drive range +includes include -s +computers computer -s +machines machine -s +families family -s +ratings rating -s +constraints constraint -s +pricerange price range +batteryrating battery rating +requirements requirement -s +drives drive -s +specifications specification -s +weightrange weight range +harddrive hard drive +batterylife battery life +businesses business -s +hours hour -s +one 1 +two 2 +three 3 +four 4 +five 5 +six 6 +seven 7 +eight 8 +nine 9 +ten 10 +eleven 11 +twelve 12 +anywhere any where +good bye goodbye diff --git a/utils/multiwoz/nlp.py b/utils/multiwoz/nlp.py new file mode 100644 index 0000000..779fd37 --- /dev/null +++ b/utils/multiwoz/nlp.py @@ -0,0 +1,245 @@ +import math +import re +from collections import Counter + +from nltk.util import ngrams + +timepat = re.compile("\d{1,2}[:]\d{1,2}") +pricepat = re.compile("\d{1,3}[.]\d{1,2}") + + +fin = open('utils/multiwoz/mapping.pair', 'r') +replacements = [] +for line in fin.readlines(): + tok_from, tok_to = line.replace('\n', '').split('\t') + replacements.append((' ' + tok_from + ' ', ' ' + tok_to + ' ')) + + +def insertSpace(token, text): + sidx = 0 + while True: + sidx = text.find(token, sidx) + if sidx == -1: + break + if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \ + re.match('[0-9]', text[sidx + 1]): + sidx += 1 + continue + if text[sidx - 1] != ' ': + text = text[:sidx] + ' ' + text[sidx:] + sidx += 1 + if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ': + text = text[:sidx + 1] + ' ' + text[sidx + 1:] + sidx += 1 + return text + + +def normalize(text, clean_value=True): + # lower case every word + text = text.lower() + + # replace white spaces in front and end + text = re.sub(r'^\s*|\s*$', '', text) + + # hotel domain pfb30 + text = re.sub(r"b&b", "bed and breakfast", text) + text = re.sub(r"b and b", "bed and breakfast", text) + + if clean_value: + # normalize phone number + ms = re.findall('\(?(\d{3})\)?[-.\s]?(\d{3})[-.\s]?(\d{4,5})', text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m[0], sidx) + if text[sidx - 1] == '(': + sidx -= 1 + eidx = text.find(m[-1], sidx) + len(m[-1]) + text = text.replace(text[sidx:eidx], ''.join(m)) + + # normalize postcode + ms = re.findall('([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})', + text) + if ms: + sidx = 0 + for m in ms: + sidx = text.find(m, sidx) + eidx = sidx + len(m) + text = text[:sidx] + re.sub('[,\. ]', '', m) + text[eidx:] + + # weird unicode bug + text = re.sub(u"(\u2018|\u2019)", "'", text) + + if clean_value: + # replace time and and price + text = re.sub(timepat, ' [value_time] ', text) + text = re.sub(pricepat, ' [value_price] ', text) + #text = re.sub(pricepat2, '[value_price]', text) + + # replace st. + text = text.replace(';', ',') + text = re.sub('$\/', '', text) + text = text.replace('/', ' and ') + + # replace other special characters + text = text.replace('-', ' ') + text = re.sub('[\"\<>@\(\)]', '', text) # remove + + # insert white space before and after tokens: + for token in ['?', '.', ',', '!']: + text = insertSpace(token, text) + + # insert white space for 's + text = insertSpace('\'s', text) + + # replace it's, does't, you'd ... etc + text = re.sub('^\'', '', text) + text = re.sub('\'$', '', text) + text = re.sub('\'\s', ' ', text) + text = re.sub('\s\'', ' ', text) + for fromx, tox in replacements: + text = ' ' + text + ' ' + text = text.replace(fromx, tox)[1:-1] + + # remove multiple spaces + text = re.sub(' +', ' ', text) + + # concatenate numbers + tmp = text + tokens = text.split() + i = 1 + while i < len(tokens): + if re.match(u'^\d+$', tokens[i]) and \ + re.match(u'\d+$', tokens[i - 1]): + tokens[i - 1] += tokens[i] + del tokens[i] + else: + i += 1 + text = ' '.join(tokens) + + return text + + +class BLEUScorer(object): + ## BLEU score calculator via GentScorer interface + ## it calculates the BLEU-4 by taking the entire corpus in + ## Calulate based multiple candidates against multiple references + def __init__(self): + pass + + def score(self, hypothesis, corpus, n=1): + # containers + count = [0, 0, 0, 0] + clip_count = [0, 0, 0, 0] + r = 0 + c = 0 + weights = [0.25, 0.25, 0.25, 0.25] + + # accumulate ngram statistics + for hyps, refs in zip(hypothesis, corpus): + if type(hyps[0]) is list: + hyps = [hyp.split() for hyp in hyps[0]] + else: + hyps = [hyp.split() for hyp in hyps] + + refs = [ref.split() for ref in refs] + + # Shawn's evaluation + refs[0] = [u'GO_'] + refs[0] + [u'EOS_'] + hyps[0] = [u'GO_'] + hyps[0] + [u'EOS_'] + + for idx, hyp in enumerate(hyps): + for i in range(4): + # accumulate ngram counts + hypcnts = Counter(ngrams(hyp, i + 1)) + cnt = sum(hypcnts.values()) + count[i] += cnt + + # compute clipped counts + max_counts = {} + for ref in refs: + refcnts = Counter(ngrams(ref, i + 1)) + for ng in hypcnts: + max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) + clipcnt = dict((ng, min(count, max_counts[ng])) \ + for ng, count in hypcnts.items()) + clip_count[i] += sum(clipcnt.values()) + + # accumulate r & c + bestmatch = [1000, 1000] + for ref in refs: + if bestmatch[0] == 0: break + diff = abs(len(ref) - len(hyp)) + if diff < bestmatch[0]: + bestmatch[0] = diff + bestmatch[1] = len(ref) + r += bestmatch[1] + c += len(hyp) + if n == 1: + break + # computing bleu score + p0 = 1e-7 + bp = 1 if c > r else math.exp(1 - float(r) / float(c)) + p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ + for i in range(4)] + s = math.fsum(w * math.log(p_n) \ + for w, p_n in zip(weights, p_ns) if p_n) + bleu = bp * math.exp(s) + return bleu + + +class GentScorer(object): + def __init__(self, detectfile): + self.bleuscorer = BLEUScorer() + + def scoreBLEU(self, parallel_corpus): + return self.bleuscorer.score(parallel_corpus) + + +def sentence_bleu_4(hyp, refs, weights=[0.25, 0.25, 0.25, 0.25]): + # input : single sentence, multiple references + count = [0, 0, 0, 0] + clip_count = [0, 0, 0, 0] + r = 0 + c = 0 + + for i in range(4): + hypcnts = Counter(ngrams(hyp, i + 1)) + cnt = sum(hypcnts.values()) + count[i] += cnt + + # compute clipped counts + max_counts = {} + for ref in refs: + refcnts = Counter(ngrams(ref, i + 1)) + for ng in hypcnts: + max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) + clipcnt = dict((ng, min(count, max_counts[ng])) \ + for ng, count in hypcnts.items()) + clip_count[i] += sum(clipcnt.values()) + + bestmatch = [1000, 1000] + for ref in refs: + if bestmatch[0] == 0: + break + diff = abs(len(ref) - len(hyp)) + if diff < bestmatch[0]: + bestmatch[0] = diff + bestmatch[1] = len(ref) + r = bestmatch[1] + c = len(hyp) + + p0 = 1e-7 + bp = math.exp(-abs(1.0 - float(r) / float(c + p0))) + + p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 for i in range(4)] + s = math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n) + bleu_hyp = bp * math.exp(s) + + return bleu_hyp + +if __name__ == '__main__': + text = "restaurant's CB39AL one seven" + text = "I'm I'd restaurant's CB39AL 099939399 one seven" + text = "ndd 19.30 nndd" + m = re.findall("(\d+\.\d+)", text) \ No newline at end of file diff --git a/utils/utils_camrest676.py b/utils/utils_camrest676.py new file mode 100644 index 0000000..817b732 --- /dev/null +++ b/utils/utils_camrest676.py @@ -0,0 +1,73 @@ +import json +import ast +import collections +import os + +from .utils_function import get_input_example + + +def read_langs_turn(args, file_name, max_line = None): + print(("Reading from {} for read_langs_turn".format(file_name))) + + data = [] + + with open(file_name) as f: + dials = json.load(f) + + cnt_lin = 1 + for dial_dict in dials: + dialog_history = [""] + + # Reading data + for ti, turn in enumerate(dial_dict["dial"]): + assert ti == turn["turn"] + turn_usr = turn["usr"]["transcript"].lower().strip() + turn_sys = turn["sys"]["sent"].lower().strip() + + data_detail = get_input_example("turn") + data_detail["ID"] = "camrest676-"+str(cnt_lin) + data_detail["turn_id"] = turn["turn"] + data_detail["turn_usr"] = turn_usr + data_detail["turn_sys"] = turn_sys + data_detail["dialog_history"] = list(dialog_history) + + if not args["only_last_turn"]: + data.append(data_detail) + + dialog_history.append(turn_usr) + dialog_history.append(turn_sys) + + if args["only_last_turn"]: + data.append(data_detail) + + cnt_lin += 1 + if(max_line and cnt_lin >= max_line): + break + + return data + + +def read_langs_dial(file_name, ontology, dialog_act, max_line = None, domain_act_flag=False): + print(("Reading from {} for read_langs_dial".format(file_name))) + + raise NotImplementedError + + + +def prepare_data_camrest676(args): + example_type = args["example_type"] + max_line = args["max_line"] + + file_trn = os.path.join(args["data_path"], 'CamRest676/CamRest676.json') + + _example_type = "dial" if "dial" in example_type else example_type + pair_trn = globals()["read_langs_{}".format(_example_type)](args, file_trn, max_line) + pair_dev = [] + pair_tst = [] + + print("Read %s pairs train from CamRest676" % len(pair_trn)) + + meta_data = {"num_labels":0} + + return pair_trn, pair_dev, pair_tst, meta_data + diff --git a/utils/utils_frames.py b/utils/utils_frames.py new file mode 100644 index 0000000..6c69ec4 --- /dev/null +++ b/utils/utils_frames.py @@ -0,0 +1,80 @@ +import json +import ast +import collections +import os + +from .utils_function import get_input_example + + +def read_langs_turn(args, file_name, max_line = None, ds_name=""): + print(("Reading from {} for read_langs_turn".format(file_name))) + + data = [] + + with open(file_name) as f: + dials = json.load(f) + + cnt_lin = 1 + for dial_dict in dials: + dialog_history = [] + + turn_usr = "" + turn_sys = "" + for ti, turn in enumerate(dial_dict["turns"]): + if turn["author"] == "user": + turn_usr = turn["text"].lower().strip() + + data_detail = get_input_example("turn") + data_detail["ID"] = "{}-{}".format(ds_name, cnt_lin) + data_detail["turn_id"] = ti % 2 + data_detail["turn_usr"] = turn_usr + data_detail["turn_sys"] = turn_sys + data_detail["dialog_history"] = list(dialog_history) + + if not args["only_last_turn"]: + data.append(data_detail) + + dialog_history.append(turn_sys) + dialog_history.append(turn_usr) + + elif turn["author"] == "wizard": + turn_sys = turn["text"].lower().strip() + + if args["only_last_turn"]: + data.append(data_detail) + + cnt_lin += 1 + if(max_line and cnt_lin >= max_line): + break + + return data + + +def read_langs_dial(file_name, ontology, dialog_act, max_line = None, domain_act_flag=False): + print(("Reading from {} for read_langs_dial".format(file_name))) + + raise NotImplementedError + + + +def prepare_data_frames(args): + ds_name = "FRAMES" + + example_type = args["example_type"] + max_line = args["max_line"] + + file_trn = os.path.join(args["data_path"], "frames.json") + + _example_type = "dial" if "dial" in example_type else example_type + pair_trn = globals()["read_langs_{}".format(_example_type)](args, file_trn, max_line, ds_name) + pair_dev = [] + pair_tst = [] + + print("Read {} pairs train from {}".format(len(pair_trn), ds_name)) + print("Read {} pairs valid from {}".format(len(pair_dev), ds_name)) + print("Read {} pairs test from {}".format(len(pair_tst), ds_name)) + + meta_data = {"num_labels":0} + + return pair_trn, pair_dev, pair_tst, meta_data + diff --git a/utils/utils_function.py b/utils/utils_function.py new file mode 100644 index 0000000..0cfedac --- /dev/null +++ b/utils/utils_function.py @@ -0,0 +1,124 @@ +import torch +import numpy as np + +PAD_token = 0 + +def to_cuda(x): + if torch.cuda.is_available(): x = x.cuda() + return x + + +def merge(sequences, ignore_idx=None): + ''' + merge from batch * sent_len to batch * max_len + ''' + pad_token = PAD_token if type(ignore_idx)==type(None) else ignore_idx + lengths = [len(seq) for seq in sequences] + max_len = 1 if max(lengths)==0 else max(lengths) + padded_seqs = torch.ones(len(sequences), max_len).long() * pad_token + for i, seq in enumerate(sequences): + end = lengths[i] + padded_seqs[i, :end] = seq[:end] + padded_seqs = padded_seqs.detach() #torch.tensor(padded_seqs) + return padded_seqs, lengths + + +def merge_multi_response(sequences, ignore_idx=None): + ''' + merge from batch * nb_slot * slot_len to batch * nb_slot * max_slot_len + ''' + pad_token = PAD_token if type(ignore_idx)==type(None) else ignore_idx + lengths = [] + for bsz_seq in sequences: + length = [len(v) for v in bsz_seq] + lengths.append(length) + max_len = max([max(l) for l in lengths]) + padded_seqs = [] + for bsz_seq in sequences: + pad_seq = [] + for v in bsz_seq: + v = v + [pad_token] * (max_len-len(v)) + pad_seq.append(v) + padded_seqs.append(pad_seq) + padded_seqs = torch.tensor(padded_seqs).long() + lengths = torch.tensor(lengths) + return padded_seqs, lengths + + +def merge_sent_and_word(sequences, ignore_idx=None): + ''' + merge from batch * nb_sent * nb_word to batch * max_nb_sent * max_nb_word + ''' + + max_nb_sent = max([len(seq) for seq in sequences]) + max_nb_word, lengths = [], [] + for seq in sequences: + length = [len(sent) for sent in seq] + max_nb_word += length + lengths.append(length) + max_nb_word = max(max_nb_word) + + pad_token = PAD_token if type(ignore_idx)==type(None) else ignore_idx + padded_seqs = np.ones((len(sequences), max_nb_sent, max_nb_word)) * pad_token + + for i, seq in enumerate(sequences): + for ii, sent in enumerate(seq): + padded_seqs[i, ii, :len(sent)] = np.array(sent) + padded_seqs = torch.LongTensor(padded_seqs) + padded_seqs = padded_seqs.detach() + return padded_seqs, lengths + + +def get_input_example(example_type): + + if example_type == "turn": + + data_detail = { + "ID":"", + "turn_id":0, + "domains":[], + "turn_domain":[], + "turn_usr":"", + "turn_sys":"", + "turn_usr_delex":"", + "turn_sys_delex":"", + "belief_state_vec":[], + "db_pointer":[], + "dialog_history":[], + "dialog_history_delex":[], + "belief":{}, + "del_belief":{}, + "slot_gate":[], + "slot_values":[], + "slots":[], + "sys_act":[], + "usr_act":[], + "intent":"", + "turn_slot":[]} + + elif example_type == "dial": + + data_detail = { + "ID":"", + "turn_id":[], + "domains":[], + "turn_domain":[], + "turn_usr":[], + "turn_sys":[], + "turn_usr_delex":[], + "turn_sys_delex":[], + "belief_state_vec":[], + "db_pointer":[], + "dialog_history":[], + "dialog_history_delex":[], + "belief":[], + "del_belief":[], + "slot_gate":[], + "slot_values":[], + "slots":[], + "sys_act":[], + "usr_act":[], + "intent":[], + "turn_slot":[]} + + return data_detail diff --git a/utils/utils_general.py b/utils/utils_general.py new file mode 100644 index 0000000..b91d6de --- /dev/null +++ b/utils/utils_general.py @@ -0,0 +1,81 @@ +import torch +import torch.utils.data as data +import random +import math + +from .dataloader_dst import * +from .dataloader_nlg import * +from .dataloader_nlu import * +from .dataloader_dm import * +from .dataloader_usdl import * + +def get_loader(args, mode, tokenizer, datasets, unified_meta, shuffle=False): + task = args["task"] + batch_size = args["batch_size"] if mode == "train" else args["eval_batch_size"] + + combined_ds = [] + for ds in datasets: + combined_ds += datasets[ds][mode] + + # do not consider empty system responses + if (args["task_name"] == "rs") or (args["task"] == "dm"): + print("[Info] Remove turns with empty system response...") + combined_ds = [d for d in combined_ds if d["turn_sys"]!=""] + + if (args["task_name"] == "rs"): + print("[Info] Remove turn=0 system response...") + combined_ds = [d for d in combined_ds if d["turn_id"]!=0] + + # control data ratio + if (args["train_data_ratio"] != 1 or args["nb_shots"] != -1) and (mode == "train"): + original_len = len(combined_ds) + + if ("oos_intent" in args["dataset"]): + nb_train_sample_per_class = int(100 * args["train_data_ratio"]) + class_count = {k: 0 for k in unified_meta["intent"]} + random.Random(args["rand_seed"]).shuffle(combined_ds) + pair_trn_new = [] + for d in combined_ds: + if class_count[d["intent"]] < nb_train_sample_per_class: + pair_trn_new.append(d) + class_count[d["intent"]] += 1 + combined_ds = pair_trn_new + else: + if args["train_data_ratio"] != 1: + random.Random(args["rand_seed"]).shuffle(combined_ds) + combined_ds = combined_ds[:int(len(combined_ds)*args["train_data_ratio"])] + else: + random.Random(args["rand_seed"]).shuffle(combined_ds) + combined_ds = combined_ds[:args["nb_shots"]] + print("[INFO] Use Training Data: from {} to {}".format(original_len, len(combined_ds))) + + data_info = {k: [] for k in combined_ds[0].keys()} + for d in combined_ds: + for k in combined_ds[0].keys(): + data_info[k].append(d[k]) + + dataset = globals()["Dataset_"+task](data_info, tokenizer, args, unified_meta, mode, args["max_seq_length"]) + + bool_shuffle = (mode=="train" or shuffle) + + data_loader = torch.utils.data.DataLoader(dataset=dataset, + batch_size=batch_size, + shuffle=bool_shuffle, + collate_fn=globals()["collate_fn_{}_{}".format(task, args["example_type"])]) + return data_loader + + +def get_unified_meta(datasets): + unified_meta = {"others":None} + for ds in datasets: + for key, value in datasets[ds]["meta"].items(): + if key not in unified_meta.keys(): + unified_meta[key] = {} + if type(value) == list: + for v in value: + if v not in unified_meta[key].keys(): + unified_meta[key][v] = len(unified_meta[key]) + else: + unified_meta[key] = value + + return unified_meta diff --git a/utils/utils_metalwoz.py b/utils/utils_metalwoz.py new file mode 100644 index 0000000..5204b16 --- /dev/null +++ b/utils/utils_metalwoz.py @@ -0,0 +1,79 @@ +import json +import ast +import collections +import os + +from .utils_function import get_input_example + + +def read_langs_turn(args, dial_files, max_line = None, ds_name=""): + print(("Reading from {} for read_langs_turn".format(ds_name))) + + data = [] + + cnt_lin = 1 + for dial_file in dial_files: + + f_dials = open(dial_file, 'r') + dials = f_dials.readlines() + + for dial in dials: + dialog_history = [] + dial_dict = json.loads(dial) + # Reading data + for ti, turn in enumerate(dial_dict["turns"]): + if ti%2 == 0: + turn_sys = turn.lower().strip() + else: + turn_usr = turn.lower().strip() + data_detail = get_input_example("turn") + data_detail["ID"] = "{}-{}".format(ds_name, cnt_lin) + data_detail["turn_id"] = ti % 2 + data_detail["turn_usr"] = turn_usr + data_detail["turn_sys"] = turn_sys + data_detail["dialog_history"] = list(dialog_history) + + if not args["only_last_turn"]: + data.append(data_detail) + + dialog_history.append(turn_sys) + dialog_history.append(turn_usr) + + if args["only_last_turn"]: + data.append(data_detail) + + cnt_lin += 1 + if(max_line and cnt_lin >= max_line): + break + + return data + + +def read_langs_dial(file_name, ontology, dialog_act, max_line = None, domain_act_flag=False): + print(("Reading from {} for read_langs_dial".format(file_name))) + + raise NotImplementedError + + + +def prepare_data_metalwoz(args): + ds_name = "MetaLWOZ" + + example_type = args["example_type"] + max_line = args["max_line"] + + onlyfiles = [os.path.join(args["data_path"], 'metalwoz/dialogues/{}'.format(f)) for f in os.listdir(os.path.join(args["data_path"], "metalwoz/dialogues/")) if ".txt" in f] + + _example_type = "dial" if "dial" in example_type else example_type + pair_trn = globals()["read_langs_{}".format(_example_type)](args, onlyfiles, max_line, ds_name) + pair_dev = [] + pair_tst = [] + + print("Read {} pairs train from {}".format(len(pair_trn), ds_name)) + print("Read {} pairs valid from {}".format(len(pair_dev), ds_name)) + print("Read {} pairs test from {}".format(len(pair_tst), ds_name)) + + meta_data = {"num_labels":0} + + return pair_trn, pair_dev, pair_tst, meta_data + diff --git a/utils/utils_msre2e.py b/utils/utils_msre2e.py new file mode 100644 index 0000000..4beb62e --- /dev/null +++ b/utils/utils_msre2e.py @@ -0,0 +1,95 @@ +import json +import ast +import collections +import os + +from .utils_function import get_input_example + + +def read_langs_turn(args, file_name, max_line = None, ds_name=""): + print(("Reading from {} for read_langs_turn".format(file_name))) + + data = [] + + with open(file_name) as f: + dials = f.readlines() + + cnt_lin = 1 + dialog_history = [] + turn_usr = "" + turn_sys = "" + turn_idx = 0 + + for dial in dials[1:]: + dial_split = dial.split("\t") + session_ID, Message_ID, Message_from, Message = dial_split[0], dial_split[1], dial_split[3], dial_split[4] + + if Message_ID == "1" and turn_sys != "": + + if args["only_last_turn"]: + data.append(data_detail) + + turn_usr = "" + turn_sys = "" + dialog_history = [] + cnt_lin += 1 + turn_idx = 0 + + if Message_from == "user": + turn_usr = Message.lower().strip() + data_detail = get_input_example("turn") + data_detail["ID"] = "{}-{}".format(ds_name, cnt_lin) + data_detail["turn_id"] = turn_idx + data_detail["turn_usr"] = turn_usr + data_detail["turn_sys"] = turn_sys + data_detail["dialog_history"] = list(dialog_history) + + if not args["only_last_turn"]: + data.append(data_detail) + + dialog_history.append(turn_sys) + dialog_history.append(turn_usr) + turn_idx += 1 + elif Message_from == "agent": + turn_sys = Message.lower().strip() + + if(max_line and cnt_lin >= max_line): + break + + return data + + +def read_langs_dial(file_name, ontology, dialog_act, max_line = None, domain_act_flag=False): + print(("Reading from {} for read_langs_dial".format(file_name))) + + raise NotImplementedError + + + +def prepare_data_msre2e(args): + ds_name = "MSR-E2E" + + example_type = args["example_type"] + max_line = args["max_line"] + + file_mov = os.path.join(args["data_path"], 'e2e_dialog_challenge/data/movie_all.tsv') + file_rst = os.path.join(args["data_path"], 'e2e_dialog_challenge/data/restaurant_all.tsv') + file_tax = os.path.join(args["data_path"], 'e2e_dialog_challenge/data/taxi_all.tsv') + + _example_type = "dial" if "dial" in example_type else example_type + pair_mov = globals()["read_langs_{}".format(_example_type)](args, file_mov, max_line, ds_name+"-mov") + pair_rst = globals()["read_langs_{}".format(_example_type)](args, file_rst, max_line, ds_name+"-rst") + pair_tax = globals()["read_langs_{}".format(_example_type)](args, file_tax, max_line, ds_name+"-tax") + + pair_trn = pair_mov + pair_rst + pair_tax + pair_dev = [] + pair_tst = [] + + print("Read {} pairs train from {}".format(len(pair_trn), ds_name)) + print("Read {} pairs valid from {}".format(len(pair_dev), ds_name)) + print("Read {} pairs test from {}".format(len(pair_tst), ds_name)) + + meta_data = {"num_labels":0} + + return pair_trn, pair_dev, pair_tst, meta_data + diff --git a/utils/utils_multiwoz.py b/utils/utils_multiwoz.py new file mode 100644 index 0000000..2b20081 --- /dev/null +++ b/utils/utils_multiwoz.py @@ -0,0 +1,251 @@ +import json +import ast +import collections +import os + +from .utils_function import get_input_example +from .multiwoz.fix_label import * + +EXPERIMENT_DOMAINS = ["hotel", "train", "restaurant", "attraction", "taxi"] #, "hospital", "police"] + + +def read_langs_turn(args, file_name, ontology, dialog_act, max_line = None, domain_act_flag=False, update_ont_flag=False): + print(("Reading from {} for read_langs_turn".format(file_name))) + + data = [] + SLOTS = [k for k in ontology.keys()] + max_resp_len, max_value_len = 0, 0 + domain_counter = {} + response_candidates = set() + add_slot_values = set() + + with open(file_name) as f: + dials = json.load(f) + + cnt_lin = 1 + for dial_dict in dials: + dialog_history, dialog_history_delex = [], [] + + # Filtering and counting domains + for domain in dial_dict["domains"]: + if domain not in EXPERIMENT_DOMAINS: + continue + if domain not in domain_counter.keys(): + domain_counter[domain] = 0 + domain_counter[domain] += 1 + + # Reading data + for ti, turn in enumerate(dial_dict["dialogue"]): + + belief_dict = fix_general_label_error(turn["belief_state"], False, SLOTS, args["ontology_version"]) + belief_list = [str(k)+'-'+str(v) for k, v in belief_dict.items()] + turn_slot_dict = fix_general_label_error(turn["turn_label"], True, SLOTS, args["ontology_version"]) + turn_slot_list = [str(k)+'-'+str(v) for k, v in turn_slot_dict.items()] + turn_slot = list(set([k.split("-")[1] for k, v in turn_slot_dict.items()])) + + slot_values, gates = [], [] + for slot in SLOTS: + if slot in belief_dict.keys(): + + # update ontology + if "the {}".format(belief_dict[slot]) in ontology[slot].keys(): + belief_dict[slot] = "the {}".format(belief_dict[slot]) + + if belief_dict[slot] not in ontology[slot].keys() and update_ont_flag: + if slot+"-"+belief_dict[slot] not in add_slot_values: + print("[Info] Adding Slot: {} with value: [{}]".format(slot, belief_dict[slot])) + add_slot_values.add(slot+"-"+belief_dict[slot]) + + ontology[slot][belief_dict[slot]] = len(ontology[slot]) + + slot_values.append(belief_dict[slot]) + + if belief_dict[slot] == "none": + gates.append(0) + else: + gates.append(1) + else: + slot_values.append("none") + gates.append(0) + + # dialgoue act (exclude domain) + if turn["turn_idx"] == 0 and turn["system_transcript"] == "": + cur_sys_acts = set() + elif str(turn["turn_idx"]) not in dialog_act[dial_dict["dialogue_idx"].replace(".json", "")].keys(): + cur_sys_acts = set() + elif dialog_act[dial_dict["dialogue_idx"].replace(".json", "")][str(turn["turn_idx"])] == "No Annotation": + cur_sys_acts = set() + else: + cur_sys_acts = dialog_act[dial_dict["dialogue_idx"].replace(".json", "")][str(turn["turn_idx"])] + + if domain_act_flag: + cur_sys_acts = set([key.lower() for key in cur_sys_acts.keys()]) + else: + cur_sys_acts = set([key.split("-")[1].lower() for key in cur_sys_acts.keys()]) + + data_detail = get_input_example("turn") + data_detail["slots"] = SLOTS + data_detail["ID"] = dial_dict["dialogue_idx"] + data_detail["turn_id"] = turn["turn_idx"] + data_detail["domains"] = dial_dict["domains"] + data_detail["turn_domain"] = turn["domain"] + data_detail["turn_usr"] = turn["transcript"].strip() + data_detail["turn_sys"] = turn["system_transcript"].strip() + data_detail["turn_usr_delex"] = turn["transcript_delex"].strip() + data_detail["turn_sys_delex"] = turn["system_transcript_delex"].strip() + data_detail["belief_state_vec"] = ast.literal_eval(turn["belief_state_vec"]) + data_detail["db_pointer"] = ast.literal_eval(turn["db_pointer"]) + data_detail["dialog_history"] = list(dialog_history) + data_detail["dialog_history_delex"] = list(dialog_history_delex) + data_detail["belief"] = belief_dict + data_detail["del_belief"] = turn_slot_dict + data_detail["slot_gate"] = gates + data_detail["slot_values"] = slot_values + data_detail["sys_act"] = cur_sys_acts + data_detail["turn_slot"] = turn_slot + + if not args["only_last_turn"]: + data.append(data_detail) + + dialog_history.append(turn["system_transcript"]) + dialog_history.append(turn["transcript"]) + dialog_history_delex.append(turn["system_transcript_delex"]) + dialog_history_delex.append(turn["transcript_delex"]) + response_candidates.add(str(data_detail["turn_sys"])) + + if args["only_last_turn"]: + data.append(data_detail) + + cnt_lin += 1 + if(max_line and cnt_lin >= max_line): + break + + #print("MultiWOZ domain counter: ", domain_counter) + return data, ontology, response_candidates + + +def read_langs_dial(args, file_name, ontology, dialog_act, max_line = None, domain_act_flag=False, update_ont_flag=False): + print(("Reading from {} for read_langs_dial".format(file_name))) + raise NotImplementedError + + +def get_slot_information(args, ontology): + ontology_domains = dict([(k, v) for k, v in ontology.items() if k.split("-")[0] in EXPERIMENT_DOMAINS]) + ontology_new = collections.OrderedDict() + for k, v in ontology_domains.items(): + name = k.replace(" ","").lower() if ("book" not in k) else k.lower() + + if args["ontology_version"] != "": + v = clean_original_ontology(v) + + ontology_new[name] = {"none":0, "do n't care":1} + for vv in v: + if vv not in ontology_new[name].keys(): + ontology_new[name][vv] = len(ontology_new[name]) + return ontology_new + + +def prepare_data_multiwoz(args): + example_type = args["example_type"] + max_line = args["max_line"] + + version = "2.1" + print("[Info] Using Version", version) + + file_trn = os.path.join(args["data_path"], 'MultiWOZ-{}/train_dials.json'.format(version)) + file_dev = os.path.join(args["data_path"], 'MultiWOZ-{}/dev_dials.json'.format(version)) + file_tst = os.path.join(args["data_path"], 'MultiWOZ-{}/test_dials.json'.format(version)) + + path_to_ontology_mapping = os.path.join(args["data_path"], + "MultiWOZ-{}/ontology-mapping{}.json".format(version, args["ontology_version"])) + + if os.path.exists(path_to_ontology_mapping): + print("[Info] Load from old complete ontology from version {}...".format(args["ontology_version"])) + ontology_mapping = json.load(open(path_to_ontology_mapping, 'r')) + update_ont_flag = False + else: + print("[Info] Creating new ontology for version {}...".format(args["ontology_version"])) + ontology = json.load(open(os.path.join(args["data_path"], "MultiWOZ-{}/ontology.json".format(version)), 'r')) + ontology_mapping = get_slot_information(args, ontology) + update_ont_flag = True + + dialog_act = json.load(open(os.path.join(args["data_path"], "MultiWOZ-{}/dialogue_acts.json".format(version)), 'r')) + + _example_type = "dial" if "dial" in example_type else example_type + + pair_trn, ontology_mapping, resp_cand_trn = globals()["read_langs_{}".format(_example_type)](args, + file_trn, + ontology_mapping, + dialog_act, + max_line, + args["domain_act"], + update_ont_flag) + + pair_dev, ontology_mapping, resp_cand_dev = globals()["read_langs_{}".format(_example_type)](args, + file_dev, + ontology_mapping, + dialog_act, + max_line, + args["domain_act"], + update_ont_flag) + + pair_tst, ontology_mapping, resp_cand_tst = globals()["read_langs_{}".format(_example_type)](args, + file_tst, + ontology_mapping, + dialog_act, + max_line, + args["domain_act"], + update_ont_flag) + + + if not os.path.exists(path_to_ontology_mapping): + print("[Info] Dumping complete ontology...") + json.dump(ontology_mapping, open(path_to_ontology_mapping, "w"), indent=4) + + print("Read %s pairs train from MultiWOZ" % len(pair_trn)) + print("Read %s pairs valid from MultiWOZ" % len(pair_dev)) + print("Read %s pairs test from MultiWOZ" % len(pair_tst)) + + # print('args["task_name"]', args["task_name"]) + + if args["task_name"] == "dst": + meta_data = {"slots":ontology_mapping, "num_labels": len(ontology_mapping)} + + elif args["task_name"] == "turn_domain": + domain_set = set([d["turn_domain"] for d in pair_trn]) + domain_dict = {d:i for i, d in enumerate(domain_set)} + meta_data = {"turn_domain":domain_dict, "num_labels": len(domain_dict)} + + elif args["task_name"] == "turn_slot": + turn_slot_list = [] + for d in pair_trn: + turn_slot_list += d["turn_slot"] + turn_slot_list = list(set(turn_slot_list)) + turn_slot_mapping = {d:i for i, d in enumerate(turn_slot_list)} + meta_data = {"turn_slot":turn_slot_mapping, "num_labels": len(turn_slot_mapping)} + + elif args["task_name"] == "sysact": + act_set = set() + for pair in [pair_tst, pair_dev, pair_trn]: + for p in pair: + if type(p["sys_act"]) == list: + for sysact in p["sys_act"]: + act_set.update(sysact) + else: + act_set.update(p["sys_act"]) + + print("act_set", len(act_set), act_set) + sysact_lookup = {sysact:i for i, sysact in enumerate(act_set)} + meta_data = {"sysact":sysact_lookup, "num_labels":len(act_set)} + + elif args["task_name"] == "rs": + print("resp_cand_trn", len(resp_cand_trn)) + print("resp_cand_dev", len(resp_cand_dev)) + print("resp_cand_tst", len(resp_cand_tst)) + meta_data = {"num_labels":0, "resp_cand_trn": resp_cand_trn} + + else: + meta_data = {"num_labels":0} + + return pair_trn, pair_dev, pair_tst, meta_data + diff --git a/utils/utils_oos_intent.py b/utils/utils_oos_intent.py new file mode 100644 index 0000000..c90d202 --- /dev/null +++ b/utils/utils_oos_intent.py @@ -0,0 +1,55 @@ +import json +import ast +import os +import random +from .utils_function import get_input_example + + +def read_langs(args, dtype, _data, _oos_data): + print(("Reading [OOS Intent] for read_langs {}".format(dtype))) + + data = [] + intent_counter = {} + + for cur_data in [_data, _oos_data]: + for d in cur_data: + sentence, label = d[0], d[1] + + data_detail = get_input_example("turn") + data_detail["ID"] = "OOS-INTENT-{}-{}".format(dtype, len(data)) + data_detail["turn_usr"] = sentence + data_detail["intent"] = label + data.append(data_detail) + + # count number of each label + if label not in intent_counter.keys(): + intent_counter[label] = 0 + intent_counter[label] += 1 + + #print("len of OOS Intent counter: ", len(intent_counter)) + + return data, intent_counter + + +def prepare_data_oos_intent(args): + example_type = args["example_type"] + max_line = args["max_line"] + + file_input = os.path.join(args["data_path"], 'oos-intent/data/data_full.json') + data = json.load(open(file_input, "r")) + + pair_trn, intent_counter_trn = read_langs(args, "trn", data["train"], data["oos_train"]) + pair_dev, intent_counter_dev = read_langs(args, "dev", data["val"], data["oos_val"]) + pair_tst, intent_counter_tst = read_langs(args, "tst", data["test"], data["oos_test"]) + + print("Read %s pairs train from OOS Intent" % len(pair_trn)) + print("Read %s pairs valid from OOS Intent" % len(pair_dev)) + print("Read %s pairs test from OOS Intent" % len(pair_tst)) + + intent_class = list(intent_counter_trn.keys()) + + meta_data = {"intent":intent_class, "num_labels":len(intent_class)} + print("len(intent_class)", len(intent_class)) + + return pair_trn, pair_dev, pair_tst, meta_data + diff --git a/utils/utils_schema.py b/utils/utils_schema.py new file mode 100644 index 0000000..2b85c88 --- /dev/null +++ b/utils/utils_schema.py @@ -0,0 +1,83 @@ +import json +import ast +import collections +import os + +from .utils_function import get_input_example + + +def read_langs_turn(args, dial_files, max_line = None, ds_name=""): + print(("Reading from {} for read_langs_turn".format(ds_name))) + + data = [] + + cnt_lin = 1 + for dial_file in dial_files: + + f_dials = open(dial_file, 'r') + + dials = json.load(f_dials) + + turn_sys = "" + turn_usr = "" + + for dial_dict in dials: + dialog_history = [] + for ti, turn in enumerate(dial_dict["turns"]): + if turn["speaker"] == "USER": + turn_usr = turn["utterance"].lower().strip() + data_detail = get_input_example("turn") + data_detail["ID"] = "{}-{}".format(ds_name, cnt_lin) + data_detail["turn_id"] = ti % 2 + data_detail["turn_usr"] = turn_usr + data_detail["turn_sys"] = turn_sys + data_detail["dialog_history"] = list(dialog_history) + + if (not args["only_last_turn"]): + data.append(data_detail) + + dialog_history.append(turn_sys) + dialog_history.append(turn_usr) + + elif turn["speaker"] == "SYSTEM": + turn_sys = turn["utterance"].lower().strip() + + if args["only_last_turn"]: + data.append(data_detail) + + cnt_lin += 1 + if(max_line and cnt_lin >= max_line): + break + + return data + + +def read_langs_dial(file_name, ontology, dialog_act, max_line = None, domain_act_flag=False): + print(("Reading from {} for read_langs_dial".format(file_name))) + raise NotImplementedError + + + +def prepare_data_schema(args): + ds_name = "Schema" + + example_type = args["example_type"] + max_line = args["max_line"] + + onlyfiles_trn = [os.path.join(args["data_path"], 'dstc8-schema-guided-dialogue/train/{}'.format(f)) for f in os.listdir(os.path.join(args["data_path"], "dstc8-schema-guided-dialogue/train/")) if "dialogues" in f] + onlyfiles_dev = [os.path.join(args["data_path"], 'dstc8-schema-guided-dialogue/dev/{}'.format(f)) for f in os.listdir(os.path.join(args["data_path"],"dstc8-schema-guided-dialogue/dev/")) if "dialogues" in f] + onlyfiles_tst = [os.path.join(args["data_path"], 'dstc8-schema-guided-dialogue/test/{}'.format(f)) for f in os.listdir(os.path.join(args["data_path"], "dstc8-schema-guided-dialogue/test/")) if "dialogues" in f] + + _example_type = "dial" if "dial" in example_type else example_type + pair_trn = globals()["read_langs_{}".format(_example_type)](args, onlyfiles_trn, max_line, ds_name) + pair_dev = globals()["read_langs_{}".format(_example_type)](args, onlyfiles_dev, max_line, ds_name) + pair_tst = globals()["read_langs_{}".format(_example_type)](args, onlyfiles_tst, max_line, ds_name) + + print("Read {} pairs train from {}".format(len(pair_trn), ds_name)) + print("Read {} pairs valid from {}".format(len(pair_dev), ds_name)) + print("Read {} pairs test from {}".format(len(pair_tst), ds_name)) + + meta_data = {"num_labels":0} + + return pair_trn, pair_dev, pair_tst, meta_data + diff --git a/utils/utils_smd.py b/utils/utils_smd.py new file mode 100644 index 0000000..6ea0556 --- /dev/null +++ b/utils/utils_smd.py @@ -0,0 +1,80 @@ +import json +import ast +import collections +import os + +from .utils_function import get_input_example + + +def read_langs_turn(args, file_name, max_line = None, ds_name=""): + print(("Reading from {} for read_langs_turn".format(file_name))) + + data = [] + + with open(file_name) as f: + dials = json.load(f) + + cnt_lin = 1 + for dial_dict in dials: + dialog_history = [] + + turn_usr = "" + turn_sys = "" + for ti, turn in enumerate(dial_dict["dialogue"]): + if turn["turn"] == "driver": + turn_usr = turn["data"]["utterance"].lower().strip() + + data_detail = get_input_example("turn") + data_detail["ID"] = "{}-{}".format(ds_name, cnt_lin) + data_detail["turn_id"] = ti % 2 + data_detail["turn_usr"] = turn_usr + data_detail["turn_sys"] = turn_sys + data_detail["dialog_history"] = list(dialog_history) + + if (not args["only_last_turn"]): + data.append(data_detail) + + dialog_history.append(turn_sys) + dialog_history.append(turn_usr) + elif turn["turn"] == "assistant": + turn_sys = turn["data"]["utterance"].lower().strip() + + if args["only_last_turn"]: + data.append(data_detail) + + cnt_lin += 1 + if(max_line and cnt_lin >= max_line): + break + + return data + + +def read_langs_dial(file_name, ontology, dialog_act, max_line = None, domain_act_flag=False): + print(("Reading from {} for read_langs_dial".format(file_name))) + raise NotImplementedError + + + +def prepare_data_smd(args): + ds_name = "SMD" + + example_type = args["example_type"] + max_line = args["max_line"] + + file_trn = os.path.join(args["data_path"], "kvret/kvret_train_public.json") + file_dev = os.path.join(args["data_path"], "kvret/kvret_dev_public.json") + file_tst = os.path.join(args["data_path"], "kvret/kvret_test_public.json") + + _example_type = "dial" if "dial" in example_type else example_type + pair_trn = globals()["read_langs_{}".format(_example_type)](args, file_trn, max_line, ds_name) + pair_dev = globals()["read_langs_{}".format(_example_type)](args, file_dev, max_line, ds_name) + pair_tst = globals()["read_langs_{}".format(_example_type)](args, file_tst, max_line, ds_name) + + print("Read {} pairs train from {}".format(len(pair_trn), ds_name)) + print("Read {} pairs valid from {}".format(len(pair_dev), ds_name)) + print("Read {} pairs test from {}".format(len(pair_tst), ds_name)) + + meta_data = {"num_labels":0} + + return pair_trn, pair_dev, pair_tst, meta_data + diff --git a/utils/utils_taskmaster.py b/utils/utils_taskmaster.py new file mode 100644 index 0000000..9c7d9af --- /dev/null +++ b/utils/utils_taskmaster.py @@ -0,0 +1,86 @@ +import json +import ast +import collections +import os + +from .utils_function import get_input_example + + +def read_langs_turn(args, dials, ds_name, max_line): + print(("Reading from {} for read_langs_turn".format(ds_name))) + + data = [] + turn_sys = "" + turn_usr = "" + + cnt_lin = 1 + for dial in dials: + dialog_history = [] + for ti, turn in enumerate(dial["utterances"]): + if turn["speaker"] == "USER": + turn_usr = turn["text"].lower().strip() + + data_detail = get_input_example("turn") + data_detail["ID"] = "{}-{}".format(ds_name, cnt_lin) + data_detail["turn_id"] = ti % 2 + data_detail["turn_usr"] = turn_usr + data_detail["turn_sys"] = turn_sys + data_detail["dialog_history"] = list(dialog_history) + + if (not args["only_last_turn"]): + data.append(data_detail) + + dialog_history.append(turn_sys) + dialog_history.append(turn_usr) + elif turn["speaker"] == "ASSISTANT": + turn_sys = turn["text"].lower().strip() + else: + turn_usr += " {}".format(turn["text"]) + + if args["only_last_turn"]: + data.append(data_detail) + + cnt_lin += 1 + if(max_line and cnt_lin >= max_line): + break + + return data + + +def read_langs_dial(file_name, ontology, dialog_act, max_line = None, domain_act_flag=False): + print(("Reading from {} for read_langs_dial".format(file_name))) + + raise NotImplementedError + + + +def prepare_data_taskmaster(args): + ds_name = "TaskMaster" + + example_type = args["example_type"] + max_line = args["max_line"] + + fr_trn_id = open(os.path.join(args["data_path"], 'Taskmaster/TM-1-2019/train-dev-test/train.csv'), 'r') + fr_dev_id = open(os.path.join(args["data_path"], 'Taskmaster/TM-1-2019/train-dev-test/dev.csv'), 'r') + fr_trn_id = fr_trn_id.readlines() + fr_dev_id = fr_dev_id.readlines() + fr_trn_id = [_id.replace("\n", "").replace(",", "") for _id in fr_trn_id] + fr_dev_id = [_id.replace("\n", "").replace(",", "") for _id in fr_dev_id] + + fr_data_woz = open(os.path.join(args["data_path"], 'Taskmaster/TM-1-2019/woz-dialogs.json'), 'r') + fr_data_self = open(os.path.join(args["data_path"], 'Taskmaster/TM-1-2019/self-dialogs.json'), 'r') + dials_all = json.load(fr_data_woz) + json.load(fr_data_self) + + _example_type = "dial" if "dial" in example_type else example_type + pair_trn = globals()["read_langs_{}".format(_example_type)](args, dials_all, ds_name, max_line) + pair_dev = [] + pair_tst = [] + + print("Read {} pairs train from {}".format(len(pair_trn), ds_name)) + print("Read {} pairs valid from {}".format(len(pair_dev), ds_name)) + print("Read {} pairs test from {}".format(len(pair_tst), ds_name)) + + meta_data = {"num_labels":0} + + return pair_trn, pair_dev, pair_tst, meta_data + diff --git a/utils/utils_universal_act.py b/utils/utils_universal_act.py new file mode 100644 index 0000000..be1d40e --- /dev/null +++ b/utils/utils_universal_act.py @@ -0,0 +1,108 @@ +import json +import ast +import os + +from .utils_function import get_input_example + +def read_langs_turn(file_name, max_line = None): + print(("Reading from {} for read_langs_turn".format(file_name))) + + data = [] + domain_counter = {} + + with open(file_name) as f: + dials = json.load(f) + + print("len dials", len(dials)) + + cnt_lin = 1 + for dial_list in dials: + dialog_history = [] + + # Reading data + for ti, turn in enumerate(dial_list): + + sys_first_flag = 1 if (ti==0 and turn["speaker"]=="[SYS]") else 0 + + data_detail = get_input_example("turn") + data_detail["ID"] = turn["conv_id"] + data_detail["dialog_history"] = list(dialog_history) + + if sys_first_flag and ti % 2 == 1: + data_detail["turn_id"] = ti//2 + data_detail["turn_usr"] = turn["raw_text"].strip() + data_detail["turn_sys"] = dial_list[ti-1]["raw_text"].strip() + data_detail["sys_act"] = dial_list[ti-1]["label"] + data.append(data_detail) + dialog_history.append(data_detail["turn_sys"]) + dialog_history.append(data_detail["turn_usr"]) + elif not sys_first_flag and ti % 2 == 0: + data_detail["turn_id"] = (ti+1)//2 + data_detail["turn_usr"] = turn["raw_text"].strip() + data_detail["turn_sys"] = dial_list[ti-1]["raw_text"].strip() if ti > 0 else "" + data_detail["sys_act"] = dial_list[ti-1]["label"] if ti > 0 else [] + data.append(data_detail) + dialog_history.append(data_detail["turn_sys"]) + dialog_history.append(data_detail["turn_usr"]) + + cnt_lin += 1 + if(max_line and cnt_lin >= max_line): + break + + return data + + +def read_langs_dial(file_name, label_dict, max_line = None): + raise NotImplementedError + + +def prepare_data_universal_act_dstc2(args): + example_type = args["example_type"] + max_line = args["max_line"] + + file_trn = os.path.join(args["data_path"], 'universal_dialog_act/dstc2/train.json') + file_dev = os.path.join(args["data_path"], 'universal_dialog_act/dstc2/valid.json') + file_tst = os.path.join(args["data_path"], 'universal_dialog_act/dstc2/test.json') + file_label = os.path.join(args["data_path"], 'universal_dialog_act/dstc2/labels.txt') + #file_label = '/export/home/dialog_datasets/universal_dialog_act/acts.txt' + label_dict = {line.replace("\n", ""):i for i, line in enumerate(open(file_label, "r").readlines())} + + _example_type = "dial" if "dial" in example_type else example_type + pair_trn = globals()["read_langs_{}".format(_example_type)](file_trn, max_line) + pair_dev = globals()["read_langs_{}".format(_example_type)](file_dev, max_line) + pair_tst = globals()["read_langs_{}".format(_example_type)](file_tst, max_line) + + print("Read {} pairs train from {}".format(len(pair_trn), file_trn)) + print("Read {} pairs valid from {}".format(len(pair_dev), file_dev)) + print("Read {} pairs test from {}".format(len(pair_tst), file_tst)) + + meta_data = {"sysact":label_dict, "num_labels":len(label_dict)} + print("meta_data", meta_data) + + return pair_trn, pair_dev, pair_tst, meta_data + + +def prepare_data_universal_act_sim_joint(args): + example_type = args["example_type"] + max_line = args["max_line"] + + file_trn = os.path.join(args["data_path"], 'universal_dialog_act/sim_joint/train.json') + file_dev = os.path.join(args["data_path"], 'universal_dialog_act/sim_joint/valid.json') + file_tst = os.path.join(args["data_path"], 'universal_dialog_act/sim_joint/test.json') + file_label = os.path.join(args["data_path"], 'universal_dialog_act/sim_joint/labels.txt') + #file_label = '/export/home/dialog_datasets/universal_dialog_act/acts.txt' + label_dict = {line.replace("\n", ""):i for i, line in enumerate(open(file_label, "r").readlines())} + + _example_type = "dial" if "dial" in example_type else example_type + pair_trn = globals()["read_langs_{}".format(_example_type)](file_trn, max_line) + pair_dev = globals()["read_langs_{}".format(_example_type)](file_dev, max_line) + pair_tst = globals()["read_langs_{}".format(_example_type)](file_tst, max_line) + + print("Read {} pairs train from {}".format(len(pair_trn), file_trn)) + print("Read {} pairs valid from {}".format(len(pair_dev), file_dev)) + print("Read {} pairs test from {}".format(len(pair_tst), file_tst)) + + meta_data = {"sysact":label_dict, "num_labels":len(label_dict)} + print("meta_data", meta_data) + + return pair_trn, pair_dev, pair_tst, meta_data diff --git a/utils/utils_woz.py b/utils/utils_woz.py new file mode 100644 index 0000000..d237dd3 --- /dev/null +++ b/utils/utils_woz.py @@ -0,0 +1,77 @@ +import json +import ast +import collections +import os + +from .utils_function import get_input_example + + +def read_langs_turn(args, file_name, max_line = None, ds_name=""): + print(("Reading from {} for read_langs_turn".format(file_name))) + + data = [] + + with open(file_name) as f: + dials = json.load(f) + + cnt_lin = 1 + for dial_dict in dials: + dialog_history = [] + + # Reading data + for ti, turn in enumerate(dial_dict["dialogue"]): + assert ti == turn["turn_idx"] + turn_usr = turn["transcript"].lower().strip() + turn_sys = turn["system_transcript"].lower().strip() + + data_detail = get_input_example("turn") + data_detail["ID"] = "{}-{}".format(ds_name, cnt_lin) + data_detail["turn_id"] = turn["turn_idx"] + data_detail["turn_usr"] = turn_usr + data_detail["turn_sys"] = turn_sys + data_detail["dialog_history"] = list(dialog_history) + + if not args["only_last_turn"]: + data.append(data_detail) + + dialog_history.append(turn_sys) + dialog_history.append(turn_usr) + + if args["only_last_turn"]: + data.append(data_detail) + + cnt_lin += 1 + if(max_line and cnt_lin >= max_line): + break + + return data + + +def read_langs_dial(file_name, ontology, dialog_act, max_line = None, domain_act_flag=False): + print(("Reading from {} for read_langs_dial".format(file_name))) + raise NotImplementedError + + +def prepare_data_woz(args): + ds_name = "WOZ" + + example_type = args["example_type"] + max_line = args["max_line"] + + file_trn = os.path.join(args["data_path"], "neural-belief-tracker/data/woz/woz_train_en.json") + file_dev = os.path.join(args["data_path"], "neural-belief-tracker/data/woz/woz_validate_en.json") + file_tst = os.path.join(args["data_path"], "neural-belief-tracker/data/woz/woz_test_en.json") + + _example_type = "dial" if "dial" in example_type else example_type + pair_trn = globals()["read_langs_{}".format(_example_type)](args, file_trn, max_line, ds_name) + pair_dev = globals()["read_langs_{}".format(_example_type)](args, file_dev, max_line, ds_name) + pair_tst = globals()["read_langs_{}".format(_example_type)](args, file_tst, max_line, ds_name) + + print("Read {} pairs train from {}".format(len(pair_trn), ds_name)) + print("Read {} pairs valid from {}".format(len(pair_dev), ds_name)) + print("Read {} pairs test from {}".format(len(pair_tst), ds_name)) + + meta_data = {"num_labels":0} + + return pair_trn, pair_dev, pair_tst, meta_data +