Skip to content

Commit

Permalink
Make annotation return a list of scored matches
Browse files Browse the repository at this point in the history
  • Loading branch information
bgyori committed Jul 14, 2024
1 parent 52c8b79 commit b5487ba
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 49 deletions.
12 changes: 4 additions & 8 deletions gilda/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def annotate(
sent_split_fun=None,
organisms=None,
namespaces=None,
return_first: bool = True,
context_text: str = None,
):
"""Annotate a given text with Gilda (i.e., do named entity recognition).
Expand All @@ -132,18 +131,16 @@ def annotate(
namespaces : list[str], optional
A list of namespaces to pass to the grounder to restrict the matches
to. By default, no restriction is applied.
return_first :
If true, only returns the first result. Otherwise, returns all results.
context_text :
A longer span of text that serves as additional context for the text
being annotated for disambiguation purposes.
Returns
-------
list[tuple[str, ScoredMatch, int, int]]
A list of tuples of start and end character offsets of the text
corresponding to the entity, the entity text, and the ScoredMatch
object corresponding to the entity.
list[tuple[str, list[ScoredMatch], int, int]]
A list of matches where each match is a tuple consisting of
the matches text span, the list of ScoredMatches, and the
start and end character offsets of the text span.
"""
from .ner import annotate as _annotate

Expand All @@ -153,7 +150,6 @@ def annotate(
sent_split_fun=sent_split_fun,
organisms=organisms,
namespaces=namespaces,
return_first=return_first,
context_text=context_text,
)

Expand Down
4 changes: 2 additions & 2 deletions gilda/app/templates/ner_matches.html
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ <h3 class="panel-title">NER Results</h3>
</tr>
</thead>
<tbody>
{% for text, match, start, end in annotations %}
{% for text, matches, start, end in annotations %}
<tr>
{% set match_curie = match.term.get_curie() %}
{% set match_curie = matches[0].term.get_curie() %}
<td>{{start}}-{{end}}</td>
<td>
<a class="label label-primary" href="{{ match['url'] }}">
Expand Down
46 changes: 23 additions & 23 deletions gilda/ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
The results are a list of 4-tuples containing:
- the text string matched
- a :class:`gilda.grounder.ScoredMatch` instance containing the _best_ match
- a list of :class:`gilda.grounder.ScoredMatch` instances containing a sorted list of matches
for the given text span (first one is the best match)
- the position in the text string where the entity starts
- the position in the text string where the entity ends
Expand Down Expand Up @@ -69,7 +70,6 @@ def annotate(
sent_split_fun=None,
organisms=None,
namespaces=None,
return_first: bool = True,
context_text: str = None,
) -> List[Annotation]:
"""Annotate a given text with Gilda.
Expand All @@ -91,18 +91,16 @@ def annotate(
namespaces : list[str], optional
A list of namespaces to pass to the grounder to restrict the matches
to. By default, no restriction is applied.
return_first :
If true, only returns the first result. Otherwise, returns all results.
context_text :
A longer span of text that serves as additional context for the text
being annotated for disambiguation purposes.
Returns
-------
list[tuple[str, ScoredMatch, int, int]]
A list of tuples of start and end character offsets of the text
corresponding to the entity, the entity text, and the ScoredMatch
object corresponding to the entity.
list[tuple[str, list[ScoredMatch], int, int]]
A list of matches where each match is a tuple consisting of
the matches text span, the list of ScoredMatches, and the
start and end character offsets of the text span.
"""
if grounder is None:
grounder = get_grounder()
Expand Down Expand Up @@ -136,22 +134,19 @@ def annotate(
# Find the largest matching span
for span in sorted(applicable_spans, reverse=True):
txt_span = ' '.join(raw_words[idx:idx+span])
matches = grounder.ground(
txt_span, context=text if context_text is None else context_text,
organisms=organisms, namespaces=namespaces,
)
context = text if context_text is None else context_text
matches = grounder.ground(txt_span,
context=context,
organisms=organisms,
namespaces=namespaces)
if matches:
start_coord = word_coords[idx]
end_coord = word_coords[idx+span-1] + \
len(raw_words[idx+span-1])
raw_span = ' '.join(raw_words[idx:idx+span])

if return_first:
matches = [matches[0]]
for match in matches:
entities.append(
(raw_span, match, start_coord, end_coord)
)
entities.append((
raw_span, matches, start_coord, end_coord
))

skip_until = idx + span
break
Expand All @@ -163,7 +158,7 @@ def get_brat(entities, entity_type="Entity", ix_offset=1, include_text=True):
Parameters
----------
entities : list[tuple[str, str | ScoredMatch, int, int]]
entities : list[tuple[str, str | list[str] | list[ScoredMatch], int, int]]
A list of tuples of entity text, grounded curie, start and end
character offsets in the text corresponding to an entity.
entity_type : str, optional
Expand All @@ -184,9 +179,14 @@ def get_brat(entities, entity_type="Entity", ix_offset=1, include_text=True):
"""
brat = []
ix_offset = max(1, ix_offset)
for idx, (raw_span, curie, start, end) in enumerate(entities, ix_offset):
if isinstance(curie, ScoredMatch):
curie = curie.term.get_curie()
for idx, (raw_span, curies, start, end) in enumerate(entities, ix_offset):
if isinstance(curies, str):
curie = curies
# Note that here we always that the best match and ignore the rest
else:
curie = curies[0]
if isinstance(curie, ScoredMatch):
curie = curie.term.get_curie()
if entity_type != "Entity":
curie += f"; Reading system: {entity_type}"
row = f'T{idx}\t{entity_type} {start} {end}' + (
Expand Down
32 changes: 16 additions & 16 deletions gilda/tests/test_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ def test_annotate():
assert annotations[6][2:4] == (56, 63) # protein

# Check that the curies are correct
assert isinstance(annotations[0][1], gilda.ScoredMatch)
assert annotations[0][1].term.get_curie() == "CHEBI:36080"
assert annotations[1][1].term.get_curie() == "hgnc:1097"
assert annotations[2][1].term.get_curie() == "mesh:D010770"
assert annotations[3][1].term.get_curie() == "hgnc:1097"
assert annotations[4][1].term.get_curie() == "mesh:D005796"
assert annotations[5][1].term.get_curie() == "hgnc:1097"
assert annotations[6][1].term.get_curie() == "CHEBI:36080"
assert isinstance(annotations[0][1][0], gilda.ScoredMatch)
assert annotations[0][1][0].term.get_curie() == "CHEBI:36080"
assert annotations[1][1][0].term.get_curie() == "hgnc:1097"
assert annotations[2][1][0].term.get_curie() == "mesh:D010770"
assert annotations[3][1][0].term.get_curie() == "hgnc:1097"
assert annotations[4][1][0].term.get_curie() == "mesh:D005796"
assert annotations[5][1][0].term.get_curie() == "hgnc:1097"
assert annotations[6][1][0].term.get_curie() == "CHEBI:36080"


def test_get_brat():
Expand Down Expand Up @@ -66,12 +66,12 @@ def test_get_brat():

def test_get_all():
full_text = "This is about ER."
results = gilda.annotate(full_text, return_first=False)
assert len(results) > 1
curies = {
scored_match.term.get_curie()
for _, scored_match, _, _ in results
}
results = gilda.annotate(full_text)
assert len(results) == 1
curies = set()
for _, scored_matches, _, _ in results:
for scored_match in scored_matches:
curies.add(scored_match.term.get_curie())
assert "hgnc:3467" in curies # ESR1
assert "fplx:ESR" in curies
assert "GO:0005783" in curies # endoplasmic reticulum
Expand All @@ -82,13 +82,13 @@ def test_context_test():
context_text = "Estrogen receptor (ER) is a protein family."
results = gilda.annotate(text, context_text=context_text)
assert len(results) == 1
assert results[0][1].term.get_curie() == "fplx:ESR"
assert results[0][1][0].term.get_curie() == "fplx:ESR"
assert results[0][0] == "ER"
assert results[0][2:4] == (14, 16)

context_text = "Calcium is released from the ER."
results = gilda.annotate(text, context_text=context_text)
assert len(results) == 1
assert results[0][1].term.get_curie() == "GO:0005783"
assert results[0][1][0].term.get_curie() == "GO:0005783"
assert results[0][0] == "ER"
assert results[0][2:4] == (14, 16)

0 comments on commit b5487ba

Please sign in to comment.