-
Notifications
You must be signed in to change notification settings - Fork 0
/
SNLI_data.py
113 lines (88 loc) · 3.68 KB
/
SNLI_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import numpy as np
from torch.utils.data import Dataset
import os.path as osp
import os
class Data():
pass
ROOT=os.getcwd()
BERT = Data()
BERT.test = 1
BERT.dev = 1
BERT.train = 2
class SNLI_dataset(Dataset):
def __init__(self,split):
super().__init__()
PATH1 = osp.join(ROOT,'data',split,'contradiction.npy')
PATH2 = osp.join(ROOT,'data',split,'entailment.npy')
self.contra = np.load(PATH1,mmap_mode='r')
self.entail = np.load(PATH2,mmap_mode='r')
def __len__(self):
return (len(self.contra)+len(self.entail))//2
def __getitem__(self,val):
index=val*2
if index<len(self.contra):
text = torch.from_numpy(self.contra[index])
hyp = torch.from_numpy(self.contra[index+1])
y = torch.tensor(1,dtype=torch.float32)
return text,hyp,y
else:
index=index-len(self.contra)
text = torch.from_numpy(self.entail[index])
hyp = torch.from_numpy(self.entail[index+1])
y = torch.tensor(0,dtype=torch.float32)
return text,hyp,y
class SNLI_Transformer(Dataset):
def __init__(self,split):
super().__init__()
self.contra=Data()
self.entail=Data()
self.contra.len=0
self.entail.len=0
for i in range(getattr(BERT,split)):
f_name = "{}/{}/{}/bert_{}_contradiction.npy".format(ROOT,'data_8k',split,i)
attr = "attr_{}".format(i)
setattr(self.contra,attr,np.load(f_name,mmap_mode='r'))
self.contra.len+=len(getattr(self.contra,attr))
f_name = "{}/{}/{}/bert_{}_entailment.npy".format(ROOT,'data_8k',split,i)
setattr(self.entail,attr,np.load(f_name,mmap_mode='r'))
self.entail.len+=len(getattr(self.entail,attr))
for i in range(getattr(BERT,split)):
f_name = "{}/{}/{}/sent_bert_{}_contradiction.npy".format(ROOT,'data_8k',split,i)
attr = "sent_{}".format(i)
setattr(self.contra,attr,np.load(f_name,mmap_mode='r'))
f_name = "{}/{}/{}/sent_bert_{}_entailment.npy".format(ROOT,'data_8k',split,i)
setattr(self.entail,attr,np.load(f_name,mmap_mode='r'))
def __len__(self):
return ((self.contra.len)+(self.entail.len))//2
def get_emb(self,object,index):
rem = index//50000
attr = "attr_{}".format(rem)
index = index%50000
return torch.from_numpy(getattr(object,attr)[index])
def get_length(self,object,index):
rem = index//50000
attr = "sent_{}".format(rem)
index = index%50000
return torch.from_numpy(np.array(getattr(object,attr)[index]))
def __getitem__(self,val):
index=val*2
if index<(self.contra.len):
text = self.get_emb(self.contra,index)
hyp = self.get_emb(self.contra,index+1)
text_len = self.get_length(self.contra,index)
hyp_len = self.get_length(self.contra,index+1)
y = torch.tensor(1,dtype=torch.float32)
text_len = torch.clamp(text_len,1,32)
hyp_len = torch.clamp(hyp_len,1,32)
return text,hyp,y,text_len,hyp_len
else:
index=index-(self.contra.len)
text = self.get_emb(self.entail,index)
hyp = self.get_emb(self.entail,index+1)
text_len = self.get_length(self.entail,index)
hyp_len = self.get_length(self.entail,index+1)
y = torch.tensor(0,dtype=torch.float32)
text_len = torch.clamp(text_len,1,32)
hyp_len = torch.clamp(hyp_len,1,32)
return text,hyp,y,text_len,hyp_len