forked from populationgenomics/analysis-runner
-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
307 lines (245 loc) · 9.37 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
"""The analysis-runner server, running Hail Batch pipelines on users' behalf."""
# pylint: disable=wrong-import-order
import datetime
import json
import logging
from shlex import quote
import hailtop.batch as hb
from aiohttp import web
from cpg_utils.config import update_dict
from cpg_utils.deploy_config import get_server_config
from cpg_utils.git import prepare_git_job
from cpg_utils.hail_batch import remote_tmpdir
from cpg_utils.storage import get_dataset_bucket_url
from cromwell import add_cromwell_routes
from util import (
DRIVER_IMAGE,
_get_hail_version,
get_analysis_runner_metadata,
get_baseline_run_config,
get_email_from_request,
run_batch_job_and_print_url,
validate_dataset_access,
validate_image,
validate_output_dir,
write_config,
)
logging.basicConfig(level=logging.INFO)
# do it like this so it's easy to disable
USE_GCP_LOGGING = False
if USE_GCP_LOGGING:
import google.cloud.logging # pylint: disable=import-error,no-name-in-module,c-extension-no-member
client = google.cloud.logging.Client()
client.get_default_handler()
client.setup_logging()
routes = web.RouteTableDef()
SUPPORTED_CLOUD_ENVIRONMENTS = {'gcp'}
# pylint: disable=too-many-statements
@routes.post('/')
async def index(request):
"""Main entry point, responds to the web root."""
email = get_email_from_request(request)
# When accessing a missing entry in the params dict, the resulting KeyError
# exception gets translated to a Bad Request error in the try block below.
params = await request.json()
repo = params['repo']
output_prefix = validate_output_dir(params['output'])
dataset = params['dataset']
image = params.get('image') or DRIVER_IMAGE
cpu = params.get('cpu', 1)
memory = params.get('memory', '1G')
preemptible = params.get('preemptible', True)
environment_variables = params.get('environmentVariables')
cloud_environment = 'azure'
ds_config = validate_dataset_access(dataset, email, repo)
access_level = params['accessLevel']
hail_token = ds_config.get(f'{access_level}Token')
if not hail_token:
raise web.HTTPBadRequest(reason=f'Invalid access level "{access_level}"')
is_test = access_level == 'test'
if not is_test and not validate_image(image):
raise web.HTTPBadRequest(reason=f'Invalid image "{image}"')
hail_bucket = get_dataset_bucket_url(dataset, 'hail')
backend = hb.ServiceBackend(
billing_project=dataset,
remote_tmpdir=remote_tmpdir(hail_bucket),
token=hail_token,
)
commit = params['commit']
if not commit or commit == 'HEAD':
raise web.HTTPBadRequest(reason='Invalid commit parameter')
cwd = params['cwd']
script = params['script']
if not script:
raise web.HTTPBadRequest(reason='Invalid script parameter')
if not isinstance(script, list):
raise web.HTTPBadRequest(reason='Script parameter expects an array')
# This metadata dictionary gets stored in the metadata bucket, at the output_dir location.
hail_version = await _get_hail_version()
timestamp = datetime.datetime.now().astimezone().isoformat()
# Prepare the job's configuration and write it to a blob.
run_config = get_baseline_run_config(
environment=cloud_environment,
gcp_project_id='projectId',
dataset=dataset,
access_level=access_level,
output_prefix=output_prefix,
driver=image,
)
if user_config := params.get('config'): # Update with user-specified configs.
update_dict(run_config, user_config)
config_path = write_config(run_config, environment=cloud_environment)
metadata = get_analysis_runner_metadata(
timestamp=timestamp,
dataset=dataset,
user=email,
access_level=access_level,
repo=repo,
commit=commit,
script=' '.join(script),
description=params['description'],
output_prefix=output_prefix,
hailVersion=hail_version,
driver_image=image,
config_path=config_path,
cwd=cwd,
environment=cloud_environment,
)
user_name = email.split('@')[0]
batch_name = f'{user_name} {repo}:{commit}/{" ".join(script)}'
extra_batch_params = {}
if cloud_environment == 'gcp':
extra_batch_params['requester_pays_project'] = environment_config['projectId']
batch = hb.Batch(backend=backend, name=batch_name, **extra_batch_params)
job = batch.new_job(name='driver')
prepare_git_job(
job=job,
organisation=repo.split('/')[0],
repo_name=repo.split('/')[1],
commit=commit,
is_test=is_test,
)
job.image(image)
if cpu:
job.cpu(cpu)
if memory:
job.memory(memory)
job._preemptible = preemptible # pylint: disable=protected-access
# NOTE: Prefer using config variables instead of environment variables.
# In case you need to add an environment variable here, make sure to update the
# cpg_utils.hail_batch.copy_common_env function!
job.env('CPG_CONFIG_PATH', config_path)
if environment_variables:
if not isinstance(environment_variables, dict):
raise ValueError('Expected environment_variables to be dictionary')
invalid_env_vars = [
f'{k}={v}'
for k, v in environment_variables.items()
if not isinstance(v, str)
]
if len(invalid_env_vars) > 0:
raise ValueError(
'Some environment_variables values were not strings, got '
+ ', '.join(invalid_env_vars)
)
for k, v in environment_variables.items():
job.env(k, v)
if cwd:
job.command(f'cd {quote(cwd)}')
job.command(f'which {quote(script[0])} || chmod +x {quote(script[0])}')
# Finally, run the script.
escaped_script = ' '.join(quote(s) for s in script if s)
job.command(escaped_script)
url = run_batch_job_and_print_url(
batch, wait=params.get('wait', False), environment=cloud_environment
)
# Publish the metadata to Pub/Sub.
metadata['batch_url'] = url
# TODO GRS publisher.publish(PUBSUB_TOPIC, json.dumps(metadata).encode('utf-8')).result()
return web.Response(text=f'{url}\n')
@routes.get('/config')
async def config(request):
"""
Generate CPG config, as JSON response
"""
email = get_email_from_request(request)
# When accessing a missing entry in the params dict, the resulting KeyError
# exception gets translated to a Bad Request error in the try block below.
params = await request.json()
output_prefix = validate_output_dir(params['output'])
dataset = params['dataset']
ds_config = validate_dataset_access(dataset, email, 'repo')
image = params.get('image') or DRIVER_IMAGE
access_level = params['accessLevel']
is_test = access_level == 'test'
if not is_test and not validate_image(image):
raise web.HTTPBadRequest(reason=f'Invalid image "{image}"')
# Prepare the job's configuration to return
run_config = get_baseline_run_config(
environment='azure',
gcp_project_id='projectId',
dataset=dataset,
access_level=access_level,
output_prefix=output_prefix,
driver=image,
)
if user_config := params.get('config'): # Update with user-specified configs.
update_dict(run_config, user_config)
return web.Response(
status=200,
body=json.dumps(run_config).encode('utf-8'),
content_type='application/json',
)
add_cromwell_routes(routes)
def prepare_exception_json_response(status_code: int, message: str) -> web.Response:
"""Prepare web.Response for"""
return web.Response(
status=status_code,
body=json.dumps({'message': message, 'success': False}).encode('utf-8'),
content_type='application/json',
)
def prepare_response_from_exception(ex: Exception):
"""Prepare json_response from exception"""
logging.error(f'Request failed with exception: {repr(ex)}')
if isinstance(ex, web.HTTPException):
return prepare_exception_json_response(
status_code=ex.status_code, message=ex.reason
)
if isinstance(ex, KeyError):
keys = ', '.join(ex.args)
return prepare_exception_json_response(
400, f'Missing request parameter: {keys}'
)
if isinstance(ex, ValueError):
return prepare_exception_json_response(400, ', '.join(ex.args))
if hasattr(ex, 'message'):
m = ex.message
else:
m = str(ex)
return prepare_exception_json_response(500, message=m)
async def error_middleware(_, handler):
"""
Constructs middleware handler
First argument is app, but unused in this context
"""
async def middleware_handler(request):
"""
Run handler and catch exceptions and response errors
"""
try:
response = await handler(request)
if isinstance(response, web.HTTPException):
return prepare_response_from_exception(response)
return response
# pylint: disable=broad-except
except Exception as e:
return prepare_response_from_exception(e)
return middleware_handler
async def init_func():
"""Initializes the app."""
app = web.Application(middlewares=[error_middleware])
app = web.Application()
app.add_routes(routes)
return app
if __name__ == '__main__':
web.run_app(init_func())