Skip to content

Commit 8b4c2e5

Browse files
authored
Add unified intent/slot filling model (#198)
1 parent 796eb67 commit 8b4c2e5

23 files changed

+1748
-258
lines changed

.github/workflows/build.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
pre-commit:
1616
if: always()
1717
runs-on: ubuntu-latest
18-
timeout-minutes: 4
18+
timeout-minutes: 30
1919
steps:
2020
- uses: actions/checkout@v3
2121
with:
@@ -43,7 +43,7 @@ jobs:
4343
name: "Build and Test Python 3.9"
4444
runs-on: ubuntu-latest
4545
if: always()
46-
timeout-minutes: 20
46+
timeout-minutes: 30
4747

4848
steps:
4949
- uses: actions/checkout@v3

.gitignore

+11
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@ eval/data/extras/*
1010
# pytest coverage
1111
.coverage
1212

13+
# pytorch
14+
lightning_logs/
15+
**/wandb/*
16+
17+
# models
18+
models/
19+
**/checkpoints/*
20+
1321
# Exportes and logs
1422
dialogue_export/
1523
data/users.db
24+
25+
# local experimentation
26+
local/

config/moviebot_config_no_integration.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ TELEGRAM: False # execute the code on Telegram
1919

2020
POLLING: False # True when using Telegram without server
2121

22-
FLASK_SOCKET: False
23-
FLASK_REST: True
22+
FLASK_SOCKET: True
23+
FLASK_REST: False
2424

2525
BOT_TOKEN_PATH: config/bot_token.yaml
2626

data/training/utterances.yaml

+652
Large diffs are not rendered by default.

moviebot/agent/agent.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
11
"""Types of conversational agents are available here."""
22
import logging
33
import os
4-
from typing import Any, Dict, List, Union
4+
from typing import Any, Dict
55

66
from dialoguekit.core import AnnotatedUtterance, Intent, Utterance
77
from dialoguekit.participant import Agent, DialogueParticipant
88

9+
from moviebot.core.core_types import DialogueOptions
910
from moviebot.core.intents.agent_intents import AgentIntents
1011
from moviebot.database.db_movies import DataBase
11-
from moviebot.dialogue_manager.dialogue_act import DialogueAct
1212
from moviebot.dialogue_manager.dialogue_manager import DialogueManager
1313
from moviebot.nlg.nlg import NLG
14-
from moviebot.nlu.nlu import NLU
14+
from moviebot.nlu.rule_based_nlu import RuleBasedNLU as NLU
1515
from moviebot.ontology.ontology import Ontology
1616
from moviebot.recommender.recommender_model import RecommenderModel
1717
from moviebot.recommender.slot_based_recommender_model import (
1818
SlotBasedRecommenderModel,
1919
)
2020

2121
logger = logging.getLogger(__name__)
22-
DialogueOptions = Dict[DialogueAct, Union[str, List[str]]]
2322

2423

2524
def _get_ontology(ontology_path: str) -> Ontology:

moviebot/controller/controller_flask.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from dialoguekit.core import Utterance
88

99
import moviebot.controller.http_data_formatter as http_formatter
10-
from moviebot.agent.agent import DialogueOptions, MovieBotAgent
10+
from moviebot.agent.agent import MovieBotAgent
1111
from moviebot.controller.controller import Controller
12+
from moviebot.core.core_types import DialogueOptions
1213
from moviebot.core.utterance.utterance import UserUtterance
1314

1415

moviebot/controller/controller_terminal.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from dialoguekit.platforms import TerminalPlatform
1111
from questionary.constants import INDICATOR_SELECTED
1212

13-
from moviebot.agent.agent import DialogueOptions, MovieBotAgent
13+
from moviebot.agent.agent import MovieBotAgent
1414
from moviebot.controller.controller import Controller
15+
from moviebot.core.core_types import DialogueOptions
1516

1617

1718
class ControllerTerminal(Controller, TerminalPlatform):

moviebot/controller/http_data_formatter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import asdict, dataclass, field
44
from typing import Any, Dict, List, Tuple
55

6-
from moviebot.agent.agent import DialogueOptions
6+
from moviebot.core.core_types import DialogueOptions
77

88
HTTP_OBJECT_MESSAGE = Dict[str, Dict[str, str]]
99

moviebot/core/core_types.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Core types for the moviebot package."""
2+
from typing import TYPE_CHECKING, Dict, List, Union
3+
4+
if TYPE_CHECKING:
5+
from moviebot.dialogue_manager.dialogue_act import DialogueAct
6+
7+
DialogueOptions = Dict["DialogueAct", Union[str, List[str]]]

moviebot/core/intents/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .agent_intents import AgentIntents
2+
from .user_intents import UserIntents
3+
4+
__all__ = ["UserIntents", "AgentIntents"]

moviebot/core/utterance/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .utterance import AgentUtterance, UserUtterance
2+
3+
__all__ = ["UserUtterance", "AgentUtterance"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .joint_bert import JointBERT
2+
3+
__all__ = ["JointBERT"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
"""Dataset loading for training and evaluating the JointBERT model. """
2+
import os
3+
import re
4+
from typing import Dict, Generator, List, Tuple
5+
6+
import torch
7+
import yaml
8+
from torch.utils.data import Dataset
9+
from transformers import BertTokenizer
10+
11+
from moviebot.nlu.annotation.joint_bert.slot_mapping import (
12+
JointBERTIntent,
13+
JointBERTSlot,
14+
)
15+
16+
DataPoint = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
17+
18+
_IGNORE_INDEX = -100
19+
_TOKENIZER_PATH = "bert-base-uncased"
20+
21+
22+
def load_yaml(path: str) -> Dict[str, List[str]]:
23+
"""Loads the YAML file at the given path.
24+
25+
Args:
26+
path: The path to the YAML file.
27+
28+
Raises:
29+
FileNotFoundError: If the file does not exist.
30+
31+
Returns:
32+
The data in the YAML file.
33+
"""
34+
if not os.path.isfile(path):
35+
raise FileNotFoundError(f"File not found: {path}")
36+
37+
with open(path) as f:
38+
return yaml.safe_load(f)
39+
40+
41+
def parse_data(
42+
data: Dict[str, List[str]]
43+
) -> Generator[Tuple[str, str, List[str]], None, None]:
44+
"""Parses the input data to extract intent, text, and slot annotations.
45+
46+
Args:
47+
data: The input data.
48+
49+
Yields:
50+
A tuple of the intent, text, and slot annotations.
51+
"""
52+
for intent in data.keys():
53+
for annotated_example in data[intent]:
54+
# Extract slot information
55+
slot_annotations = re.findall(
56+
r"\[(.*?)\]\((.*?)\)", annotated_example
57+
)
58+
59+
# Remove slot annotations from the text
60+
clean_text = re.sub(r"\[(.*?)\]\((.*?)\)", r"\1", annotated_example)
61+
62+
yield intent, clean_text, slot_annotations
63+
64+
65+
class JointBERTDataset(Dataset):
66+
def __init__(self, path: str, max_length: int = 32) -> None:
67+
"""Initializes the dataset.
68+
69+
Args:
70+
path: The path to the YAML file containing the data.
71+
max_length: The maximum length of the input sequence. Defaults to
72+
32.
73+
"""
74+
self.data = load_yaml(path)
75+
self.max_length = max_length
76+
77+
self.intent_label_count = len(JointBERTIntent)
78+
self.slot_label_count = len(JointBERTSlot)
79+
80+
self.tokenizer = BertTokenizer.from_pretrained(_TOKENIZER_PATH)
81+
82+
self.examples = []
83+
self._build_dataset()
84+
85+
def _build_dataset(self) -> None:
86+
"""Builds the dataset."""
87+
for intent, clean_text, slot_annotations in parse_data(self.data):
88+
intent, tokens, labels = self._tokenize_and_label(
89+
intent, clean_text, slot_annotations
90+
)
91+
92+
input_ids = self.tokenizer.encode(tokens, add_special_tokens=True)
93+
attention_mask = [1] * len(input_ids)
94+
95+
# Add [CLS] and [SEP] tokens to labels
96+
cls_label = _IGNORE_INDEX
97+
sep_label = _IGNORE_INDEX
98+
labels = [cls_label] + labels + [sep_label]
99+
100+
# Pad input_ids, attention_mask, and labels
101+
padding_length = self.max_length - len(input_ids)
102+
input_ids = input_ids + (
103+
[self.tokenizer.pad_token_id] * padding_length
104+
)
105+
attention_mask = attention_mask + ([0] * padding_length)
106+
labels = labels + ([_IGNORE_INDEX] * padding_length)
107+
self.examples.append((input_ids, attention_mask, intent, labels))
108+
109+
def _num_word_tokens(self, word: str) -> int:
110+
"""Returns the number of word tokens in the input word.
111+
112+
Args:
113+
word: The input word.
114+
115+
Returns:
116+
The number of word tokens in the input word.
117+
"""
118+
return len(self.tokenizer.tokenize(word))
119+
120+
def _tokenize_and_label(
121+
self, intent: str, text: str, slot_annotations: Tuple(str, str)
122+
) -> Tuple[int, List[str], List[int]]:
123+
"""Tokenizes the text and assigns labels based on slot annotations.
124+
125+
The main purpose of this method is to convert the slot annotations into
126+
labels that can be used to train the model. The labels need to have the
127+
same length as the tokenized utterance.
128+
129+
For example:
130+
131+
Input: "I like scifi."
132+
Tokens: ["I", "like", "sci", "##fi", "."]
133+
Labels: ["OUT", "OUT", "B_GENRE", -100, "OUT"]
134+
Indexes: [0, 0, 3, -100, 0]
135+
136+
Note that we put -100 to ignore evaluation of the loss function for
137+
tokens that are not beginning of a slot. This makes it easier to
138+
decode the labels later.
139+
140+
Args:
141+
intent: The intent of the text.
142+
text: The text to tokenize.
143+
slot_annotations: A tuple of slot-value pairs in the text.
144+
145+
Returns:
146+
A tuple of the intent, tokenized text, and labels.
147+
"""
148+
tokens = self.tokenizer.tokenize(text)
149+
labels = []
150+
151+
start_idx = 0
152+
for slot_text, slot_label in slot_annotations:
153+
index = text.find(slot_text)
154+
for word in text[start_idx:index].split():
155+
labels.append(JointBERTSlot.to_index("OUT"))
156+
labels.extend([_IGNORE_INDEX] * self._num_word_tokens(word) - 1)
157+
158+
for i, word in enumerate(slot_text.split()):
159+
labels.append(
160+
JointBERTSlot.to_index(
161+
("B_" if i == 0 else "I_") + slot_label.upper()
162+
)
163+
)
164+
labels.extend([_IGNORE_INDEX] * self._num_word_tokens(word) - 1)
165+
start_idx = index + len(slot_text)
166+
167+
for word in text[start_idx:].split():
168+
labels.append(JointBERTSlot.to_index("OUT"))
169+
labels.extend([_IGNORE_INDEX] * self._num_word_tokens(word) - 1)
170+
assert len(tokens) == len(labels)
171+
return JointBERTIntent.to_index(intent.upper()), tokens, labels
172+
173+
def __len__(self):
174+
"""Returns the number of examples in the dataset."""
175+
return len(self.examples)
176+
177+
def __getitem__(self, idx: int) -> DataPoint:
178+
"""Returns the example at the given index.
179+
180+
Args:
181+
idx: The index of the example to return.
182+
183+
Returns:
184+
A tuple of the input_ids, attention_mask, intent, and labels.
185+
"""
186+
input_ids, attention_mask, intent, labels = self.examples[idx]
187+
188+
return (
189+
torch.tensor(input_ids, dtype=torch.long),
190+
torch.tensor(attention_mask, dtype=torch.long),
191+
torch.tensor(intent, dtype=torch.long),
192+
torch.tensor(labels, dtype=torch.long),
193+
)

0 commit comments

Comments
 (0)