-
Notifications
You must be signed in to change notification settings - Fork 129
/
Copy pathreTranE.py
23 lines (21 loc) · 859 Bytes
/
reTranE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from tranE import *
def loadData(str):
fr = open(str)
sArr = [line.strip().split("\t") for line in fr.readlines()]
datArr = [[float(s) for s in line[1][1:-1].split(", ")] for line in sArr]
nameArr = [line[0] for line in sArr]
dic = {}
for name, vec in zip(nameArr, datArr):
dic[name] = vec
return dic
if __name__ == '__main__':
dirEntityVector = "c:\\entityVector.txt"
entityList = loadData(dirEntityVector)
dirRelationVector = "c:\\relationVector.txt"
relationList = loadData(dirRelationVector)
dirTrain = "C:\\data\\train.txt"
tripleNum, tripleList = openTrain(dirTrain)
transE = TransE(entityList, relationList, tripleList, learingRate = 0.001, dim = 30)
transE.transE(100000)
transE.writeRelationVector("c:\\relationVector.txt")
transE.writeEntilyVector("c:\\entityVector.txt")