Skip to content

Commit

Permalink
fix tests movie model, the lime output is now slightly different as well
Browse files Browse the repository at this point in the history
  • Loading branch information
loostrum committed May 29, 2024
1 parent b8fa45b commit 2fff22f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
10 changes: 5 additions & 5 deletions tests/methods/test_lime_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def test_lime_text(self):
def test_lime_text_special_chars(self):
"""Tests exact expected output given a text with special characters and model for Lime."""
review = 'such a bad movie "!?\'"'

This comment has been minimized.

Copy link
@elboyran

elboyran May 29, 2024

Contributor

Do we have a test somewhere with a "full stop"? I know the logic is the same as a "!" and "?", just wondering if we need a test with a full stop. Or are the examples in our tutorials enough?

This comment has been minimized.

Copy link
@loostrum

loostrum May 29, 2024

Author Member

The errors that popped up were due to special chars in general, not the full stop specifically. Still, it wouldn't hurt to have a test with a full stop of course.

expected_words = ['bad', '?', '!', 'movie', 'such', 'a', "'", '"', '"']
expected_word_indices = [2, 6, 5, 3, 0, 1, 7, 4, 8]
expected_words = ['bad', 'movie', '?', 'such', '!', "'", '"', 'a', '"']
expected_word_indices = [2, 3, 6, 0, 5, 7, 8, 1, 4]
expected_scores = [
0.50032869, 0.06458735, -0.05793979, 0.01413776, -0.01246357,
-0.00528022, 0.00305347, 0.00185159, -0.00165128
0.51140699, 0.02827488, 0.02657974, -0.02208464, -0.02140743,
0.00962419, 0.00746798, -0.00743376, -0.0012061
]

explanation = dianna.explain_text(self.runner,
Expand All @@ -44,7 +44,7 @@ def test_lime_text_special_chars(self):
labels=[0],
method='LIME',
random_state=42)[0]

print(explanation)
assert_explanation_satisfies_expectations(explanation, expected_scores,
expected_word_indices,
expected_words)
Expand Down
9 changes: 8 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import onnxruntime as ort
import spacy
from scipy.special import expit as sigmoid
from torchtext.vocab import Vectors
Expand Down Expand Up @@ -84,6 +85,10 @@ def __call__(self, sentences):
if isinstance(sentences, str):
sentences = [sentences]

sess = ort.InferenceSession(self.filename)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name

output = []
for sentence in sentences:
# tokenize and pad to minimum length
Expand All @@ -99,7 +104,9 @@ def __call__(self, sentences):
]

# run the model, applying a sigmoid because the model outputs logits, remove any remaining batch axis
pred = float(sigmoid(self.run_model([tokens_numerical])))
onnx_input = {input_name: [tokens_numerical]}
logits = sess.run([output_name], onnx_input)[0]
pred = float(sigmoid(logits))
output.append(pred)

# output two classes
Expand Down

0 comments on commit 2fff22f

Please sign in to comment.