From 2f5782947d937f88ba8a59fb7c7885ee909059ac Mon Sep 17 00:00:00 2001 From: Trutnev Aleksei Date: Sat, 25 May 2024 22:42:06 +0300 Subject: [PATCH] more metrics --- lkmeans/examples/main.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lkmeans/examples/main.py b/lkmeans/examples/main.py index d175adc..caa61ec 100644 --- a/lkmeans/examples/main.py +++ b/lkmeans/examples/main.py @@ -5,7 +5,7 @@ import numpy as np from numpy.typing import NDArray -from sklearn.metrics import accuracy_score, adjusted_mutual_info_score, adjusted_rand_score +from sklearn.metrics import accuracy_score, adjusted_mutual_info_score, adjusted_rand_score, completeness_score, homogeneity_score, normalized_mutual_info_score, v_measure_score from tap import Tap from lkmeans.clustering import HardSSLKMeans, LKMeans, SoftSSLKMeans @@ -46,6 +46,10 @@ def calculate_metrics(labels: NDArray, generated_labels: NDArray) -> Dict[str, f return { 'ari': float(adjusted_rand_score(labels, generated_labels)), 'ami': float(adjusted_mutual_info_score(labels, generated_labels)), + 'completeness': float(completeness_score(labels, generated_labels)), + 'homogeneity': float(homogeneity_score(labels, generated_labels)), + 'nmi': float(normalized_mutual_info_score(labels, generated_labels)), + 'v_measure': float(v_measure_score(labels, generated_labels)), 'accuracy': float(accuracy_score(labels, generated_labels)), }