diff --git a/gilda/api.py b/gilda/api.py
index e46e257..b5aaa1e 100644
--- a/gilda/api.py
+++ b/gilda/api.py
@@ -112,7 +112,6 @@ def annotate(
sent_split_fun=None,
organisms=None,
namespaces=None,
- return_first: bool = True,
context_text: str = None,
):
"""Annotate a given text with Gilda (i.e., do named entity recognition).
@@ -132,18 +131,16 @@ def annotate(
namespaces : list[str], optional
A list of namespaces to pass to the grounder to restrict the matches
to. By default, no restriction is applied.
- return_first :
- If true, only returns the first result. Otherwise, returns all results.
context_text :
A longer span of text that serves as additional context for the text
being annotated for disambiguation purposes.
Returns
-------
- list[tuple[str, ScoredMatch, int, int]]
- A list of tuples of start and end character offsets of the text
- corresponding to the entity, the entity text, and the ScoredMatch
- object corresponding to the entity.
+ list[tuple[str, list[ScoredMatch], int, int]]
+ A list of matches where each match is a tuple consisting of
+ the matches text span, the list of ScoredMatches, and the
+ start and end character offsets of the text span.
"""
from .ner import annotate as _annotate
@@ -153,7 +150,6 @@ def annotate(
sent_split_fun=sent_split_fun,
organisms=organisms,
namespaces=namespaces,
- return_first=return_first,
context_text=context_text,
)
diff --git a/gilda/app/templates/ner_matches.html b/gilda/app/templates/ner_matches.html
index a9ff447..94614a4 100644
--- a/gilda/app/templates/ner_matches.html
+++ b/gilda/app/templates/ner_matches.html
@@ -25,9 +25,9 @@
NER Results
- {% for text, match, start, end in annotations %}
+ {% for text, matches, start, end in annotations %}
- {% set match_curie = match.term.get_curie() %}
+ {% set match_curie = matches[0].term.get_curie() %}
{{start}}-{{end}} |
diff --git a/gilda/ner.py b/gilda/ner.py
index 2340367..75fbe59 100644
--- a/gilda/ner.py
+++ b/gilda/ner.py
@@ -9,7 +9,8 @@
The results are a list of 4-tuples containing:
- the text string matched
-- a :class:`gilda.grounder.ScoredMatch` instance containing the _best_ match
+- a list of :class:`gilda.grounder.ScoredMatch` instances containing a sorted list of matches
+ for the given text span (first one is the best match)
- the position in the text string where the entity starts
- the position in the text string where the entity ends
@@ -69,7 +70,6 @@ def annotate(
sent_split_fun=None,
organisms=None,
namespaces=None,
- return_first: bool = True,
context_text: str = None,
) -> List[Annotation]:
"""Annotate a given text with Gilda.
@@ -91,18 +91,16 @@ def annotate(
namespaces : list[str], optional
A list of namespaces to pass to the grounder to restrict the matches
to. By default, no restriction is applied.
- return_first :
- If true, only returns the first result. Otherwise, returns all results.
context_text :
A longer span of text that serves as additional context for the text
being annotated for disambiguation purposes.
Returns
-------
- list[tuple[str, ScoredMatch, int, int]]
- A list of tuples of start and end character offsets of the text
- corresponding to the entity, the entity text, and the ScoredMatch
- object corresponding to the entity.
+ list[tuple[str, list[ScoredMatch], int, int]]
+ A list of matches where each match is a tuple consisting of
+ the matches text span, the list of ScoredMatches, and the
+ start and end character offsets of the text span.
"""
if grounder is None:
grounder = get_grounder()
@@ -136,22 +134,19 @@ def annotate(
# Find the largest matching span
for span in sorted(applicable_spans, reverse=True):
txt_span = ' '.join(raw_words[idx:idx+span])
- matches = grounder.ground(
- txt_span, context=text if context_text is None else context_text,
- organisms=organisms, namespaces=namespaces,
- )
+ context = text if context_text is None else context_text
+ matches = grounder.ground(txt_span,
+ context=context,
+ organisms=organisms,
+ namespaces=namespaces)
if matches:
start_coord = word_coords[idx]
end_coord = word_coords[idx+span-1] + \
len(raw_words[idx+span-1])
raw_span = ' '.join(raw_words[idx:idx+span])
-
- if return_first:
- matches = [matches[0]]
- for match in matches:
- entities.append(
- (raw_span, match, start_coord, end_coord)
- )
+ entities.append((
+ raw_span, matches, start_coord, end_coord
+ ))
skip_until = idx + span
break
@@ -163,7 +158,7 @@ def get_brat(entities, entity_type="Entity", ix_offset=1, include_text=True):
Parameters
----------
- entities : list[tuple[str, str | ScoredMatch, int, int]]
+ entities : list[tuple[str, str | list[str] | list[ScoredMatch], int, int]]
A list of tuples of entity text, grounded curie, start and end
character offsets in the text corresponding to an entity.
entity_type : str, optional
@@ -184,9 +179,14 @@ def get_brat(entities, entity_type="Entity", ix_offset=1, include_text=True):
"""
brat = []
ix_offset = max(1, ix_offset)
- for idx, (raw_span, curie, start, end) in enumerate(entities, ix_offset):
- if isinstance(curie, ScoredMatch):
- curie = curie.term.get_curie()
+ for idx, (raw_span, curies, start, end) in enumerate(entities, ix_offset):
+ if isinstance(curies, str):
+ curie = curies
+ # Note that here we always that the best match and ignore the rest
+ else:
+ curie = curies[0]
+ if isinstance(curie, ScoredMatch):
+ curie = curie.term.get_curie()
if entity_type != "Entity":
curie += f"; Reading system: {entity_type}"
row = f'T{idx}\t{entity_type} {start} {end}' + (
diff --git a/gilda/tests/test_ner.py b/gilda/tests/test_ner.py
index 13ae67b..55060ac 100644
--- a/gilda/tests/test_ner.py
+++ b/gilda/tests/test_ner.py
@@ -28,14 +28,14 @@ def test_annotate():
assert annotations[6][2:4] == (56, 63) # protein
# Check that the curies are correct
- assert isinstance(annotations[0][1], gilda.ScoredMatch)
- assert annotations[0][1].term.get_curie() == "CHEBI:36080"
- assert annotations[1][1].term.get_curie() == "hgnc:1097"
- assert annotations[2][1].term.get_curie() == "mesh:D010770"
- assert annotations[3][1].term.get_curie() == "hgnc:1097"
- assert annotations[4][1].term.get_curie() == "mesh:D005796"
- assert annotations[5][1].term.get_curie() == "hgnc:1097"
- assert annotations[6][1].term.get_curie() == "CHEBI:36080"
+ assert isinstance(annotations[0][1][0], gilda.ScoredMatch)
+ assert annotations[0][1][0].term.get_curie() == "CHEBI:36080"
+ assert annotations[1][1][0].term.get_curie() == "hgnc:1097"
+ assert annotations[2][1][0].term.get_curie() == "mesh:D010770"
+ assert annotations[3][1][0].term.get_curie() == "hgnc:1097"
+ assert annotations[4][1][0].term.get_curie() == "mesh:D005796"
+ assert annotations[5][1][0].term.get_curie() == "hgnc:1097"
+ assert annotations[6][1][0].term.get_curie() == "CHEBI:36080"
def test_get_brat():
@@ -66,12 +66,12 @@ def test_get_brat():
def test_get_all():
full_text = "This is about ER."
- results = gilda.annotate(full_text, return_first=False)
- assert len(results) > 1
- curies = {
- scored_match.term.get_curie()
- for _, scored_match, _, _ in results
- }
+ results = gilda.annotate(full_text)
+ assert len(results) == 1
+ curies = set()
+ for _, scored_matches, _, _ in results:
+ for scored_match in scored_matches:
+ curies.add(scored_match.term.get_curie())
assert "hgnc:3467" in curies # ESR1
assert "fplx:ESR" in curies
assert "GO:0005783" in curies # endoplasmic reticulum
@@ -82,13 +82,13 @@ def test_context_test():
context_text = "Estrogen receptor (ER) is a protein family."
results = gilda.annotate(text, context_text=context_text)
assert len(results) == 1
- assert results[0][1].term.get_curie() == "fplx:ESR"
+ assert results[0][1][0].term.get_curie() == "fplx:ESR"
assert results[0][0] == "ER"
assert results[0][2:4] == (14, 16)
context_text = "Calcium is released from the ER."
results = gilda.annotate(text, context_text=context_text)
assert len(results) == 1
- assert results[0][1].term.get_curie() == "GO:0005783"
+ assert results[0][1][0].term.get_curie() == "GO:0005783"
assert results[0][0] == "ER"
assert results[0][2:4] == (14, 16)
|