From 5a161fd0ebc6050b59e03153b3cd39bb7be96749 Mon Sep 17 00:00:00 2001 From: Marcus P S Date: Thu, 31 Oct 2024 09:10:12 -0700 Subject: [PATCH] Fix #242 trace_distance to use nuclear norm Fix #242 by replacing call to `numpy.linalg.norm(x,1)` with `numpy.linalg.norm(x,'nuc')` --- forest/benchmarking/distance_measures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/forest/benchmarking/distance_measures.py b/forest/benchmarking/distance_measures.py index ab59b137..1bb06f6e 100644 --- a/forest/benchmarking/distance_measures.py +++ b/forest/benchmarking/distance_measures.py @@ -111,7 +111,7 @@ def trace_distance(rho: np.ndarray, sigma: np.ndarray) -> float: :param sigma: Is a dim by dim positive matrix with unit trace. :return: Trace distance which is a scalar. """ - return (0.5) * np.linalg.norm(rho - sigma, 1) + return (0.5) * np.linalg.norm(rho - sigma, 'nuc') def bures_distance(rho: np.ndarray, sigma: np.ndarray) -> float: