Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError when running data_redundancy.duplicates on WN18 dataset #265

Open
galadrielbriere opened this issue Aug 29, 2024 · 0 comments
Open

Comments

@galadrielbriere
Copy link

galadrielbriere commented Aug 29, 2024

  • TorchKGE version: 0.17.7
  • Python version: 3.9.19
  • Operating System: Ubuntu

Description

I encountered a KeyError when running the data_redundancy.duplicates function on the WN18 dataset. The code works correctly with other datasets (e.g., FB13), but fails with WN18, and with my own KG as well, with the same error.

Here is the code that triggers the error:

from torchkge.utils.datasets import load_wn18
from  torchkge.utils import data_redundancy

kg_train, kg_val, kg_test = load_wn18()
dups, reverse_dups = data_redundancy.duplicates(kg_train, kg_val, kg_test)

And the error:

{
	"name": "KeyError",
	"message": "18",
	"stack": "---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[5], line 5
      2 from  torchkge.utils import data_redundancy
      4 kg_train, kg_val, kg_test = load_wn18()
----> 5 dups, reverse_dups = data_redundancy.duplicates(kg_train, kg_val, kg_test)

File ~/anaconda3/envs/benchmark/lib/python3.9/site-packages/torchkge/utils/data_redundancy.py:150, in duplicates(kg_tr, kg_val, kg_te, theta1, theta2, verbose, counts, reverses)
    147 iter_ = list(combinations(range(1345), 2))
    149 for r1, r2 in tqdm(iter_):
--> 150     a = len(T[r1].intersection(T[r2])) / lengths[r1]
    151     b = len(T[r1].intersection(T[r2])) / lengths[r2]
    153     if a > theta1 and b > theta2:

KeyError: 18"
}

Potential solution

To my understanding, this can be fixed by modifying:

iter_ = list(combinations(range(1345), 2))
to
iter_ = list(combinations(range(kg_tr.n_rel), 2))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant