You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I encountered a KeyError when running the data_redundancy.duplicates function on the WN18 dataset. The code works correctly with other datasets (e.g., FB13), but fails with WN18, and with my own KG as well, with the same error.
Here is the code that triggers the error:
from torchkge.utils.datasets import load_wn18
from torchkge.utils import data_redundancy
kg_train, kg_val, kg_test = load_wn18()
dups, reverse_dups = data_redundancy.duplicates(kg_train, kg_val, kg_test)
And the error:
{
"name": "KeyError",
"message": "18",
"stack": "---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[5], line 5
2 from torchkge.utils import data_redundancy
4 kg_train, kg_val, kg_test = load_wn18()
----> 5 dups, reverse_dups = data_redundancy.duplicates(kg_train, kg_val, kg_test)
File ~/anaconda3/envs/benchmark/lib/python3.9/site-packages/torchkge/utils/data_redundancy.py:150, in duplicates(kg_tr, kg_val, kg_te, theta1, theta2, verbose, counts, reverses)
147 iter_ = list(combinations(range(1345), 2))
149 for r1, r2 in tqdm(iter_):
--> 150 a = len(T[r1].intersection(T[r2])) / lengths[r1]
151 b = len(T[r1].intersection(T[r2])) / lengths[r2]
153 if a > theta1 and b > theta2:
KeyError: 18"
}
Potential solution
To my understanding, this can be fixed by modifying:
iter_ = list(combinations(range(1345), 2))
to iter_ = list(combinations(range(kg_tr.n_rel), 2))
The text was updated successfully, but these errors were encountered:
Description
I encountered a KeyError when running the
data_redundancy.duplicates
function on the WN18 dataset. The code works correctly with other datasets (e.g., FB13), but fails with WN18, and with my own KG as well, with the same error.Here is the code that triggers the error:
And the error:
Potential solution
To my understanding, this can be fixed by modifying:
iter_ = list(combinations(range(1345), 2))
to
iter_ = list(combinations(range(kg_tr.n_rel), 2))
The text was updated successfully, but these errors were encountered: