Skip to content

Commit

Permalink
Merge pull request #100 from smart-on-fhir/mikix/separate-tasks
Browse files Browse the repository at this point in the history
etl: allow running individual tasks to update single tables
  • Loading branch information
mikix authored Dec 19, 2022
2 parents 3b622ef + 33f00b4 commit 05a2472
Show file tree
Hide file tree
Showing 13 changed files with 190 additions and 102 deletions.
4 changes: 4 additions & 0 deletions cumulus/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime
import os
from socket import gethostname
from typing import List

from cumulus import common, loaders, store

Expand All @@ -19,6 +20,7 @@ def __init__(
timestamp: datetime.datetime = None,
comment: str = None,
batch_size: int = 1, # this default is never really used - overridden by command line args
tasks: List[str] = None,
):
"""
:param loader: describes how input files were loaded (e.g. i2b2 or ndjson)
Expand All @@ -34,6 +36,7 @@ def __init__(
self.hostname = gethostname()
self.comment = comment or ''
self.batch_size = batch_size
self.tasks = tasks or []

def path_codebook(self) -> str:
return self.dir_phi.joinpath('codebook.json')
Expand All @@ -57,6 +60,7 @@ def as_json(self):
'output_format': type(self.format).__name__,
'comment': self.comment,
'batch_size': self.batch_size,
'tasks': ','.join(self.tasks),
}


Expand Down
38 changes: 28 additions & 10 deletions cumulus/deid/codebook.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Codebook that stores the mappings between real and fake IDs"""

import binascii
import hmac
import logging
import secrets
Expand All @@ -25,14 +26,6 @@ def __init__(self, saved: str = None):
except (FileNotFoundError, PermissionError):
self.db = CodebookDB()

# Create a salt, used when hashing resource IDs.
# Some prior art is Microsoft's anonymizer tool which uses a UUID4 salt (with 122 bits of entropy).
# Since this is an important salt, it seems reasonable to do a bit more.
# Python's docs for the secrets module recommend 256 bits, as of 2015.
# The sha256 algorithm is sitting on top of this salt, and a key size equal to the output size is also
# recommended, so 256 bits seem good (which is 32 bytes).
self.salt = secrets.token_bytes(32)

def fake_id(self, resource_type: str, real_id: str) -> str:
"""
Returns a new fake ID in place of the provided real ID
Expand All @@ -55,8 +48,7 @@ def fake_id(self, resource_type: str, real_id: str) -> str:
elif resource_type == 'Encounter':
return self.db.encounter(real_id)
else:
# This will be exactly 64 characters long, the maximum FHIR id length
return hmac.new(self.salt, digestmod='sha256', msg=real_id.encode('utf8')).hexdigest()
return self.db.resource_hash(real_id)


###############################################################################
Expand Down Expand Up @@ -126,6 +118,32 @@ def _fake_id(self, resource_type: str, real_id: str) -> str:

return fake_id

def resource_hash(self, real_id: str) -> str:
"""
Get a fake ID for an arbitrary FHIR resource ID
:param real_id: resource ID
:return: hashed ID, using the saved salt
"""
# This will be exactly 64 characters long, the maximum FHIR id length
return hmac.new(self._id_salt(), digestmod='sha256', msg=real_id.encode('utf8')).hexdigest()

def _id_salt(self) -> bytes:
"""Returns the saved salt or creates and saves one if needed"""
salt = self.mapping.get('id_salt')

if salt is None:
# Create a salt, used when hashing resource IDs.
# Some prior art is Microsoft's anonymizer tool which uses a UUID4 salt (with 122 bits of entropy).
# Since this is an important salt, it seems reasonable to do a bit more.
# Python's docs for the secrets module recommend 256 bits, as of 2015.
# The sha256 algorithm is sitting on top of this salt, and a key size equal to the output size is also
# recommended, so 256 bits seem good (which is 32 bytes).
salt = secrets.token_hex(32)
self.mapping['id_salt'] = salt

return binascii.unhexlify(salt) # revert from doubled hex 64-char string representation back to just 32 bytes

def _load_saved(self, saved: dict) -> None:
"""
:param saved: dictionary containing structure
Expand Down
1 change: 1 addition & 0 deletions cumulus/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
CTAKES_MISSING = 14
SMART_CREDENTIALS_MISSING = 15
BULK_EXPORT_FAILED = 16
TASK_UNKNOWN = 17
69 changes: 49 additions & 20 deletions cumulus/etl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Load, transform, and write out input data to deidentified FHIR"""

import argparse
import collections
import itertools
import json
import logging
Expand Down Expand Up @@ -229,40 +230,63 @@ def etl_notes_text2fhir_symptoms(config: JobConfig, scrubber: deid.Scrubber) ->
#
###############################################################################

def load_and_deidentify(loader: loaders.Loader) -> tempfile.TemporaryDirectory:
TaskDescription = collections.namedtuple(
'TaskDescription',
['func', 'resources'],
)


def get_task_descriptions(tasks: List[str] = None) -> Iterable[TaskDescription]:
# Tasks are named after the table they generate on the output side
available_tasks = {
'patient': TaskDescription(etl_patient, ['Patient']),
'encounter': TaskDescription(etl_encounter, ['Encounter']),
'observation': TaskDescription(etl_lab, ['Observation']),
'documentreference': TaskDescription(etl_notes_meta, ['DocumentReference']),
'symptom': TaskDescription(etl_notes_text2fhir_symptoms, ['DocumentReference']),
'condition': TaskDescription(etl_condition, ['Condition']),
}

if tasks is None:
return available_tasks.values()

try:
return [available_tasks[k] for k in tasks]
except KeyError as exc:
names = '\n'.join(sorted(f' {key}' for key in available_tasks))
print(f'Unknown task name given. Valid task names:\n{names}', file=sys.stderr)
raise SystemExit(errors.TASK_UNKNOWN) from exc


def load_and_deidentify(loader: loaders.Loader, tasks: List[str] = None) -> tempfile.TemporaryDirectory:
"""
Loads the input directory and does a first-pass de-identification
Code outside this method should never see the original input files.
:returns: a temporary directory holding the de-identified files in FHIR ndjson format
"""
# Grab a list of all required resource types for the tasks we are running
required_resources = set(itertools.chain.from_iterable(t.resources for t in get_task_descriptions(tasks)))

# First step is loading all the data into a local ndjson format
loaded_dir = loader.load_all()
loaded_dir = loader.load_all(list(required_resources))

# Second step is de-identifying that data (at a bulk level)
return deid.Scrubber.scrub_bulk_data(loaded_dir.name)


def etl_job(config: JobConfig) -> List[JobSummary]:
def etl_job(config: JobConfig, tasks: Iterable[str] = None) -> List[JobSummary]:
"""
:param config:
:return:
:param config: job config
:param tasks: if specified, only the listed tasks are run
:return: a list of job summaries
"""
summary_list = []

task_list = [
etl_patient,
etl_encounter,
etl_lab,
etl_notes_meta,
etl_notes_text2fhir_symptoms,
etl_condition,
]

scrubber = deid.Scrubber(config.path_codebook())
for task in task_list:
summary = task(config, scrubber)
for task_desc in get_task_descriptions(tasks):
summary = task_desc.func(config, scrubber)
summary_list.append(summary)

scrubber.save()
Expand Down Expand Up @@ -351,6 +375,7 @@ def main(args: List[str]):
parser.add_argument('--smart-client-id', metavar='CLIENT_ID', help='Client ID registered with SMART FHIR server '
'(can be a filename with ID inside it')
parser.add_argument('--smart-jwks', metavar='/path/to/jwks', help='JWKS file registered with SMART FHIR server')
parser.add_argument('--task', action='append', help='Only update the given output tables (comma separated)')
parser.add_argument('--skip-init-checks', action='store_true', help=argparse.SUPPRESS)
args = parser.parse_args(args)

Expand Down Expand Up @@ -383,17 +408,21 @@ def main(args: List[str]):
else:
config_store = formats.NdjsonFormat(root_output)

deid_dir = load_and_deidentify(config_loader)
# Check which tasks are being run, allowing comma-separated values
tasks = args.task and list(itertools.chain.from_iterable(t.split(',') for t in args.task))

# Pull down resources and run the MS tool on them
deid_dir = load_and_deidentify(config_loader, tasks=tasks)

# Prepare config for jobs
config = JobConfig(config_loader, deid_dir.name, config_store, root_phi, comment=args.comment,
batch_size=args.batch_size, timestamp=job_datetime)
batch_size=args.batch_size, timestamp=job_datetime, tasks=tasks)
common.write_json(config.path_config(), config.as_json(), indent=4)
common.print_header('Configuration:')
print(json.dumps(config.as_json(), indent=4))

# Finally, actually run the meat of the pipeline!
summaries = etl_job(config)
# Finally, actually run the meat of the pipeline! (Filtered down to requested tasks)
summaries = etl_job(config, tasks=tasks)

# Print results to the console
common.print_header('Results:')
Expand Down
6 changes: 4 additions & 2 deletions cumulus/loaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import tempfile
from typing import List

from cumulus.store import Root

Expand All @@ -22,9 +23,10 @@ def __init__(self, root: Root):
self.root = root

@abc.abstractmethod
def load_all(self) -> tempfile.TemporaryDirectory:
def load_all(self, resources: List[str]) -> tempfile.TemporaryDirectory:
"""
Loads all remote resources and places them into a local folder as FHIR ndjson
Loads the listed remote resources and places them into a local folder as FHIR ndjson
:param resources: a list of resources to ingest
:returns: an object holding the name of a local ndjson folder path (e.g. a TemporaryDirectory)
"""
15 changes: 4 additions & 11 deletions cumulus/loaders/fhir/fhir_ndjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sys
import tempfile
from typing import List

from cumulus import common, errors, store
from cumulus.loaders import base
Expand Down Expand Up @@ -29,10 +30,10 @@ def __init__(self, root: store.Root, client_id: str = None, jwks: str = None):
self.client_id = client_id
self.jwks = common.read_json(jwks) if jwks else None

def load_all(self) -> tempfile.TemporaryDirectory:
def load_all(self, resources: List[str]) -> tempfile.TemporaryDirectory:
# Are we doing a bulk FHIR export from a server?
if self.root.protocol in ['http', 'https']:
return self._load_from_bulk_export()
return self._load_from_bulk_export(resources)

# Are we reading from a local directory?
if self.root.protocol == 'file':
Expand All @@ -46,7 +47,7 @@ class Dir:
self.root.get(self.root.joinpath('*.ndjson'), f'{tmpdir.name}/')
return tmpdir

def _load_from_bulk_export(self) -> tempfile.TemporaryDirectory:
def _load_from_bulk_export(self, resources: List[str]) -> tempfile.TemporaryDirectory:
# First, check that the extra arguments we need were provided
error_list = []
if not self.client_id:
Expand All @@ -59,14 +60,6 @@ def _load_from_bulk_export(self) -> tempfile.TemporaryDirectory:

tmpdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with

resources = [
'Condition',
'DocumentReference',
'Encounter',
'Observation',
'Patient',
]

try:
server = BackendServiceServer(self.root.path, self.client_id, self.jwks, resources)
bulk_exporter = BulkExporter(server, resources, tmpdir.name)
Expand Down
Loading

0 comments on commit 05a2472

Please sign in to comment.