Skip to content

Commit

Permalink
Merge pull request #120 from cthoyt/return-first
Browse files Browse the repository at this point in the history
Enable returning all NER results
  • Loading branch information
bgyori authored Jul 26, 2023
2 parents 125ab65 + 40e1936 commit ddea0cc
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
6 changes: 5 additions & 1 deletion gilda/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def annotate(
sent_split_fun=None,
organisms=None,
namespaces=None,
return_first: bool = True,
):
"""Annotate a given text with Gilda (i.e., do named entity recognition).
Expand All @@ -130,6 +131,8 @@ def annotate(
namespaces : list[str], optional
A list of namespaces to pass to the grounder to restrict the matches
to. By default, no restriction is applied.
return_first:
If true, only returns the first result. Otherwise, returns all results.
Returns
-------
Expand All @@ -145,7 +148,8 @@ def annotate(
grounder=grounder,
sent_split_fun=sent_split_fun,
organisms=organisms,
namespaces=namespaces
namespaces=namespaces,
return_first=return_first,
)


Expand Down
14 changes: 9 additions & 5 deletions gilda/ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def annotate(
sent_split_fun=None,
organisms=None,
namespaces=None,
return_first: bool = True,
) -> List[Annotation]:
"""Annotate a given text with Gilda.
Expand All @@ -89,6 +90,8 @@ def annotate(
namespaces : list[str], optional
A list of namespaces to pass to the grounder to restrict the matches
to. By default, no restriction is applied.
return_first:
If true, only returns the first result. Otherwise, returns all results.
Returns
-------
Expand Down Expand Up @@ -139,11 +142,12 @@ def annotate(
len(raw_words[idx+span-1])
raw_span = ' '.join(raw_words[idx:idx+span])

# Append raw_span, (best) match, start, end
match = matches[0]
entities.append(
(raw_span, match, start_coord, end_coord)
)
if return_first:
matches = [matches[0]]
for match in matches:
entities.append(
(raw_span, match, start_coord, end_coord)
)

skip_until = idx + span
break
Expand Down
15 changes: 14 additions & 1 deletion gilda/tests/test_ner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from textwrap import dedent

import gilda
from gilda.ner import annotate, get_brat
from gilda.ner import get_brat


def test_annotate():
Expand Down Expand Up @@ -62,3 +62,16 @@ def test_get_brat():
#7\tAnnotatorNotes T7\tCHEBI:36080
""").lstrip()
assert brat_str == match_str


def test_get_all():
full_text = "This is about ER."
results = gilda.annotate(full_text, return_first=False)
assert len(results) > 1
curies = {
scored_match.term.get_curie()
for _, scored_match, _, _ in results
}
assert "hgnc:3467" in curies # ESR1
assert "fplx:ESR" in curies
assert "GO:0005783" in curies # endoplasmic reticulum

0 comments on commit ddea0cc

Please sign in to comment.