Skip to content

Commit

Permalink
mypy circles.
Browse files Browse the repository at this point in the history
  • Loading branch information
sligocki committed Feb 27, 2024
1 parent c6808c3 commit c04014a
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 163 deletions.
27 changes: 16 additions & 11 deletions bfs_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,33 @@
Breadth-first search tools for enumerating connections or relatives of a person.
"""
import collections
from collections.abc import Set
from typing import Iterator

import data_reader
from data_reader import UserNum


class BfsNode:
def __init__(self, person, prevs, dist):
def __init__(self, person : UserNum, prevs : list[UserNum],
dist : int) -> None:
# Currently enumerated person.
self.person = person
# List of previous nodes we reached person from.
self.prevs = prevs
# Dist from start to person.
self.dist = dist

def __repr__(self):
def __repr__(self) -> str:
return f"BfsNode({self.person}, {self.prevs}, {self.dist})"

def ConnectionBfs(db, start, ignore_people=frozenset()):
todos = collections.deque()
def ConnectionBfs(db : data_reader.Database,
start : UserNum,
ignore_people : Set[UserNum] = frozenset()
) -> Iterator[BfsNode]:
todos : collections.deque[UserNum] = collections.deque()
todos.append(start)
nodes = {}
nodes[start] = BfsNode(start, [], 0)
nodes = {start: BfsNode(start, [], 0)}
while todos:
person = todos.popleft()
yield nodes[person]
Expand All @@ -39,12 +45,11 @@ def ConnectionBfs(db, start, ignore_people=frozenset()):
nodes[neigh] = BfsNode(neigh, [person], dist + 1)


def RelativeBfs(db, start):
todos = collections.deque()
dists = {}

def RelativeBfs(db : data_reader.Database, start : UserNum
) -> Iterator[tuple[UserNum, tuple[int, int]]]:
todos : collections.deque[UserNum] = collections.deque()
todos.append(start)
dists[start] = (0, 0)
dists = {start: (0, 0)}

while todos:
person = todos.popleft()
Expand Down
23 changes: 15 additions & 8 deletions circles_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,31 @@

import argparse
import collections
from collections.abc import Collection, Set
import itertools
from pathlib import Path
import re
from typing import Iterator

from unidecode import unidecode

import bfs_tools
import category_tools
import circles_tools
import data_reader
from data_reader import UserNum
import utils


def try_id(db, person_num) -> str:
def try_id(db, person_num : UserNum) -> str:
id = db.num2id(person_num)
if id:
return id
else:
return str(person_num)


def get_locations(db, user_num):
def get_locations(db, user_num : UserNum) -> set[str]:
"""Return set of locations referenced by user's birth and death fields."""
locs = set()
for attribute in ["birth_location", "death_location"]:
Expand All @@ -43,10 +47,11 @@ def get_locations(db, user_num):
return locs


def summarize_group(db, category_db, people):
def summarize_group(db : data_reader.Database, category_db : category_tools.CategoryDb,
people : Collection[UserNum]) -> None:
num_people = len(people)
print(f"Summarizing over {num_people} people")
counts = {
counts : dict[str, collections.Counter[str]] = {
"location": collections.Counter(),
"category": collections.Counter(),
"manager": collections.Counter(),
Expand Down Expand Up @@ -76,12 +81,14 @@ def summarize_group(db, category_db, people):
by_index = round(percentile * (len(birth_years) - 1))
print(f" - {percentile:4.0%}-ile: {birth_years[by_index]}")

def load_locs(filename):
def load_locs(filename : Path) -> list[str]:
with open(filename, "r") as f:
return list(line.strip() for line in f)

def iter_closest_each_loc(db, focus, locs):
focus_num = db.get_person_num(focus)
def iter_closest_each_loc(db : data_reader.Database, focus_id : str,
locs : Collection[str]
) -> Iterator[tuple[str, int, UserNum]]:
focus_num = db.get_person_num(focus_id)
remaining_locs = set(locs)
for node in bfs_tools.ConnectionBfs(db, focus_num):
hits = get_locations(db, node.person) & remaining_locs
Expand Down Expand Up @@ -109,7 +116,7 @@ def main():

if args.state:
print("Finding closest person from every US State:")
states = load_locs("data/us_states.txt")
states = load_locs(Path("data/us_states.txt"))
n = 1
for loc, dist, id in iter_closest_each_loc(db, args.focus_id, states):
print(f" {n:3d} {dist:3d} {loc:20s} {id:20}")
Expand Down
87 changes: 45 additions & 42 deletions circles_api_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,45 +11,48 @@
import utils


parser = argparse.ArgumentParser()
parser.add_argument("wikitree_id")
parser.add_argument("--time", action="store_true",
help="Include time-of-day in filename.")
args = parser.parse_args()

utils.log("Loading URL")
params = urllib.parse.urlencode({"WikiTreeID": args.wikitree_id})
url = f"https://wikitree.sdms.si/function/WT100Circles/Tree.json?{params}"
with urllib.request.urlopen(url) as resp:
data_text = resp.read().decode("ascii")

utils.log("Parsing JSON")
try:
data = json.loads(data_text)
except json.decoder.JSONDecodeError:
print("Error while parsing JSON response:")
print(data_text)
raise

try:
utils.log("Extracting circle sizes")
circle_sizes = []
for step in data["debug"]["steps"]:
circle_num, circle_size, cumulative_size, _ = step
circle_sizes.append(circle_size)

utils.log("Writing results")
if args.time:
date = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M")
else:
date = datetime.date.today().strftime("%Y-%m-%d")
filename = Path("results", "circles", f"{args.wikitree_id}_{date}.json")
with open(filename, "w") as f:
json.dump({args.wikitree_id: circle_sizes}, f)

utils.log(f"Wrote {filename}")

except:
print("Error while processing data")
pprint(data)
raise
def main():
parser = argparse.ArgumentParser()
parser.add_argument("wikitree_id")
parser.add_argument("--time", action="store_true",
help="Include time-of-day in filename.")
args = parser.parse_args()

utils.log("Loading URL")
params = urllib.parse.urlencode({"WikiTreeID": args.wikitree_id})
url = f"https://wikitree.sdms.si/function/WT100Circles/Tree.json?{params}"
with urllib.request.urlopen(url) as resp:
data_text = resp.read().decode("ascii")

utils.log("Parsing JSON")
try:
data = json.loads(data_text)
except json.decoder.JSONDecodeError:
print("Error while parsing JSON response:")
print(data_text)
raise

try:
utils.log("Extracting circle sizes")
circle_sizes = []
for step in data["debug"]["steps"]:
circle_num, circle_size, cumulative_size, _ = step
circle_sizes.append(circle_size)

utils.log("Writing results")
if args.time:
date = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M")
else:
date = datetime.date.today().strftime("%Y-%m-%d")
filename = Path("results", "circles", f"{args.wikitree_id}_{date}.json")
with open(filename, "w") as f:
json.dump({args.wikitree_id: circle_sizes}, f)

utils.log(f"Wrote {filename}")

except:
print("Error while processing data")
pprint(data)
raise

main()
5 changes: 3 additions & 2 deletions circles_communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def main():
if size >= args.min_community_size]

utils.log("Find distances to all nodes from focus")
comm_circles = {index: collections.Counter()
for index in sorted_comms}
comm_circles : dict[int | str, collections.Counter[int]] = {
index: collections.Counter()
for index in sorted_comms}
comm_circles["all"] = collections.Counter()

focus_index = names_db.name2index(args.focus)
Expand Down
22 changes: 14 additions & 8 deletions circles_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,47 @@
"""

import argparse
from collections.abc import Collection
from pathlib import Path

import networkx as nx

import circles_tools
import data_reader
from data_reader import UserNum
import graph_tools
import utils


def make_bipartite(db, people):
NodeType = str
EdgeType = tuple[NodeType, NodeType]

def make_bipartite(db : data_reader.Database, people : Collection[UserNum]
) -> tuple[list[NodeType], list[EdgeType]]:
# Like graph_make_bipartite ... but only for a subset of people.
people_nodes = frozenset(people)
union_nodes = set()
edges = set()
people_nodes : frozenset[NodeType] = frozenset(str(p) for p in people)
union_nodes : set[NodeType] = set()
edges : set[EdgeType] = set()

for person in people:
# Add union node for parents if they are known.
parents = db.parents_of(person)
if parents:
union = "Union/" + "/".join(str(p) for p in sorted(parents))
union_nodes.add(union)
edges.add((person, union))
edges.add((str(person), union))
# Make sure parents are also connected to the union.
for parent in parents:
if parent in people_nodes:
edges.add((parent, union))
edges.add((str(parent), union))

# Add union node for all "partners" (spouses / coparents).
for partner in db.partners_of(person):
if partner in people_nodes:
union = "Union/" + "/".join(str(p) for p in sorted([person, partner]))
union_nodes.add(union)
edges.add((person, union))
edges.add((partner, union))
edges.add((str(person), union))
edges.add((str(partner), union))
# Note: We don't explicitly connect all children here.
# If they are in `people`, they will be connected above.

Expand Down
Loading

0 comments on commit c04014a

Please sign in to comment.