diff --git a/forest/benchmarking/distance_measures.py b/forest/benchmarking/distance_measures.py index ab59b137..1bb06f6e 100644 --- a/forest/benchmarking/distance_measures.py +++ b/forest/benchmarking/distance_measures.py @@ -111,7 +111,7 @@ def trace_distance(rho: np.ndarray, sigma: np.ndarray) -> float: :param sigma: Is a dim by dim positive matrix with unit trace. :return: Trace distance which is a scalar. """ - return (0.5) * np.linalg.norm(rho - sigma, 1) + return (0.5) * np.linalg.norm(rho - sigma, 'nuc') def bures_distance(rho: np.ndarray, sigma: np.ndarray) -> float: