From 6cc3a39b385bd308a6b13b1d115492fbc48c35da Mon Sep 17 00:00:00 2001 From: gareth Date: Sun, 19 May 2019 17:10:46 +0100 Subject: [PATCH] Further refines progress bar --- cityseer/algos/centrality.py | 9 +++++++-- cityseer/algos/checks.py | 14 ++++++++++---- cityseer/algos/data.py | 18 ++++++++++++++---- tests/algos/test_checks.py | 5 ++--- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/cityseer/algos/centrality.py b/cityseer/algos/centrality.py index 18ebdf12..c16be5bc 100644 --- a/cityseer/algos/centrality.py +++ b/cityseer/algos/centrality.py @@ -166,7 +166,11 @@ def shortest_path_tree( return map_impedance, map_distance, map_pred, cycles -@njit(cache=True) +# cache has to be set to false per Numba issue: +# https://github.com/numba/numba/issues/3555 +# which prevents nested print function from working as intended +# TODO: set to True once resolved +@njit(cache=False) def local_centrality(node_map: np.ndarray, edge_map: np.ndarray, distances: np.ndarray, @@ -236,10 +240,11 @@ def local_centrality(node_map: np.ndarray, betweenness_data = np.full((2, d_n, n), 0.0) # iterate through each vert and calculate the shortest path tree + progress_chunks = int(n / 2000) for src_idx in range(n): # numba no object mode can only handle basic printing - checks.progress_bar(src_idx, n, 20) + checks.progress_bar(src_idx, n, progress_chunks) # only compute for live nodes if not nodes_live[src_idx]: diff --git a/cityseer/algos/checks.py b/cityseer/algos/checks.py index 97b865e4..14b99de6 100644 --- a/cityseer/algos/checks.py +++ b/cityseer/algos/checks.py @@ -5,12 +5,18 @@ def_min_thresh_wt = 0.01831563888873418 - +# cache for parent functions has to be set to false per Numba issue: +# https://github.com/numba/numba/issues/3555 +# which prevents nested print function from working as intended +# TODO: resolve once fixed @njit(cache=True) def progress_bar(current: int, total: int, chunks: int): - if (chunks > total): - raise ValueError('The number of chunks should not exceed the total.') + if chunks < 10: + chunks = 10 + + if chunks > total: + chunks = total def print_msg(hash_count, void_count, percentage): msg = '|' @@ -30,7 +36,7 @@ def print_msg(hash_count, void_count, percentage): print_msg(int(total / step_size), 0, 100) elif (current + 1) % step_size == 0: - percentage = int((current + 1) / total * 100) + percentage = np.round((current + 1) / total * 100, 2) hash_count = int((current + 1) / step_size) void_count = int(total / step_size - hash_count) print_msg(hash_count, void_count, percentage) diff --git a/cityseer/algos/data.py b/cityseer/algos/data.py index 6e94e85f..31ef857b 100644 --- a/cityseer/algos/data.py +++ b/cityseer/algos/data.py @@ -68,7 +68,11 @@ def find_nearest(src_x: float, src_y: float, x_arr: np.ndarray, y_arr: np.ndarra return min_idx, min_dist -@njit(cache=True) +# cache has to be set to false per Numba issue: +# https://github.com/numba/numba/issues/3555 +# which prevents nested print function from working as intended +# TODO: set to True once resolved +@njit(cache=False) def assign_to_network(data_map: np.ndarray, node_map: np.ndarray, edge_map: np.ndarray, @@ -196,9 +200,10 @@ def closest_intersections(d_coords, pr_map, end_node): data_y_arr = data_map[:, 1] total_count = len(data_map) + progress_chunks = int(total_count / 2000) for data_idx in range(total_count): - checks.progress_bar(data_idx, total_count, 20) + checks.progress_bar(data_idx, total_count, progress_chunks) # find the nearest network node min_idx, min_dist = find_nearest(data_x_arr[data_idx], data_y_arr[data_idx], netw_x_arr, netw_y_arr, max_dist) @@ -416,7 +421,11 @@ def aggregate_to_src_idx(src_idx: int, return reachable_data_idx, reachable_data_dist, data_trim_to_full_idx_map -@njit(cache=True) +# cache has to be set to false per Numba issue: +# https://github.com/numba/numba/issues/3555 +# which prevents nested print function from working as intended +# TODO: set to True once resolved +@njit(cache=False) def local_aggregator(node_map: np.ndarray, edge_map: np.ndarray, data_map: np.ndarray, @@ -570,9 +579,10 @@ def disp_check(disp_matrix): stats_min = np.full((n_n, d_n, netw_n), np.nan) # iterate through each vert and aggregate + progress_chunks = int(netw_n / 2000) for src_idx in range(netw_n): - checks.progress_bar(src_idx, netw_n, 20) + checks.progress_bar(src_idx, netw_n, progress_chunks) # only compute for live nodes if not netw_nodes_live[src_idx]: diff --git a/tests/algos/test_checks.py b/tests/algos/test_checks.py index ffef4ec8..e8883bff 100644 --- a/tests/algos/test_checks.py +++ b/tests/algos/test_checks.py @@ -10,9 +10,8 @@ def test_progress_bar(): for n, chunks in zip([1, 10, 100], [1, 3, 10]): for i in range(n): checks.progress_bar(i, n, chunks) - # check that chunks > total raises - with pytest.raises(ValueError): - checks.progress_bar(i, 10, 20) + # check that chunks > total doesn't raise + checks.progress_bar(10, 10, 20) def test_check_numerical_data():