Skip to content

Commit

Permalink
training loop
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPicklePinosaur committed Jun 5, 2022
1 parent 2fcecc3 commit 073195a
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ very simple python chatbot to suck less at nlp

First create venv and install dependencies
```
$ virtualenv venv
$ virtualenv --python=<path to python3.7> venv
$ source venv/bin/activate
$ pip install -r requirements.txt
```
Expand Down
7 changes: 5 additions & 2 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@

class IntentDataset:
from torch.utils.data import Dataset


class IntentDataset(Dataset):

def __init__(self, x, y):
self.x_data = x
Expand All @@ -8,5 +11,5 @@ def __init__(self, x, y):
def __len__(self):
return len(self.x_data)

def __get_item__(self, index):
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
41 changes: 39 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader

# TODO this is prob not needed lol
from pipeop import pipes
import preprocess
from dataset import IntentDataset
from nn import NeuralNet

test_data = [
("advice", "In my younger and more vulnerable years my father gave me some advice that I've been turning over in my mind ever since."),
("critism", "Whenever you feel like criticizing any one, he told me, just remember that all the people in this world haven't had the advantages that you've had."),
("communication", "He didn't say any more but we've always been unusually communicative in a reserved way, and I understood that he meant a great deal more than that.")
]


@pipes
def run():
word_dict = []
Expand All @@ -37,9 +39,11 @@ def run():
x_data = np.array([
preprocess.bag_words(tokenized, word_dict) for (tokenized, tag) in xy
])
y_data = np.array([tag for (tokenized, tag) in xy])
# TODO make this reference index of tag
y_data = np.array([i for i in range(len(test_data))])
dataset = IntentDataset(x_data, y_data)

# build dataloader
batch_size = 8
num_workers = 2
loader = DataLoader(
Expand All @@ -49,4 +53,37 @@ def run():
num_workers=num_workers
)

# build neural net
input_size = len(word_dict)
hidden_size = 8
output_size = 3
device = 'cpu'
model = NeuralNet(input_size, hidden_size, output_size).to(device)

# start training
learning_rate = 0.001
training_epochs = 1000
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(training_epochs):
for (words, labels) in loader:
words = words.to(device)
labels = labels.to(device)

# forward pass
outputs = model(words)
loss = criterion(outputs, labels)

# backwards pass
optimizer.zero_grad()
loss.backward()
optimizer.step()

if epoch % 100 == 0:
print(f'epoch={epoch}/{training_epochs} loss={loss.item():.4f}')

print(f'final loss={loss.item():.4f}')


run()
23 changes: 23 additions & 0 deletions nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
import torch.nn as nn

from pipeop import pipes


class NeuralNet(nn.Module):

def __init__(self, input_size, hidden_size, output_size):
super(NeuralNet, self).__init__()
self.l1 = nn.Linear(input_size, hidden_size)
self.l2 = nn.Linear(hidden_size, hidden_size)
self.l3 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()

@pipes
def forward(self, x):
return (
x
>> self.l1 >> self.relu
>> self.l2 >> self.relu
>> self.l3
)
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
click==8.1.3
joblib==1.1.0
nltk==3.7
numpy==1.22.4
pipeop==0.3.0
regex==2022.6.2
torch==1.11.0
tqdm==4.64.0
typing_extensions==4.2.0

0 comments on commit 073195a

Please sign in to comment.