-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata_reader.py
70 lines (60 loc) · 2.45 KB
/
data_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
Single interface for accessing data via either SQLite or CSV into memory.
"""
import collections
from collections.abc import Mapping, Set
import time
import csv_iterate
import sqlite_reader
from sqlite_reader import UserNum
def load_connections(version : str,
include_parents : bool,
include_children : bool,
include_siblings : bool,
include_spouses : bool) -> Mapping[UserNum, Set[UserNum]]:
connections = collections.defaultdict(set)
children_of : Mapping[int, set[int]] = collections.defaultdict(set)
print("Loading people", time.process_time())
num_conns = 0
for i, person in enumerate(csv_iterate.iterate_users(version=version)):
person_num = person.user_num()
for parent_num in (person.father_num(), person.mother_num()):
if parent_num:
if include_parents:
connections[person_num].add(parent_num)
num_conns += 1
if include_children:
connections[parent_num].add(person_num)
num_conns += 1
if include_siblings:
for sibling_num in children_of[parent_num]:
connections[person_num].add(sibling_num)
connections[sibling_num].add(person_num)
num_conns += 2
children_of[parent_num].add(person_num)
if i % 1000000 == 0:
print(" ... {:,}".format(i), "{:,}".format(num_conns), time.process_time())
if include_spouses:
print("Loading marriages", time.process_time())
for marriage in csv_iterate.iterate_marriages(version=version):
user1, user2 = marriage.user_nums()
connections[user1].add(user2)
connections[user2].add(user1)
print("All connections loaded", time.process_time())
return connections
class Database(sqlite_reader.Database):
def __init__(self, version : str) -> None:
super(Database, self).__init__(version)
self.version = version
self.connections : Mapping[UserNum, Set[UserNum]] = {}
def neighbors_of(self, person : UserNum):
if self.connections:
return self.connections[person]
else:
return super(Database, self).neighbors_of(person)
def load_connections(self):
self.connections = load_connections(version=self.version,
include_parents=True,
include_children=True,
include_siblings=True,
include_spouses=True)