-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathcopy_state_db.py
135 lines (109 loc) · 4.19 KB
/
copy_state_db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from bson import Timestamp
import itertools
import json
import sqlite3
import time
CREATE_STATE_SQL = (
"""
CREATE TABLE {table_name}
(
source TEXT NOT NULL,
dest TEXT NOT NULL,
updated_at REAL NOT NULL,
state TEXT NOT NULL,
oplog_ts TEXT DEFAULT NULL,
PRIMARY KEY(source, dest)
)
""")
def _mongo_dict_to_str(d):
if 'id_source' in d:
return d['id_source']['shard_name']
return "%s:%d/%s/%s" % (d['host'], d['port'], d['db'], d['collection'])
def _results_as_dicts(cursor):
"""
given a sqlite cursor, yields results as a dictionary mapping column names to
column values
probably slightly overengineered
"""
results = []
col_names = [d[0] for d in cursor.description]
while True:
rows = cursor.fetchmany()
if not rows:
break
for row in rows:
results.append(dict(itertools.izip(col_names, row)))
return results
class CopyStateDB(object):
"""
contains state of a collection copy in a sqlite3 database, for ease of
use in other code
a separate state file should be used for each sharded collection being copied,
to avoid deleting state should copy_collection.py be run with --restart; if that's
not a concern, share away!
"""
STATE_TABLE = 'state'
STATE_INITIAL_COPY = 'initial copy'
STATE_WAITING_FOR_INDICES = 'waiting for indices'
STATE_APPLYING_OPLOG = 'applying oplog'
def __init__(self, path):
self._conn = sqlite3.connect(path)
self._path = path
def drop_and_create(self):
with self._conn:
cursor = self._conn.cursor()
cursor.execute("DROP TABLE IF EXISTS %s" % self.STATE_TABLE)
cursor.execute(CREATE_STATE_SQL.format(table_name=self.STATE_TABLE))
def add_source_and_dest(self, source, dest):
"""
adds a state entry for the given source and destination, not complaining
if it already exists
assumes source and dest are dict's with these fields: host, port, db, collection
"""
source_str = _mongo_dict_to_str(source)
dest_str = _mongo_dict_to_str(dest)
with self._conn:
cursor = self._conn.cursor()
query = "INSERT OR IGNORE INTO "+self.STATE_TABLE+" "
query += "(source, dest, updated_at, state, oplog_ts) VALUES (?, ?, ?, ?, ?) "
cursor.execute(query,
(source_str, dest_str, time.time(), self.STATE_INITIAL_COPY, None))
def select_by_state(self, state):
cursor = self._conn.cursor()
query = "SELECT * FROM "+self.STATE_TABLE+" WHERE state=?"
cursor.execute(query, (state,))
return _results_as_dicts(cursor)
def update_oplog_ts(self, source, dest, oplog_ts):
"""
updates where we are in applying oplog entries
"""
assert isinstance(oplog_ts, Timestamp)
source_str = _mongo_dict_to_str(source)
dest_str = _mongo_dict_to_str(dest)
oplog_ts_json = json.dumps({'time': oplog_ts.time, 'inc': oplog_ts.inc})
query = "UPDATE "+self.STATE_TABLE+" "
query += "SET oplog_ts = ? "
query += "WHERE source = ? AND dest = ?"
with self._conn:
cursor = self._conn.cursor()
cursor.execute(query, (oplog_ts_json, source_str, dest_str))
def update_state(self, source, dest, state):
source_str = _mongo_dict_to_str(source)
dest_str = _mongo_dict_to_str(dest)
query = "UPDATE "+self.STATE_TABLE+" "
query += "SET state = ? "
query += "WHERE source = ? AND dest = ?"
with self._conn:
cursor = self._conn.cursor()
cursor.execute(query, (state, source_str, dest_str))
def get_oplog_ts(self, source, dest):
source_str = _mongo_dict_to_str(source)
dest_str = _mongo_dict_to_str(dest)
query = "SELECT oplog_ts "
query += "FROM %s " % self.STATE_TABLE
query += "WHERE source = ? AND dest = ?"
with self._conn:
cursor = self._conn.cursor()
cursor.execute(query, (source_str, dest_str))
result = json.loads(cursor.fetchone()[0])
return Timestamp(time=result['time'], inc=result['inc'])