Skip to content

Commit

Permalink
probability + more test data
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPicklePinosaur committed Jun 5, 2022
1 parent 789b29a commit 7bec33c
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 13 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ $ python bin/download.py

## TODO

- [ ] properly load intents data from json
- [x] properly load intents data from json
- [ ] add type annotations
- [ ] do some benchmarking (maybe)
- [ ] implement simple user interaction
- [x] implement simple user interaction
- [ ] python linter
- [ ] support regex in intents to generate more testing data
- [ ] support queries inside utterances

## RESOURCES

Expand Down
11 changes: 9 additions & 2 deletions data/intents.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
"utterances": [
"what is the weather currently",
"how's the weather",
"what's it like outside"
"what's it like outside",
"weather",
"can you tell me the weather",
"weather report",
"current weather",
"what's the weather like"
],
"responses": [
]
Expand All @@ -15,7 +20,9 @@
"utterances": [
"what time is it",
"tell me the time",
"what's the time"
"what's the time",
"current time",
"time"
],
"responses": [
]
Expand Down
34 changes: 25 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import torch
from pipeop import pipes

from nn import NeuralNet
import preprocess
Expand All @@ -25,16 +26,31 @@
model.load_state_dict(model_state)
model.eval()

# preprocess query
x = preprocess.tokenize(query)
x = preprocess.stem(x)
x = preprocess.bag_words(x, word_dict)
x = x.reshape(1, x.shape[0])
x = torch.from_numpy(x)

output = model(x)
@pipes
def preprocess_query(query):
x = (query
>> preprocess.tokenize
>> preprocess.stem
>> preprocess.bag_words(word_dict)
)
x = x.reshape(1, x.shape[0])
return torch.from_numpy(x)


preprocessed = preprocess_query(query)
# TODO catch a tensor that is all zero
print(preprocessed)
output = model(preprocessed)
_, predicted = torch.max(output, dim=1)
tag = tags[predicted.item()]

print(tag)
probs = torch.softmax(output, dim=1)
prob = probs[0][predicted.item()]

confidence_threshold = 0.75
if prob.item() > confidence_threshold:
print(f"[prob={prob.item():.4f}] {tag}")
else:
print("query not understood")


0 comments on commit 7bec33c

Please sign in to comment.