Skip to content

Latest commit

 

History

History
145 lines (104 loc) · 3.02 KB

File metadata and controls

145 lines (104 loc) · 3.02 KB

TFSimilarity.classification_metrics.F1Score

Calculates the harmonic mean of precision and recall.

Inherits From: ClassificationMetric, ABC

TFSimilarity.classification_metrics.F1Score(
    name: str = f1
) -> None

Computes the F-1 Score given the query classification counts. The metric is computed as follows:

$$ F_1 = 2 \cdot \frac\textrm{precision} \cdot \textrm{recall}}{\textrm{precision} + \textrm{recall} $$

args: name: Name associated with a specific metric object, e.g., [email protected]

Usage with tf.similarity.models.SimilarityModel():

model.calibrate(x=query_examples,
                y=query_labels,
                calibration_metric='f1')

Methods

compute

View source

compute(
    tp: <a href="../../TFSimilarity/callbacks/FloatTensor.md">TFSimilarity.callbacks.FloatTensor```
</a>,
    fp: <a href="../../TFSimilarity/callbacks/FloatTensor.md">TFSimilarity.callbacks.FloatTensor```
</a>,
    tn: <a href="../../TFSimilarity/callbacks/FloatTensor.md">TFSimilarity.callbacks.FloatTensor```
</a>,
    fn: <a href="../../TFSimilarity/callbacks/FloatTensor.md">TFSimilarity.callbacks.FloatTensor```
</a>,
    count: int
) -> <a href="../../TFSimilarity/callbacks/FloatTensor.md">TFSimilarity.callbacks.FloatTensor```
</a>

Compute the classification metric.

The compute() method supports computing the metric for a set of values, where each value represents the counts at a specific distance threshold.

Args
tp A 1D FloatTensor containing the count of True Positives at each distance threshold.
fp A 1D FloatTensor containing the count of False Positives at each distance threshold.
tn A 1D FloatTensor containing the count of True Negatives at each distance threshold.
fn A 1D FloatTensor containing the count of False Negatives at each distance threshold.
count The total number of queries
Returns
A 1D FloatTensor containing the metric at each distance threshold.

get_config

View source

get_config()