-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathmysql_to_s3_operator.py
150 lines (131 loc) · 5.97 KB
/
mysql_to_s3_operator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from airflow.models import BaseOperator
from airflow.hooks.S3_hook import S3Hook
from mysql_plugin.hooks.astro_mysql_hook import AstroMySqlHook
from airflow.utils.decorators import apply_defaults
import json
import logging
class MySQLToS3Operator(BaseOperator):
"""
MySQL to Spreadsheet Operator
NOTE: When using the MySQLToS3Operator, it is necessary to set the cursor
to "dictcursor" in the MySQL connection settings within "Extra"
(e.g.{"cursor":"dictcursor"}). To avoid invalid characters, it is also
recommended to specify the character encoding (e.g {"charset":"utf8"}).
NOTE: Because this operator accesses a single database via concurrent
connections, it is advised that a connection pool be used to control
requests. - https://airflow.incubator.apache.org/concepts.html#pools
:param mysql_conn_id: The input mysql connection id.
:type mysql_conn_id: string
:param mysql_table: The input MySQL table to pull data from.
:type mysql_table: string
:param aws_conn_id: The destination s3 connection id.
:type aws_conn_id: string
:param s3_bucket: The destination s3 bucket.
:type s3_bucket: string
:param s3_key: The destination s3 key.
:type s3_key: string
:param package_schema: *(optional)* Whether or not to pull the
schema information for the table as well as
the data.
:type package_schema: boolean
:param incremental_key: *(optional)* The incrementing key to filter
the source data with. Currently only
accepts a column with type of timestamp.
:type incremental_key: string
:param start: *(optional)* The start date to filter
records with based on the incremental_key.
Only required if using the incremental_key
field.
:type start: timestamp (YYYY-MM-DD HH:MM:SS)
:param end: *(optional)* The end date to filter
records with based on the incremental_key.
Only required if using the incremental_key
field.
:type end: timestamp (YYYY-MM-DD HH:MM:SS)
"""
template_fields = ['start', 'end', 's3_key']
@apply_defaults
def __init__(self,
mysql_conn_id,
mysql_table,
aws_conn_id,
s3_bucket,
s3_key,
package_schema=False,
incremental_key=None,
start=None,
end=None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.mysql_conn_id = mysql_conn_id
self.mysql_table = mysql_table
self.aws_conn_id = aws_conn_id
self.s3_bucket = s3_bucket
self.s3_key = s3_key
self.package_schema = package_schema
self.incremental_key = incremental_key
self.start = start
self.end = end
def execute(self, context):
hook = AstroMySqlHook(self.mysql_conn_id)
self.get_records(hook)
if self.package_schema:
self.get_schema(hook, self.mysql_table)
def get_schema(self, hook, table):
logging.info('Initiating schema retrieval.')
results = list(hook.get_schema(table))
output_array = []
for i in results:
new_dict = {}
new_dict['name']=i['COLUMN_NAME']
new_dict['type']=i['COLUMN_TYPE']
if len(new_dict) == 2:
output_array.append(new_dict)
self.s3_upload(json.dumps(output_array), schema=True)
def get_records(self, hook):
logging.info('Initiating record retrieval.')
logging.info('Start Date: {0}'.format(self.start))
logging.info('End Date: {0}'.format(self.end))
if all([self.incremental_key, self.start, self.end]):
query_filter = """ WHERE {0} >= '{1}' AND {0} < '{2}'
""".format(self.incremental_key, self.start, self.end)
if all([self.incremental_key, self.start]) and not self.end:
query_filter = """ WHERE {0} >= '{1}'
""".format(self.incremental_key, self.start)
if not self.incremental_key:
query_filter = ''
query = \
"""
SELECT *
FROM {0}
{1}
""".format(self.mysql_table, query_filter)
# Perform query and convert returned tuple to list
results = list(hook.get_records(query))
logging.info('Successfully performed query.')
# Iterate through list of dictionaries (one dict per row queried)
# and convert datetime and date values to isoformat.
# (e.g. datetime(2017, 08, 01) --> "2017-08-01T00:00:00")
results = [dict([k, str(v)] if v is not None else [k, v]
for k, v in i.items()) for i in results]
results = '\n'.join([json.dumps(i) for i in results])
self.s3_upload(results)
return results
def s3_upload(self, results, schema=False):
s3 = S3Hook(aws_conn_id=self.aws_conn_id)
key = '{0}'.format(self.s3_key)
# If the file being uploaded to s3 is a schema, append "_schema" to the
# end of the file name.
if schema and key[-5:] == '.json':
key = key[:-5] + '_schema' + key[-5:]
if schema and key[-4:] == '.csv':
key = key[:-4] + '_schema' + key[-4:]
s3.load_string(
string_data=results,
bucket_name=self.s3_bucket,
key=key,
replace=True
)
s3.connection.close()
logging.info('File uploaded to s3')