Skip to content

Commit

Permalink
Further refines progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
songololo committed May 19, 2019
1 parent a20aa8b commit 6cc3a39
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 13 deletions.
9 changes: 7 additions & 2 deletions cityseer/algos/centrality.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,11 @@ def shortest_path_tree(
return map_impedance, map_distance, map_pred, cycles


@njit(cache=True)
# cache has to be set to false per Numba issue:
# https://github.com/numba/numba/issues/3555
# which prevents nested print function from working as intended
# TODO: set to True once resolved
@njit(cache=False)
def local_centrality(node_map: np.ndarray,
edge_map: np.ndarray,
distances: np.ndarray,
Expand Down Expand Up @@ -236,10 +240,11 @@ def local_centrality(node_map: np.ndarray,
betweenness_data = np.full((2, d_n, n), 0.0)

# iterate through each vert and calculate the shortest path tree
progress_chunks = int(n / 2000)
for src_idx in range(n):

# numba no object mode can only handle basic printing
checks.progress_bar(src_idx, n, 20)
checks.progress_bar(src_idx, n, progress_chunks)

# only compute for live nodes
if not nodes_live[src_idx]:
Expand Down
14 changes: 10 additions & 4 deletions cityseer/algos/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
def_min_thresh_wt = 0.01831563888873418



# cache for parent functions has to be set to false per Numba issue:
# https://github.com/numba/numba/issues/3555
# which prevents nested print function from working as intended
# TODO: resolve once fixed
@njit(cache=True)
def progress_bar(current: int, total: int, chunks: int):

if (chunks > total):
raise ValueError('The number of chunks should not exceed the total.')
if chunks < 10:
chunks = 10

if chunks > total:
chunks = total

def print_msg(hash_count, void_count, percentage):
msg = '|'
Expand All @@ -30,7 +36,7 @@ def print_msg(hash_count, void_count, percentage):
print_msg(int(total / step_size), 0, 100)

elif (current + 1) % step_size == 0:
percentage = int((current + 1) / total * 100)
percentage = np.round((current + 1) / total * 100, 2)
hash_count = int((current + 1) / step_size)
void_count = int(total / step_size - hash_count)
print_msg(hash_count, void_count, percentage)
Expand Down
18 changes: 14 additions & 4 deletions cityseer/algos/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ def find_nearest(src_x: float, src_y: float, x_arr: np.ndarray, y_arr: np.ndarra
return min_idx, min_dist


@njit(cache=True)
# cache has to be set to false per Numba issue:
# https://github.com/numba/numba/issues/3555
# which prevents nested print function from working as intended
# TODO: set to True once resolved
@njit(cache=False)
def assign_to_network(data_map: np.ndarray,
node_map: np.ndarray,
edge_map: np.ndarray,
Expand Down Expand Up @@ -196,9 +200,10 @@ def closest_intersections(d_coords, pr_map, end_node):
data_y_arr = data_map[:, 1]

total_count = len(data_map)
progress_chunks = int(total_count / 2000)
for data_idx in range(total_count):

checks.progress_bar(data_idx, total_count, 20)
checks.progress_bar(data_idx, total_count, progress_chunks)

# find the nearest network node
min_idx, min_dist = find_nearest(data_x_arr[data_idx], data_y_arr[data_idx], netw_x_arr, netw_y_arr, max_dist)
Expand Down Expand Up @@ -416,7 +421,11 @@ def aggregate_to_src_idx(src_idx: int,
return reachable_data_idx, reachable_data_dist, data_trim_to_full_idx_map


@njit(cache=True)
# cache has to be set to false per Numba issue:
# https://github.com/numba/numba/issues/3555
# which prevents nested print function from working as intended
# TODO: set to True once resolved
@njit(cache=False)
def local_aggregator(node_map: np.ndarray,
edge_map: np.ndarray,
data_map: np.ndarray,
Expand Down Expand Up @@ -570,9 +579,10 @@ def disp_check(disp_matrix):
stats_min = np.full((n_n, d_n, netw_n), np.nan)

# iterate through each vert and aggregate
progress_chunks = int(netw_n / 2000)
for src_idx in range(netw_n):

checks.progress_bar(src_idx, netw_n, 20)
checks.progress_bar(src_idx, netw_n, progress_chunks)

# only compute for live nodes
if not netw_nodes_live[src_idx]:
Expand Down
5 changes: 2 additions & 3 deletions tests/algos/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ def test_progress_bar():
for n, chunks in zip([1, 10, 100], [1, 3, 10]):
for i in range(n):
checks.progress_bar(i, n, chunks)
# check that chunks > total raises
with pytest.raises(ValueError):
checks.progress_bar(i, 10, 20)
# check that chunks > total doesn't raise
checks.progress_bar(10, 10, 20)


def test_check_numerical_data():
Expand Down

0 comments on commit 6cc3a39

Please sign in to comment.