Skip to content

Commit

Permalink
Added MNIST and Entity-Relation extraction tests
Browse files Browse the repository at this point in the history
Currently work on only a subset of solvers. Closes #4
  • Loading branch information
KareemYousrii committed Sep 3, 2020
1 parent 93d1500 commit 8bc32ec
Show file tree
Hide file tree
Showing 2 changed files with 328 additions and 0 deletions.
129 changes: 129 additions & 0 deletions tests/test_entity_relation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import torch
import torch.nn.functional as F

from pytorch_constraints.constraint import constraint
from pytorch_constraints.sampling_solver import *

ENTITY_TO_ID = {"O":0,"Loc":1, "Org":2, "Peop":3, "Other":4}
REL_TO_ID = {"*":0, "Work_For_arg1":1, "Kill_arg1":2, "OrgBased_In_arg1":3, "Live_In_arg1":4,
"Located_In_arg1":5, "Work_For_arg2":6, "Kill_arg2":7, "OrgBased_In_arg2":8,
"Live_In_arg2":9, "Located_In_arg2":10}

def get_solvers(num_samples):
return [WeightedSamplingSolver(num_samples)]

class NER_Net(torch.nn.Module):
'''Simple Named Entity Recognition model'''

def __init__(self, vocab_size, num_classes, hidden_dim=50, embedding_dim=100):
super().__init__()

self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim

# layers
self.embedding = torch.nn.Embedding(self.vocab_size, self.embedding_dim)
#self.embedding.weight = torch.nn.Parameter(vocab.vectors)
self.embedding.weight.data.uniform_(-1.0, 1.0)

self.lstm = torch.nn.LSTM(self.embedding_dim, self.hidden_dim, batch_first=True)
self.fc = torch.nn.Linear(self.hidden_dim, num_classes)

# Initialize fully connected layer
self.fc.bias.data.fill_(0)
torch.nn.init.xavier_uniform_(self.fc.weight, gain=1)

def forward(self, s):
s = self.embedding(s) # dim: batch_size x batch_max_len x embedding_dim
s, _ = self.lstm(s) # dim: batch_size x batch_max_len x lstm_hidden_dim
s = self.fc(s) # dim: batch_size*batch_max_len x num_tags

return s


class RE_Net(torch.nn.Module):
'''Simple Relation extraction model'''

def __init__(self, vocab_size, num_classes, hidden_dim=50, embedding_dim=100):
super().__init__()

self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim

# layers
self.embedding = torch.nn.Embedding(self.vocab_size, self.embedding_dim)
#self.embedding.weight = torch.nn.Parameter(vocab.vectors)
self.embedding.weight.data.uniform_(-1.0, 1.0)

self.lstm = torch.nn.LSTM(self.embedding_dim, self.hidden_dim, batch_first=True)
self.fc = torch.nn.Linear(self.hidden_dim, num_classes)

# Initialize fully connected layer
self.fc.bias.data.fill_(0)
torch.nn.init.xavier_uniform_(self.fc.weight, gain=1)

def forward(self, s):
s = self.embedding(s) # dim: batch_size x batch_max_len x embedding_dim
s, _ = self.lstm(s) # dim: batch_size x batch_max_len x lstm_hidden_dim
s = self.fc(s) # dim: batch_size*batch_max_len x num_tags

return s

def OrgBasedIn_Org_Loc(ne, re):

arg1 = (re==3).nonzero(as_tuple=False)
arg2 = (re==8).nonzero(as_tuple=False)

return all(ne[arg1] == 2) and all(ne[arg2] == 1)

def train(constraint):

ner = NER_Net(vocab_size=3027, num_classes=len(ENTITY_TO_ID))
re = RE_Net(vocab_size=3027, num_classes=len(REL_TO_ID))

opt = torch.optim.SGD(list(ner.parameters()) + list(re.parameters()), lr=1.0)

tokens, entities, relations = get_data()

for i in range(100):
opt.zero_grad()

ner_logits = ner(tokens)
ner_logits = ner_logits.view(-1, ner_logits.shape[2])

re_logits = re(tokens)
re_logits = re_logits.view(-1, re_logits.shape[2])

re_loss = F.cross_entropy(re_logits, relations.view(-1))
closs = constraint(ner_logits, re_logits)
loss = 0.05 * closs + 10 * re_loss

loss.backward()
opt.step()

return ner, re


def test_entity_relation():

tokens, entities, relations = get_data()
for solver in get_solvers(num_samples=200):

cons = constraint(OrgBasedIn_Org_Loc, solver)
ner, re = train(cons)

re = torch.argmax(torch.softmax(re(tokens).view(-1, 11), dim=-1), dim=-1)
ner = torch.argmax(torch.softmax(ner(tokens).view(-1, 5), dim=-1), dim=-1)

assert (ner[re == 3] == 2).all() and (ner[re == 8] == 1).all()

def get_data():

tokens = torch.tensor([[ 32, 1973, 2272, 15, 3, 0, 0, 5, 0, 389, 0, 12,
7, 823, 4, 2636, 4, 0, 114, 5, 3, 2701, 6]])
entities = torch.LongTensor([0, 0, 0, 0, 0, 0, 2, 0, 4, 0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0])
relations = torch.LongTensor([0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

return tokens, entities, relations
199 changes: 199 additions & 0 deletions tests/test_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from pytorch_constraints.constraint import constraint
from pytorch_constraints.tnorm_solver import *
from pytorch_constraints.sampling_solver import WeightedSamplingSolver
from pytorch_constraints.circuit_solver import SemanticLossCircuitSolver

def get_solvers(num_samples):
return [WeightedSamplingSolver(num_samples), SemanticLossCircuitSolver(), ProductTNormLogicSolver()]

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 20)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)

return x.reshape(10,2)

def train(constraint=None, epoch=100):

net = Net()
X, y = get_mnist_data()
optimizer = optim.Adadelta(net.parameters(), lr=1.0)

for i in range(epoch):

optimizer.zero_grad()

output = net(X)
loss = F.cross_entropy(output[:,1].reshape(1,10), y)

if constraint:
loss += constraint(output)

loss.backward()
optimizer.step()

return net

def only_one(x):
#return sum(x) == 1
return (x[0] == 1 and x[1] == 0 and x[2] == 0 and x[3] == 0 and x[4] == 0 and x[5] == 0 and x[6] == 0 and x[7] == 0 and x[8] == 0 and x[9] == 0) or\
(x[0] == 0 and x[1] == 1 and x[2] == 0 and x[3] == 0 and x[4] == 0 and x[5] == 0 and x[6] == 0 and x[7] == 0 and x[8] == 0 and x[9] == 0) or\
(x[0] == 0 and x[1] == 0 and x[2] == 1 and x[3] == 0 and x[4] == 0 and x[5] == 0 and x[6] == 0 and x[7] == 0 and x[8] == 0 and x[9] == 0) or\
(x[0] == 0 and x[1] == 0 and x[2] == 0 and x[3] == 1 and x[4] == 0 and x[5] == 0 and x[6] == 0 and x[7] == 0 and x[8] == 0 and x[9] == 0) or\
(x[0] == 0 and x[1] == 0 and x[2] == 0 and x[3] == 0 and x[4] == 1 and x[5] == 0 and x[6] == 0 and x[7] == 0 and x[8] == 0 and x[9] == 0) or\
(x[0] == 0 and x[1] == 0 and x[2] == 0 and x[3] == 0 and x[4] == 0 and x[5] == 1 and x[6] == 0 and x[7] == 0 and x[8] == 0 and x[9] == 0) or\
(x[0] == 0 and x[1] == 0 and x[2] == 0 and x[3] == 0 and x[4] == 0 and x[5] == 0 and x[6] == 1 and x[7] == 0 and x[8] == 0 and x[9] == 0) or\
(x[0] == 0 and x[1] == 0 and x[2] == 0 and x[3] == 0 and x[4] == 0 and x[5] == 0 and x[6] == 0 and x[7] == 1 and x[8] == 0 and x[9] == 0) or\
(x[0] == 0 and x[1] == 0 and x[2] == 0 and x[3] == 0 and x[4] == 0 and x[5] == 0 and x[6] == 0 and x[7] == 0 and x[8] == 1 and x[9] == 0) or\
(x[0] == 0 and x[1] == 0 and x[2] == 0 and x[3] == 0 and x[4] == 0 and x[5] == 0 and x[6] == 0 and x[7] == 0 and x[8] == 0 and x[9] == 1)

def test_only_one_mnist():

X, y = get_mnist_data()

for solver in get_solvers(num_samples=200):
only_one_constraint = constraint(only_one, solver)
net = train(only_one_constraint)

assert(torch.argmax(torch.softmax(net(X), dim=-1), dim=-1).sum().item() == 1)

def get_mnist_data():
X = torch.tensor([[[
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, 0.2249, 1.5996, 2.7960, 1.5996, 0.2122, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
0.1867, 2.6051, 2.7833, 2.7833, 2.7833, 2.5924, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 0.2631,
2.4651, 2.7960, 2.7833, 2.6178, 2.5415, 2.7833, 0.3013,
-0.3478, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.2969, 0.3395, 2.4269,
2.7833, 2.7960, 2.7833, 2.1469, 0.6450, 2.7833, 2.7960,
1.1286, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, 1.6505, 2.7833, 2.7833,
2.7833, 2.7960, 2.7833, 2.7833, 0.7977, 1.9814, 2.7960,
1.7014, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, 0.2249, 2.6051, 2.7960, 2.7960,
1.9942, 1.0268, 2.7960, 2.4778, 0.1740, 0.5813, 2.8215,
1.7141, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, 0.1867, 2.6051, 2.7833, 2.7833, 1.8541,
-0.2715, 0.5304, 1.1159, -0.1569, -0.4242, -0.4242, 2.7960,
2.6687, 0.2122, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, 0.0595, 1.6759, 2.7960, 2.5415, 2.2233, 0.6450,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 2.7960,
2.7833, 1.6759, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.3351, 1.8414, 2.7833, 2.6306, 0.4795, -0.1824, -0.0678,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 2.7960,
2.7833, 2.0578, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
0.3013, 2.7833, 2.7833, 0.3777, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 2.7960,
2.7833, 2.0578, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
2.0960, 2.7960, 1.9942, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 2.8215,
2.7960, 2.0705, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 0.5431,
2.7069, 2.7833, 1.0013, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 2.7960,
2.7833, 1.4596, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 0.6577,
2.7833, 2.5033, -0.1060, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.3351, 1.2941, 2.7960,
1.9432, -0.2715, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 0.6577,
2.7833, 2.4142, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.3351, 1.2432, 2.7833, 2.4396,
0.4795, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 0.6577,
2.7833, 1.4214, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, 0.1867, 1.6759, 2.7833, 1.7778, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 0.6704,
2.7960, 2.4396, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, 1.0268, 2.6051, 2.7960, 1.6378, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 0.6577,
2.7833, 2.7451, 1.4341, 0.1867, -0.0551, 0.6577, 1.8414,
2.4396, 2.7960, 2.4142, 1.7014, 0.2886, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, 0.6577,
2.7833, 2.7833, 2.7833, 2.4906, 2.3124, 2.7833, 2.7833,
2.7833, 2.0705, 1.2305, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.0678,
2.1087, 2.7833, 2.7833, 2.7960, 2.7833, 2.7833, 2.5415,
1.4214, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.1060, 1.2050, 2.7833, 2.7960, 2.7833, 1.3705, 0.0467,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242]]]])
y = torch.tensor([0])

return X,y

0 comments on commit 8bc32ec

Please sign in to comment.