Skip to content

Commit

Permalink
Merge pull request #24 from anilavakundu/openKE-patch
Browse files Browse the repository at this point in the history
Upgrading OpenKE TensorFlow dependency
  • Loading branch information
svkeerthy authored May 23, 2021
2 parents 84462ee + e26fab8 commit 712a116
Show file tree
Hide file tree
Showing 20 changed files with 858 additions and 449 deletions.
56 changes: 40 additions & 16 deletions seed_embeddings/OpenKE/base/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "Reader.h"
#include "Setting.h"
#include "Test.h"
#include "Valid.h"
#include <cstdlib>
#include <pthread.h>

Expand All @@ -14,6 +15,10 @@ extern "C" void setWorkThreads(INT threads);

extern "C" void setBern(INT con);

extern "C" void setHeadTailCrossSampling(INT con);

extern "C" bool judgeHeadBatch();

extern "C" INT getWorkThreads();

extern "C" INT getEntityTotal();
Expand Down Expand Up @@ -41,6 +46,7 @@ struct Parameter {
INT batchSize;
INT negRate;
INT negRelRate;
INT headBatchFlag;
};

void *getBatch(void *con) {
Expand All @@ -53,6 +59,7 @@ void *getBatch(void *con) {
INT batchSize = para->batchSize;
INT negRate = para->negRate;
INT negRelRate = para->negRelRate;
INT headBatchFlag = para->headBatchFlag;
INT lef, rig;
if (batchSize % workThreads == 0) {
lef = id * (batchSize / workThreads);
Expand All @@ -72,23 +79,39 @@ void *getBatch(void *con) {
batch_y[batch] = 1;
INT last = batchSize;
for (INT times = 0; times < negRate; times++) {
if (bernFlag)
prob = 1000 * right_mean[trainList[i].r] /
(right_mean[trainList[i].r] + left_mean[trainList[i].r]);
if (randd(id) % 1000 < prob) {
batch_h[batch + last] = trainList[i].h;
batch_t[batch + last] =
corrupt_head(id, trainList[i].h, trainList[i].r);
batch_r[batch + last] = trainList[i].r;
if (!crossSamplingFlag) {
if (bernFlag)
prob = 1000 * right_mean[trainList[i].r] /
(right_mean[trainList[i].r] + left_mean[trainList[i].r]);
if (randd(id) % 1000 < prob) {
batch_h[batch + last] = trainList[i].h;
batch_t[batch + last] =
corrupt_head(id, trainList[i].h, trainList[i].r);
batch_r[batch + last] = trainList[i].r;
} else {
batch_h[batch + last] =
corrupt_tail(id, trainList[i].t, trainList[i].r);
;
batch_t[batch + last] = trainList[i].t;
batch_r[batch + last] = trainList[i].r;
}
batch_y[batch + last] = -1;
last += batchSize;
} else {
batch_h[batch + last] =
corrupt_tail(id, trainList[i].t, trainList[i].r);
;
batch_t[batch + last] = trainList[i].t;
batch_r[batch + last] = trainList[i].r;
if (headBatchFlag) {
batch_h[batch + last] =
corrupt_tail(id, trainList[i].t, trainList[i].r);
batch_t[batch + last] = trainList[i].t;
batch_r[batch + last] = trainList[i].r;
} else {
batch_h[batch + last] = trainList[i].h;
batch_t[batch + last] =
corrupt_head(id, trainList[i].h, trainList[i].r);
batch_r[batch + last] = trainList[i].r;
}
batch_y[batch + last] = -1;
last += batchSize;
}
batch_y[batch + last] = -1;
last += batchSize;
}
for (INT times = 0; times < negRelRate; times++) {
batch_h[batch + last] = trainList[i].h;
Expand All @@ -103,7 +126,7 @@ void *getBatch(void *con) {

extern "C" void sampling(INT *batch_h, INT *batch_t, INT *batch_r,
REAL *batch_y, INT batchSize, INT negRate = 1,
INT negRelRate = 0) {
INT negRelRate = 0, INT headBatchFlag = 0) {
pthread_t *pt = (pthread_t *)malloc(workThreads * sizeof(pthread_t));
Parameter *para = (Parameter *)malloc(workThreads * sizeof(Parameter));
for (INT threads = 0; threads < workThreads; threads++) {
Expand All @@ -115,6 +138,7 @@ extern "C" void sampling(INT *batch_h, INT *batch_t, INT *batch_r,
para[threads].batchSize = batchSize;
para[threads].negRate = negRate;
para[threads].negRelRate = negRelRate;
para[threads].headBatchFlag = headBatchFlag;
pthread_create(&pt[threads], NULL, getBatch, (void *)(para + threads));
}
for (INT threads = 0; threads < workThreads; threads++)
Expand Down
1 change: 1 addition & 0 deletions seed_embeddings/OpenKE/base/Reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ INT *testLef, *testRig;
INT *validLef, *validRig;

extern "C" void importTrainFiles() {

printf("The toolkit is importing datasets.\n");
FILE *fin;
int tmp;
Expand Down
3 changes: 3 additions & 0 deletions seed_embeddings/OpenKE/base/Setting.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ extern "C" INT getValidTotal() { return validTotal; }
*/

INT bernFlag = 0;
INT crossSamplingFlag = 0;

extern "C" void setBern(INT con) { bernFlag = con; }

extern "C" void setHeadTailCrossSampling(INT con) { crossSamplingFlag = con; }

#endif
46 changes: 22 additions & 24 deletions seed_embeddings/OpenKE/base/Test.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,27 @@ REAL l3_filter_tot_constrain = 0, l3_tot_constrain = 0, r3_tot_constrain = 0,
r_filter_tot_constrain = 0, r_filter_rank_constrain = 0,
r_rank_constrain = 0, r_filter_reci_rank_constrain = 0,
r_reci_rank_constrain = 0;

extern "C" void initTest() {
lastHead = 0;
lastTail = 0;
l1_filter_tot = 0, l1_tot = 0, r1_tot = 0, r1_filter_tot = 0, l_tot = 0,
r_tot = 0, l_filter_rank = 0, l_rank = 0, l_filter_reci_rank = 0,
l_reci_rank = 0;
l3_filter_tot = 0, l3_tot = 0, r3_tot = 0, r3_filter_tot = 0,
l_filter_tot = 0, r_filter_tot = 0, r_filter_rank = 0, r_rank = 0,
r_filter_reci_rank = 0, r_reci_rank = 0;

l1_filter_tot_constrain = 0, l1_tot_constrain = 0, r1_tot_constrain = 0,
r1_filter_tot_constrain = 0, l_tot_constrain = 0, r_tot_constrain = 0,
l_filter_rank_constrain = 0, l_rank_constrain = 0,
l_filter_reci_rank_constrain = 0, l_reci_rank_constrain = 0;
l3_filter_tot_constrain = 0, l3_tot_constrain = 0, r3_tot_constrain = 0,
r3_filter_tot_constrain = 0, l_filter_tot_constrain = 0,
r_filter_tot_constrain = 0, r_filter_rank_constrain = 0, r_rank_constrain = 0,
r_filter_reci_rank_constrain = 0, r_reci_rank_constrain = 0;
}

extern "C" void getHeadBatch(INT *ph, INT *pt, INT *pr) {
for (INT i = 0; i < entityTotal; i++) {
ph[i] = i;
Expand Down Expand Up @@ -111,10 +132,6 @@ extern "C" void testHead(REAL *con) {
l_reci_rank_constrain += 1.0 / (l_s_constrain + 1);

lastHead++;

printf("l_filter_s: %ld\n", l_filter_s);
printf("%f %f %f %f \n", l_tot / lastHead, l_filter_tot / lastHead,
l_rank / lastHead, l_filter_rank / lastHead);
}

extern "C" void testTail(REAL *con) {
Expand Down Expand Up @@ -185,9 +202,6 @@ extern "C" void testTail(REAL *con) {
r_reci_rank_constrain += 1.0 / (1 + r_s_constrain);

lastTail++;
printf("r_filter_s: %ld\n", r_filter_s);
printf("%f %f %f %f\n", r_tot / lastTail, r_filter_tot / lastTail,
r_rank / lastTail, r_filter_rank / lastTail);
}

extern "C" void test_link_prediction() {
Expand Down Expand Up @@ -307,14 +321,6 @@ extern "C" void getNegTest() {
negTestList[i] = testList[i];
negTestList[i].t = corrupt(testList[i].h, testList[i].r);
}
FILE *fout = fopen((inPath + "test_neg.txt").c_str(), "w");
for (INT i = 0; i < testTotal; i++) {
fprintf(fout, "%ld\t%ld\t%ld\t%ld\n", testList[i].h, testList[i].t,
testList[i].r, INT(1));
fprintf(fout, "%ld\t%ld\t%ld\t%ld\n", negTestList[i].h, negTestList[i].t,
negTestList[i].r, INT(-1));
}
fclose(fout);
}

Triple *negValidList;
Expand All @@ -324,14 +330,6 @@ extern "C" void getNegValid() {
negValidList[i] = validList[i];
negValidList[i].t = corrupt(validList[i].h, validList[i].r);
}
FILE *fout = fopen((inPath + "valid_neg.txt").c_str(), "w");
for (INT i = 0; i < validTotal; i++) {
fprintf(fout, "%ld\t%ld\t%ld\t%ld\n", validList[i].h, validList[i].t,
validList[i].r, INT(1));
fprintf(fout, "%ld\t%ld\t%ld\t%ld\n", negValidList[i].h, negValidList[i].t,
negValidList[i].r, INT(-1));
}
fclose(fout);
}

extern "C" void getTestBatch(INT *ph, INT *pt, INT *pr, INT *nh, INT *nt,
Expand Down Expand Up @@ -359,7 +357,7 @@ extern "C" void getValidBatch(INT *ph, INT *pt, INT *pr, INT *nh, INT *nt,
nr[i] = negValidList[i].r;
}
}
// REAL* relThresh;

REAL threshEntire;
extern "C" void getBestThreshold(REAL *relThresh, REAL *score_pos,
REAL *score_neg) {
Expand Down
78 changes: 78 additions & 0 deletions seed_embeddings/OpenKE/base/Valid.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#ifndef VALID_H
#define VALID_H
#include "Corrupt.h"
#include "Reader.h"
#include "Setting.h"

INT lastValidHead = 0;
INT lastValidTail = 0;

REAL l_valid_filter_tot = 0;
REAL r_valid_filter_tot = 0;

extern "C" void validInit() {
lastValidHead = 0;
lastValidTail = 0;
l_valid_filter_tot = 0;
r_valid_filter_tot = 0;
}

extern "C" void getValidHeadBatch(INT *ph, INT *pt, INT *pr) {
for (INT i = 0; i < entityTotal; i++) {
ph[i] = i;
pt[i] = validList[lastValidHead].t;
pr[i] = validList[lastValidHead].r;
}
}

extern "C" void getValidTailBatch(INT *ph, INT *pt, INT *pr) {
for (INT i = 0; i < entityTotal; i++) {
ph[i] = validList[lastValidTail].h;
pt[i] = i;
pr[i] = validList[lastValidTail].r;
}
}

extern "C" void validHead(REAL *con) {
INT h = validList[lastValidHead].h;
INT t = validList[lastValidHead].t;
INT r = validList[lastValidHead].r;
REAL minimal = con[h];
INT l_filter_s = 0;
for (INT j = 0; j < entityTotal; j++) {
if (j != h) {
REAL value = con[j];
if (value < minimal && !_find(j, t, r))
l_filter_s += 1;
}
}
if (l_filter_s < 10)
l_valid_filter_tot += 1;
lastValidHead++;
}

extern "C" void validTail(REAL *con) {
INT h = validList[lastValidTail].h;
INT t = validList[lastValidTail].t;
INT r = validList[lastValidTail].r;
REAL minimal = con[t];
INT r_filter_s = 0;
for (INT j = 0; j < entityTotal; j++) {
if (j != t) {
REAL value = con[j];
if (value < minimal && !_find(h, j, r))
r_filter_s += 1;
}
}
if (r_filter_s < 10)
r_valid_filter_tot += 1;
lastValidTail++;
}

extern "C" REAL getValidHit10() {
return (l_valid_filter_tot / validTotal + r_valid_filter_tot / validTotal) /
2;
;
}

#endif
Loading

0 comments on commit 712a116

Please sign in to comment.