Skip to content

Commit

Permalink
etl: allow running individual tasks to update single tables
Browse files Browse the repository at this point in the history
This commit adds a new command line arg (--task) which can be
specified multiple times and/or include a list of comma-separated
task names.

Each task name is simply the output database name (i.e. symptom
for the symptom Athena table or patient, condition, etc).

In order to support this, we need to keep ID anonymization
consistent between separate task runs, because of cross-references.
So we now keep the hashing salt in the codebook.
  • Loading branch information
mikix committed Dec 19, 2022
1 parent 3b622ef commit 33f00b4
Show file tree
Hide file tree
Showing 13 changed files with 190 additions and 102 deletions.
4 changes: 4 additions & 0 deletions cumulus/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime
import os
from socket import gethostname
from typing import List

from cumulus import common, loaders, store

Expand All @@ -19,6 +20,7 @@ def __init__(
timestamp: datetime.datetime = None,
comment: str = None,
batch_size: int = 1, # this default is never really used - overridden by command line args
tasks: List[str] = None,
):
"""
:param loader: describes how input files were loaded (e.g. i2b2 or ndjson)
Expand All @@ -34,6 +36,7 @@ def __init__(
self.hostname = gethostname()
self.comment = comment or ''
self.batch_size = batch_size
self.tasks = tasks or []

def path_codebook(self) -> str:
return self.dir_phi.joinpath('codebook.json')
Expand All @@ -57,6 +60,7 @@ def as_json(self):
'output_format': type(self.format).__name__,
'comment': self.comment,
'batch_size': self.batch_size,
'tasks': ','.join(self.tasks),
}


Expand Down
38 changes: 28 additions & 10 deletions cumulus/deid/codebook.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Codebook that stores the mappings between real and fake IDs"""

import binascii
import hmac
import logging
import secrets
Expand All @@ -25,14 +26,6 @@ def __init__(self, saved: str = None):
except (FileNotFoundError, PermissionError):
self.db = CodebookDB()

# Create a salt, used when hashing resource IDs.
# Some prior art is Microsoft's anonymizer tool which uses a UUID4 salt (with 122 bits of entropy).
# Since this is an important salt, it seems reasonable to do a bit more.
# Python's docs for the secrets module recommend 256 bits, as of 2015.
# The sha256 algorithm is sitting on top of this salt, and a key size equal to the output size is also
# recommended, so 256 bits seem good (which is 32 bytes).
self.salt = secrets.token_bytes(32)

def fake_id(self, resource_type: str, real_id: str) -> str:
"""
Returns a new fake ID in place of the provided real ID
Expand All @@ -55,8 +48,7 @@ def fake_id(self, resource_type: str, real_id: str) -> str:
elif resource_type == 'Encounter':
return self.db.encounter(real_id)
else:
# This will be exactly 64 characters long, the maximum FHIR id length
return hmac.new(self.salt, digestmod='sha256', msg=real_id.encode('utf8')).hexdigest()
return self.db.resource_hash(real_id)


###############################################################################
Expand Down Expand Up @@ -126,6 +118,32 @@ def _fake_id(self, resource_type: str, real_id: str) -> str:

return fake_id

def resource_hash(self, real_id: str) -> str:
"""
Get a fake ID for an arbitrary FHIR resource ID
:param real_id: resource ID
:return: hashed ID, using the saved salt
"""
# This will be exactly 64 characters long, the maximum FHIR id length
return hmac.new(self._id_salt(), digestmod='sha256', msg=real_id.encode('utf8')).hexdigest()

def _id_salt(self) -> bytes:
"""Returns the saved salt or creates and saves one if needed"""
salt = self.mapping.get('id_salt')

if salt is None:
# Create a salt, used when hashing resource IDs.
# Some prior art is Microsoft's anonymizer tool which uses a UUID4 salt (with 122 bits of entropy).
# Since this is an important salt, it seems reasonable to do a bit more.
# Python's docs for the secrets module recommend 256 bits, as of 2015.
# The sha256 algorithm is sitting on top of this salt, and a key size equal to the output size is also
# recommended, so 256 bits seem good (which is 32 bytes).
salt = secrets.token_hex(32)
self.mapping['id_salt'] = salt

return binascii.unhexlify(salt) # revert from doubled hex 64-char string representation back to just 32 bytes

def _load_saved(self, saved: dict) -> None:
"""
:param saved: dictionary containing structure
Expand Down
1 change: 1 addition & 0 deletions cumulus/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
CTAKES_MISSING = 14
SMART_CREDENTIALS_MISSING = 15
BULK_EXPORT_FAILED = 16
TASK_UNKNOWN = 17
69 changes: 49 additions & 20 deletions cumulus/etl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Load, transform, and write out input data to deidentified FHIR"""

import argparse
import collections
import itertools
import json
import logging
Expand Down Expand Up @@ -229,40 +230,63 @@ def etl_notes_text2fhir_symptoms(config: JobConfig, scrubber: deid.Scrubber) ->
#
###############################################################################

def load_and_deidentify(loader: loaders.Loader) -> tempfile.TemporaryDirectory:
TaskDescription = collections.namedtuple(
'TaskDescription',
['func', 'resources'],
)


def get_task_descriptions(tasks: List[str] = None) -> Iterable[TaskDescription]:
# Tasks are named after the table they generate on the output side
available_tasks = {
'patient': TaskDescription(etl_patient, ['Patient']),
'encounter': TaskDescription(etl_encounter, ['Encounter']),
'observation': TaskDescription(etl_lab, ['Observation']),
'documentreference': TaskDescription(etl_notes_meta, ['DocumentReference']),
'symptom': TaskDescription(etl_notes_text2fhir_symptoms, ['DocumentReference']),
'condition': TaskDescription(etl_condition, ['Condition']),
}

if tasks is None:
return available_tasks.values()

try:
return [available_tasks[k] for k in tasks]
except KeyError as exc:
names = '\n'.join(sorted(f' {key}' for key in available_tasks))
print(f'Unknown task name given. Valid task names:\n{names}', file=sys.stderr)
raise SystemExit(errors.TASK_UNKNOWN) from exc


def load_and_deidentify(loader: loaders.Loader, tasks: List[str] = None) -> tempfile.TemporaryDirectory:
"""
Loads the input directory and does a first-pass de-identification
Code outside this method should never see the original input files.
:returns: a temporary directory holding the de-identified files in FHIR ndjson format
"""
# Grab a list of all required resource types for the tasks we are running
required_resources = set(itertools.chain.from_iterable(t.resources for t in get_task_descriptions(tasks)))

# First step is loading all the data into a local ndjson format
loaded_dir = loader.load_all()
loaded_dir = loader.load_all(list(required_resources))

# Second step is de-identifying that data (at a bulk level)
return deid.Scrubber.scrub_bulk_data(loaded_dir.name)


def etl_job(config: JobConfig) -> List[JobSummary]:
def etl_job(config: JobConfig, tasks: Iterable[str] = None) -> List[JobSummary]:
"""
:param config:
:return:
:param config: job config
:param tasks: if specified, only the listed tasks are run
:return: a list of job summaries
"""
summary_list = []

task_list = [
etl_patient,
etl_encounter,
etl_lab,
etl_notes_meta,
etl_notes_text2fhir_symptoms,
etl_condition,
]

scrubber = deid.Scrubber(config.path_codebook())
for task in task_list:
summary = task(config, scrubber)
for task_desc in get_task_descriptions(tasks):
summary = task_desc.func(config, scrubber)
summary_list.append(summary)

scrubber.save()
Expand Down Expand Up @@ -351,6 +375,7 @@ def main(args: List[str]):
parser.add_argument('--smart-client-id', metavar='CLIENT_ID', help='Client ID registered with SMART FHIR server '
'(can be a filename with ID inside it')
parser.add_argument('--smart-jwks', metavar='/path/to/jwks', help='JWKS file registered with SMART FHIR server')
parser.add_argument('--task', action='append', help='Only update the given output tables (comma separated)')
parser.add_argument('--skip-init-checks', action='store_true', help=argparse.SUPPRESS)
args = parser.parse_args(args)

Expand Down Expand Up @@ -383,17 +408,21 @@ def main(args: List[str]):
else:
config_store = formats.NdjsonFormat(root_output)

deid_dir = load_and_deidentify(config_loader)
# Check which tasks are being run, allowing comma-separated values
tasks = args.task and list(itertools.chain.from_iterable(t.split(',') for t in args.task))

# Pull down resources and run the MS tool on them
deid_dir = load_and_deidentify(config_loader, tasks=tasks)

# Prepare config for jobs
config = JobConfig(config_loader, deid_dir.name, config_store, root_phi, comment=args.comment,
batch_size=args.batch_size, timestamp=job_datetime)
batch_size=args.batch_size, timestamp=job_datetime, tasks=tasks)
common.write_json(config.path_config(), config.as_json(), indent=4)
common.print_header('Configuration:')
print(json.dumps(config.as_json(), indent=4))

# Finally, actually run the meat of the pipeline!
summaries = etl_job(config)
# Finally, actually run the meat of the pipeline! (Filtered down to requested tasks)
summaries = etl_job(config, tasks=tasks)

# Print results to the console
common.print_header('Results:')
Expand Down
6 changes: 4 additions & 2 deletions cumulus/loaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import tempfile
from typing import List

from cumulus.store import Root

Expand All @@ -22,9 +23,10 @@ def __init__(self, root: Root):
self.root = root

@abc.abstractmethod
def load_all(self) -> tempfile.TemporaryDirectory:
def load_all(self, resources: List[str]) -> tempfile.TemporaryDirectory:
"""
Loads all remote resources and places them into a local folder as FHIR ndjson
Loads the listed remote resources and places them into a local folder as FHIR ndjson
:param resources: a list of resources to ingest
:returns: an object holding the name of a local ndjson folder path (e.g. a TemporaryDirectory)
"""
15 changes: 4 additions & 11 deletions cumulus/loaders/fhir/fhir_ndjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sys
import tempfile
from typing import List

from cumulus import common, errors, store
from cumulus.loaders import base
Expand Down Expand Up @@ -29,10 +30,10 @@ def __init__(self, root: store.Root, client_id: str = None, jwks: str = None):
self.client_id = client_id
self.jwks = common.read_json(jwks) if jwks else None

def load_all(self) -> tempfile.TemporaryDirectory:
def load_all(self, resources: List[str]) -> tempfile.TemporaryDirectory:
# Are we doing a bulk FHIR export from a server?
if self.root.protocol in ['http', 'https']:
return self._load_from_bulk_export()
return self._load_from_bulk_export(resources)

# Are we reading from a local directory?
if self.root.protocol == 'file':
Expand All @@ -46,7 +47,7 @@ class Dir:
self.root.get(self.root.joinpath('*.ndjson'), f'{tmpdir.name}/')
return tmpdir

def _load_from_bulk_export(self) -> tempfile.TemporaryDirectory:
def _load_from_bulk_export(self, resources: List[str]) -> tempfile.TemporaryDirectory:
# First, check that the extra arguments we need were provided
error_list = []
if not self.client_id:
Expand All @@ -59,14 +60,6 @@ def _load_from_bulk_export(self) -> tempfile.TemporaryDirectory:

tmpdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with

resources = [
'Condition',
'DocumentReference',
'Encounter',
'Observation',
'Patient',
]

try:
server = BackendServiceServer(self.root.path, self.client_id, self.jwks, resources)
bulk_exporter = BulkExporter(server, resources, tmpdir.name)
Expand Down
Loading

0 comments on commit 33f00b4

Please sign in to comment.