Skip to content

Commit

Permalink
1) Switch distances.py to use bfs_tools
Browse files Browse the repository at this point in the history
2) Add ignore_people option to bfs_tools
3) Fix greedy_around.
  • Loading branch information
sligocki committed Dec 3, 2021
1 parent 18d794b commit 1e2492a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 28 deletions.
4 changes: 2 additions & 2 deletions bfs_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import data_reader


def ConnectionBfs(db, start):
def ConnectionBfs(db, start, ignore_people=frozenset()):
todos = collections.deque()
dists = {}

Expand All @@ -18,7 +18,7 @@ def ConnectionBfs(db, start):
dist = dists[person]
yield (person, dist)
for neigh in db.neighbors_of(person):
if neigh not in dists:
if neigh not in ignore_people and neigh not in dists:
todos.append(neigh)
dists[neigh] = dist + 1

Expand Down
36 changes: 17 additions & 19 deletions distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,24 @@
import random
import time

import bfs_tools
import data_reader
import utils

def get_distances(db, start):
def get_distances(db, start, ignore_people=frozenset()):
"""Get distances to all other items in graph via breadth-first search."""
dists = {start: 0}
queue = collections.deque()
queue.append(start)
dists = {}
total_dist = 0
max_dist = 0
hist_dist = [1]
while queue:
person = queue.popleft()
dist = dists[person]
for neigh in db.neighbors_of(person):
if neigh not in dists:
dists[neigh] = dist + 1
total_dist += dist + 1
max_dist = dist + 1
while len(hist_dist) <= dist + 1:
hist_dist.append(0)
hist_dist[dist + 1] += 1
queue.append(neigh)
hist_dist = collections.defaultdict(int)
for (person, dist) in bfs_tools.ConnectionBfs(db, start, ignore_people):
dists[person] = dist
hist_dist[dist] += 1
max_dist = dist
total_dist += dist
mean_dist = float(total_dist) / len(dists)
return dists, hist_dist, mean_dist, max_dist
hist_dist_list = [hist_dist[i] for i in range(max(hist_dist.keys()) + 1)]
return dists, hist_dist_list, mean_dist, max_dist

def get_mean_dists(db, start):
_, _, mean_dist, max_dist = get_distances(db, start)
Expand All @@ -53,16 +46,21 @@ def enum_user_nums(db, args):
parser.add_argument("--version", help="Data version (defaults to most recent).")
parser.add_argument("--random", action="store_true")
parser.add_argument("--save-distribution-json", help="Save Circle sizes to file.")
parser.add_argument("--ignore-people",
help="Comma separated list of people to ignore in BFS.")
parser.add_argument("wikitree_id", nargs="*")
args = parser.parse_args()

db = data_reader.Database(args.version)
db.load_connections()

ignore_ids = args.ignore_people.split(",")
ignore_nums = frozenset(db.id2num(id) for id in ignore_ids)

circle_sizes = {}
for user_num in enum_user_nums(db, args):
utils.log("Loading distances from", db.num2id(user_num))
dists, hist_dist, mean_dist, max_dist = get_distances(db, user_num)
dists, hist_dist, mean_dist, max_dist = get_distances(db, user_num, ignore_nums)
utils.log(db.num2id(user_num), mean_dist, max_dist)
circle_sizes[db.num2id(user_num)] = hist_dist
utils.log(hist_dist)
Expand Down
19 changes: 12 additions & 7 deletions greedy_around.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@


import argparse
import collections
import sys
import time
Expand Down Expand Up @@ -36,10 +35,16 @@ def greedy_path(db, start, visited):


if __name__ == "__main__":
db = data_reader.Database()
start_id = sys.argv[1]
start = db.id2num(start_id)
parser = argparse.ArgumentParser()
parser.add_argument("start_id", nargs="+")
parser.add_argument("--version", help="Data version (defaults to most recent).")
args = parser.parse_args()

db = data_reader.Database(args.version)
# Load connections into memory so that it's faster to do BFS.
db.load_connections()
print("Searching around", start_id, time.process_time())
greedy_path(db, start, visited=set())

for start_id in args.start_id:
start = db.id2num(start_id)
print("Searching around", start_id, time.process_time())
greedy_path(db, start, visited=set())

0 comments on commit 1e2492a

Please sign in to comment.