diff --git a/.gitignore b/.gitignore index a3ff6319..42def2f3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ tmp/ venv/ env/ +venv3_new/ .idea/ Pentaho-scripts/ __pycache__/ @@ -9,4 +10,4 @@ cmd.txt tests/temp config/data-loader.yml config/file-loader.yml -backup/ \ No newline at end of file +backup/ diff --git a/bento b/bento index 98457f53..78a929a5 160000 --- a/bento +++ b/bento @@ -1 +1 @@ -Subproject commit 98457f53ff55fdc13b7cdfe51b3cc42411883f68 +Subproject commit 78a929a58fef780e62c00252653b0fd6a9e9fb17 diff --git a/config.py b/config.py new file mode 100644 index 00000000..cd118c40 --- /dev/null +++ b/config.py @@ -0,0 +1,104 @@ +from configparser import ConfigParser +import os +import yaml + +from bento.common.utils import get_logger + + +class BentoConfig: + def __init__(self, config_file): + self.log = get_logger('Bento Config') + self.PSWD_ENV = 'NEO_PASSWORD' + + if config_file is None: + # File-Loader related + self.temp_folder = None + self.queue_long_pull_time = None + self.visibility_timeout = None + self.indexd_guid_prefix = None + self.indexd_manifest_ext = None + self.rel_prop_delimiter = None + + # Data-Loader Related + self.backup_folder = None + self.neo4j_uri = None + self.neo4j_user = None + self.neo4j_password = None + self.schema_files = None + self.prop_file = None + self.cheat_mode = None + self.dry_run = None + self.wipe_db = None + self.no_backup = None + self.yes = None + self.max_violations = None + self.s3_bucket = None + self.s3_folder = None + self.loading_mode = None + self.dataset = None + self.no_parents = None + self.split_transactions = None + else: + if os.path.isfile(config_file): + with open(config_file) as c_file: + config = yaml.safe_load(c_file)['Config'] + + ################################# + # Folders + self.temp_folder = config.get('temp_folder') + if self.temp_folder: + self._create_folder(self.temp_folder) + + self.backup_folder = config.get('backup_folder') + if self.backup_folder: + self._create_folder(self.backup_folder) + + ################################# + # File-loader related + if 'sqs' in config: + sqs = config['sqs'] + self.queue_long_pull_time = sqs.get('long_pull_time') + self.visibility_timeout = sqs.get('visibility_timeout') + + if 'indexd' in config: + indexd = config['indexd'] + self.indexd_guid_prefix = indexd.get('GUID_prefix') + self.indexd_manifest_ext = indexd.get('ext') + if self.indexd_manifest_ext and not self.indexd_manifest_ext.startswith('.'): + self.indexd_manifest_ext = '.' + self.indexd_manifest_ext + self.slack_url = config.get('url') + + ################################# + # Data-loader related + self.rel_prop_delimiter = config.get('rel_prop_delimiter') + if 'neo4j' in config: + neo4j = config['neo4j'] + self.neo4j_uri = neo4j.get('uri') + self.neo4j_user = neo4j.get('user') + self.neo4j_password = neo4j.get('password') + + self.schema_files = config.get('schema') + self.prop_file = config.get('prop_file') + self.cheat_mode = config.get('cheat_mode') + self.dry_run = config.get('dry_run') + self.wipe_db = config.get('wipe_db') + self.no_backup = config.get('no_backup') + self.yes = config.get('no_confirmation') + self.max_violations = config.get('max_violations', 10) + self.s3_bucket = config.get('s3_bucket') + self.s3_folder = config.get('s3_folder') + self.loading_mode = config.get('loading_mode', 'UPSERT_MODE') + self.dataset = config.get('dataset') + self.no_parents = config.get('no_parents') + self.split_transactions = config.get('split_transactions') + else: + msg = f'Can NOT open configuration file "{config_file}"!' + self.log.error(msg) + raise Exception(msg) + + def _create_folder(self, folder): + os.makedirs(folder, exist_ok=True) + if not os.path.isdir(folder): + msg = f'{folder} is not a folder!' + self.log.error(msg) + raise Exception(msg) diff --git a/config_base.py b/config_base.py new file mode 100644 index 00000000..f2acfa26 --- /dev/null +++ b/config_base.py @@ -0,0 +1,48 @@ +import os + +import yaml + +from bento.common.utils import get_logger + + +class BentoConfig: + def __init__(self, config_file, args, config_file_arg='config_file'): + self.log = get_logger('Bento Config') + if not config_file: + raise ValueError(f'Empty config file name') + if not os.path.isfile(config_file): + raise ValueError(f'"{config_file}" is not a file!') + + self.config_file_arg = config_file_arg + + with open(config_file) as c_file: + self.data = yaml.safe_load(c_file)['Config'] + if self.data is None: + self.data = {} + + self._override(args) + + def _override(self, args): + for key, value in vars(args).items(): + # Ignore config file argument + if key == self.config_file_arg: + continue + if isinstance(value, bool): + if value: + self.data[key] = value + + elif value is not None: + self.data[key] = value + + def create_folder(self, folder): + """ + Create given folder if not already exists + :param folder: folder path + :return: + """ + os.makedirs(folder, exist_ok=True) + if not os.path.isdir(folder): + msg = f'{folder} is not a folder!' + self.log.error(msg) + raise Exception(msg) + diff --git a/data_loader.py b/data_loader.py new file mode 100644 index 00000000..a025c898 --- /dev/null +++ b/data_loader.py @@ -0,0 +1,924 @@ +#!/usr/bin/env python3 + +import os +from collections import deque +import csv +import re +import datetime +import sys +import platform +import subprocess +import json +from timeit import default_timer as timer +from bento.common.utils import get_host, DATETIME_FORMAT + +from neo4j import Driver + +from icdc_schema import ICDC_Schema +from bento.common.utils import get_logger, NODES_CREATED, RELATIONSHIP_CREATED, UUID, \ + RELATIONSHIP_TYPE, MULTIPLIER, ONE_TO_ONE, DEFAULT_MULTIPLIER, UPSERT_MODE, \ + NEW_MODE, DELETE_MODE, NODES_DELETED, RELATIONSHIP_DELETED, combined_dict_counters + +NODE_TYPE = 'type' +PROP_TYPE = 'Type' +PARENT_TYPE = 'parent_type' +PARENT_ID_FIELD = 'parent_id_field' +PARENT_ID = 'parent_id' +excluded_fields = {NODE_TYPE} +CASE_NODE = 'case' +CASE_ID = 'case_id' +CREATED = 'created' +UPDATED = 'updated' +RELATIONSHIPS = 'relationships' +INT_NODE_CREATED = 'int_node_created' +PROVIDED_PARENTS = 'provided_parents' +RELATIONSHIP_PROPS = 'relationship_properties' +BATCH_SIZE = 1000 + + +def get_indexes(session): + """ + Queries the database to get all existing indexes + + :param session: the current neo4j transaction session + :return: A set of tuples representing all existing indexes in the database + """ + command = "call db.indexes()" + result = session.run(command) + indexes = set() + for r in result: + indexes.add(format_as_tuple(r["tokenNames"][0], r["properties"])) + return indexes + + +def format_as_tuple(node_name, properties): + """ + Format index info as a tuple + + :param node_name: The name of the node type for the index + :param properties: The list of node properties being used by the index + :return: A tuple containing the index node_name followed by the index properties in alphabetical order + """ + if isinstance(properties, str): + properties = [properties] + lst = [node_name] + sorted(properties) + return tuple(lst) + + +def backup_neo4j(backup_dir, name, address, log): + try: + restore_cmd = 'To restore DB from backup (to remove any changes caused by current data loading, run following commands:\n' + restore_cmd += '#' * 160 + '\n' + neo4j_cmd = 'neo4j-admin restore --from={}/{} --force'.format(backup_dir, name) + mkdir_cmd = [ + 'mkdir', + '-p', + backup_dir + ] + is_shell = False + # settings for Windows platforms + if platform.system() == "Windows": + mkdir_cmd[2] = os.path.abspath(backup_dir) + is_shell = True + cmds = [ + mkdir_cmd, + [ + 'neo4j-admin', + 'backup', + '--backup-dir={}'.format(backup_dir), + '--name={}'.format(name), + ] + ] + if address in ['localhost', '127.0.0.1']: + # On Windows, the Neo4j service cannot be accessed through the command line without an absolute path + # or a custom installation location + if platform.system() == "Windows": + restore_cmd += '\tManually stop the Neo4j service\n\t$ {}\n\tManually start the Neo4j service\n'.format( + neo4j_cmd) + else: + restore_cmd += '\t$ neo4j stop && {} && neo4j start\n'.format(neo4j_cmd) + for cmd in cmds: + log.info(cmd) + subprocess.call(cmd, shell=is_shell) + else: + second_cmd = 'sudo systemctl stop neo4j && {} && sudo systemctl start neo4j && exit'.format(neo4j_cmd) + restore_cmd += '\t$ echo "{}" | ssh -t {} sudo su - neo4j\n'.format(second_cmd, address) + for cmd in cmds: + remote_cmd = ['ssh', address, '-o', 'StrictHostKeyChecking=no'] + cmd + log.info(' '.join(remote_cmd)) + subprocess.call(remote_cmd) + restore_cmd += '#' * 160 + return restore_cmd + except Exception as e: + log.exception(e) + return False + + +class DataLoader: + def __init__(self, driver, schema, intermediate_node_creator=None): + if not schema or not isinstance(schema, ICDC_Schema): + raise Exception('Invalid ICDC_Schema object') + self.log = get_logger('Data Loader') + self.driver = driver + self.schema = schema + self.rel_prop_delimiter = self.schema.rel_prop_delimiter + + if intermediate_node_creator: + if not hasattr(intermediate_node_creator, 'create_intermediate_node'): + raise ValueError('Invalide Intermediate node creator') + if not hasattr(intermediate_node_creator, 'nodes_stat'): + raise ValueError('Invalide Intermediate node creator') + if not hasattr(intermediate_node_creator, 'relationships_stat'): + raise ValueError('Invalide Intermediate node creator') + self.int_node_creator = intermediate_node_creator + + def check_files(self, file_list): + if not file_list: + self.log.error('Invalid file list') + return False + elif file_list: + for data_file in file_list: + if not os.path.isfile(data_file): + self.log.error('File "{}" doesn\'t exist'.format(data_file)) + return False + return True + + def validate_files(self, cheat_mode, file_list, max_violations): + if not cheat_mode: + validation_failed = False + for txt in file_list: + if not self.validate_file(txt, max_violations): + self.log.error('Validating file "{}" failed!'.format(txt)) + validation_failed = True + return not validation_failed + else: + self.log.info('Cheat mode enabled, all validations skipped!') + return True + + def load(self, file_list, cheat_mode, dry_run, loading_mode, wipe_db, max_violations, no_parents, + split=False, no_backup=True, backup_folder="/", neo4j_uri=None): + if not self.check_files(file_list): + return False + start = timer() + if not self.validate_files(cheat_mode, file_list, max_violations): + return False + if not no_backup and not dry_run: + if not neo4j_uri: + self.log.error('No Neo4j URI specified for backup, abort loading!') + sys.exit(1) + backup_name = datetime.datetime.today().strftime(DATETIME_FORMAT) + host = get_host(neo4j_uri) + restore_cmd = backup_neo4j(backup_folder, backup_name, host, self.log) + if not restore_cmd: + self.log.error('Backup Neo4j failed, abort loading!') + sys.exit(1) + if dry_run: + end = timer() + self.log.info('Dry run mode, no nodes or relationships loaded.') # Time in seconds, e.g. 5.38091952400282 + self.log.info('Running time: {:.2f} seconds'.format(end - start)) # Time in seconds, e.g. 5.38091952400282 + return {NODES_CREATED: 0, RELATIONSHIP_CREATED: 0} + + self.nodes_created = 0 + self.relationships_created = 0 + self.indexes_created = 0 + self.nodes_deleted = 0 + self.relationships_deleted = 0 + self.nodes_stat = {} + self.relationships_stat = {} + self.nodes_deleted_stat = {} + self.relationships_deleted_stat = {} + if not self.driver or not isinstance(self.driver, Driver): + self.log.error('Invalid Neo4j Python Driver!') + return False + # Data updates and schema related updates cannot be performed in the same session so multiple will be created + # Create new session for schema related updates (index creation) + with self.driver.session() as session: + tx = session.begin_transaction() + try: + self.create_indexes(tx) + tx.commit() + except Exception as e: + tx.rollback() + self.log.exception(e) + return False + # Create new session for data related updates + with self.driver.session() as session: + # Split Transactions enabled + if split: + self._load_all(session, file_list, loading_mode, no_parents, split, wipe_db) + + # Split Transactions Disabled + else: + # Data updates transaction + tx = session.begin_transaction() + try: + self._load_all(tx, file_list, loading_mode, no_parents, split, wipe_db) + tx.commit() + except Exception as e: + tx.rollback() + self.log.exception(e) + return False + + # End the timer + end = timer() + + # Print statistics + if self.int_node_creator: + combined_dict_counters(self.nodes_stat, self.int_node_creator.nodes_stat) + combined_dict_counters(self.relationships_stat, self.int_node_creator.relationships_stat) + self.nodes_created += self.int_node_creator.nodes_created + self.relationships_created += self.int_node_creator.relationships_created + for node in sorted(self.nodes_stat.keys()): + count = self.nodes_stat[node] + self.log.info('Node: (:{}) loaded: {}'.format(node, count)) + for rel in sorted(self.relationships_stat.keys()): + count = self.relationships_stat[rel] + self.log.info('Relationship: [:{}] loaded: {}'.format(rel, count)) + self.log.info('{} new indexes created!'.format(self.indexes_created)) + self.log.info('{} nodes and {} relationships loaded!'.format(self.nodes_created, self.relationships_created)) + self.log.info('{} nodes and {} relationships deleted!'.format(self.nodes_deleted, self.relationships_deleted)) + self.log.info('Loading time: {:.2f} seconds'.format(end - start)) # Time in seconds, e.g. 5.38091952400282 + return {NODES_CREATED: self.nodes_created, RELATIONSHIP_CREATED: self.relationships_created, + NODES_DELETED: self.nodes_deleted, RELATIONSHIP_DELETED: self.relationships_deleted} + + def _load_all(self, tx, file_list, loading_mode, no_parents, split, wipe_db): + if wipe_db: + self.wipe_db(tx, split) + for txt in file_list: + self.load_nodes(tx, txt, loading_mode, no_parents, split) + if loading_mode != DELETE_MODE: + for txt in file_list: + self.load_relationships(tx, txt, loading_mode, split) + + # Remove extra spaces at begining and end of the keys and values + @staticmethod + def cleanup_node(node): + obj = {} + for key, value in node.items(): + obj[key.strip()] = value.strip() + return obj + + # Cleanup values for Boolean, Int and Float types + # Add uuid to nodes if one not exists + # Add parent id(s) + # Add extra properties for "value with unit" properties + def prepare_node(self, node, no_parents): + obj = self.cleanup_node(node) + + node_type = obj.get(NODE_TYPE, None) + # Cleanup values for Boolean, Int and Float types + if node_type: + for key, value in obj.items(): + search_node_type = node_type + search_key = key + if self.schema.is_parent_pointer(key): + search_node_type, search_key = key.split('.') + elif self.schema.is_relationship_property(key): + search_node_type, search_key = key.split(self.rel_prop_delimiter) + + key_type = self.schema.get_prop_type(search_node_type, search_key) + if key_type == 'Boolean': + cleaned_value = None + if isinstance(value, str): + if re.search(r'yes|true', value, re.IGNORECASE): + cleaned_value = True + elif re.search(r'no|false', value, re.IGNORECASE): + cleaned_value = False + else: + self.log.debug('Unsupported Boolean value: "{}"'.format(value)) + cleaned_value = None + obj[key] = cleaned_value + elif key_type == 'Int': + try: + if value is None: + cleaned_value = None + else: + cleaned_value = int(value) + except Exception: + cleaned_value = None + obj[key] = cleaned_value + elif key_type == 'Float': + try: + if value is None: + cleaned_value = None + else: + cleaned_value = float(value) + except Exception: + cleaned_value = None + obj[key] = cleaned_value + elif key_type == 'Array': + items = self.schema.get_list_values(value) + # todo: need to transform items if item type is not string + obj[key] = json.dumps(items) + + if UUID not in obj: + id_field = self.schema.get_id_field(obj) + id_value = self.schema.get_id(obj) + node_type = obj.get(NODE_TYPE) + if node_type: + if not id_value: + obj[UUID] = self.schema.get_uuid_for_node(node_type, self.get_signature(obj)) + elif id_field != UUID: + obj[UUID] = self.schema.get_uuid_for_node(node_type, id_value) + else: + raise Exception('No "type" property in node') + + obj2 = {} + for key, value in obj.items(): + obj2[key] = value + # Add parent id field(s) into node + if self.schema.is_parent_pointer(key) and not no_parents: + header = key.split('.') + if len(header) > 2: + self.log.warning('Column header "{}" has multiple periods!'.format(key)) + field_name = header[1] + parent = header[0] + combined = '{}_{}'.format(parent, field_name) + if field_name in obj: + self.log.debug( + '"{}" field is in both current node and parent "{}", use {} instead !'.format(key, parent, + combined)) + field_name = combined + # Add an value for parent id + obj2[field_name] = value + # Add extra properties if any + for extra_prop_name, extra_value in self.schema.get_extra_props(node_type, key, value).items(): + obj2[extra_prop_name] = extra_value + + return obj2 + + @staticmethod + def get_signature(node): + result = [] + for key, value in node.items(): + result.append('{}: {}'.format(key, value)) + return '{{ {} }}'.format(', '.join(result)) + + # Validate all cases exist in a data (TSV/TXT) file + def validate_cases_exist_in_file(self, file_name, max_violations, no_parents): + if not self.driver or not isinstance(self.driver, Driver): + self.log.error('Invalid Neo4j Python Driver!') + return False + with self.driver.session() as session: + file_encoding = self.check_encoding(file_name) + with open(file_name, encoding=file_encoding) as in_file: + self.log.info('Validating relationships in file "{}" ...'.format(file_name)) + reader = csv.DictReader(in_file, delimiter='\t') + line_num = 1 + validation_failed = False + violations = 0 + for org_obj in reader: + obj = self.prepare_node(org_obj, no_parents) + line_num += 1 + # Validate parent exist + if CASE_ID in obj: + case_id = obj[CASE_ID] + if not self.node_exists(session, CASE_NODE, CASE_ID, case_id): + self.log.error( + 'Invalid data at line {}: Parent (:{} {{ {}: "{}" }}) doesn\'t exist!'.format( + line_num, CASE_NODE, CASE_ID, case_id)) + validation_failed = True + violations += 1 + if violations >= max_violations: + return False + return not validation_failed + + # Validate all parents exist in a data (TSV/TXT) file + def validate_parents_exist_in_file(self, file_name, max_violations, no_parents): + validation_failed = True + if not self.driver or not isinstance(self.driver, Driver): + self.log.error('Invalid Neo4j Python Driver!') + return False + with self.driver.session() as session: + file_encoding = self.check_encoding(file_name) + with open(file_name, encoding=file_encoding) as in_file: + self.log.info('Validating relationships in file "{}" ...'.format(file_name)) + reader = csv.DictReader(in_file, delimiter='\t') + line_num = 1 + validation_failed = False + violations = 0 + for org_obj in reader: + line_num += 1 + obj = self.prepare_node(org_obj, no_parents) + results = self.collect_relationships(obj, session, False, line_num) + relationships = results[RELATIONSHIPS] + provided_parents = results[PROVIDED_PARENTS] + if provided_parents > 0: + if len(relationships) == 0: + self.log.error('Invalid data at line {}: No parents found!'.format(line_num)) + validation_failed = True + violations += 1 + if violations >= max_violations: + return False + else: + self.log.info('Line: {} - No parents found'.format(line_num)) + + return not validation_failed + + def get_node_properties(self, obj): + ''' + Generate a node with only node properties from input data + + :param obj: input data object (dict), may contain parent pointers, relationship properties etc. + :return: an object (dict) that only contains properties on this node + ''' + node = {} + + for key, value in obj.items(): + if self.schema.is_parent_pointer(key): + continue + elif self.schema.is_relationship_property(key): + continue + else: + node[key] = value + + return node + + #Check encoding + def check_encoding(self, file_name): + utf8 = 'utf-8' + windows1252 = 'windows-1252' + try: + with open(file_name, encoding=utf8) as file: + for line in file.readlines(): + pass + return utf8 + except UnicodeDecodeError: + return windows1252 + + # Validate file + def validate_file(self, file_name, max_violations): + file_encoding = self.check_encoding(file_name) + with open(file_name, encoding=file_encoding) as in_file: + self.log.info('Validating file "{}" ...'.format(file_name)) + reader = csv.DictReader(in_file, delimiter='\t') + line_num = 1 + validation_failed = False + violations = 0 + IDs = {} + for org_obj in reader: + obj = self.cleanup_node(org_obj) + props = self.get_node_properties(obj) + line_num += 1 + id_field = self.schema.get_id_field(obj) + node_id = self.schema.get_id(obj) + if node_id: + if node_id in IDs: + if props != IDs[node_id]['props']: + validation_failed = True + self.log.error( + f'Invalid data at line {line_num}: duplicate {id_field}: {node_id}, found in line: {", ".join(IDs[node_id]["lines"])}') + IDs[node_id]['lines'].append(str(line_num)) + else: + # Same ID exists in same file, but properties are also same, probably it's pointing same object to multiple parents + self.log.debug( + f'Duplicated data at line {line_num}: duplicate {id_field}: {node_id}, found in line: {", ".join(IDs[node_id]["lines"])}') + else: + IDs[node_id] = {'props': props, 'lines': [str(line_num)]} + + validate_result = self.schema.validate_node(obj[NODE_TYPE], obj) + if not validate_result['result']: + for msg in validate_result['messages']: + self.log.error('Invalid data at line {}: "{}"!'.format(line_num, msg)) + validation_failed = True + violations += 1 + if violations >= max_violations: + return False + return not validation_failed + + def get_new_statement(self, node_type, obj): + # statement is used to create current node + prop_stmts = [] + + for key in obj.keys(): + if key in excluded_fields: + continue + elif self.schema.is_parent_pointer(key): + continue + elif self.schema.is_relationship_property(key): + continue + + prop_stmts.append('{0}: {{{0}}}'.format(key)) + + statement = 'CREATE (:{0} {{ {1} }})'.format(node_type, ' ,'.join(prop_stmts)) + return statement + + def get_upsert_statement(self, node_type, id_field, obj): + # statement is used to create current node + statement = '' + prop_stmts = [] + + for key in obj.keys(): + if key in excluded_fields: + continue + elif key == id_field: + continue + elif self.schema.is_parent_pointer(key): + continue + elif self.schema.is_relationship_property(key): + continue + + prop_stmts.append('n.{0} = {{{0}}}'.format(key)) + + statement += 'MERGE (n:{0} {{ {1}: {{{1}}} }})'.format(node_type, id_field) + statement += ' ON CREATE ' + 'SET n.{} = datetime(), '.format(CREATED) + ' ,'.join(prop_stmts) + statement += ' ON MATCH ' + 'SET n.{} = datetime(), '.format(UPDATED) + ' ,'.join(prop_stmts) + return statement + + # Delete a node and children with no other parents recursively + def delete_node(self, session, node): + delete_queue = deque([node]) + node_deleted = 0 + relationship_deleted = 0 + while len(delete_queue) > 0: + root = delete_queue.popleft() + delete_queue.extend(self.get_children_with_single_parent(session, root)) + n_deleted, r_deleted = self.delete_single_node(session, root) + node_deleted += n_deleted + relationship_deleted += r_deleted + return (node_deleted, relationship_deleted) + + # Return children of node without other parents + def get_children_with_single_parent(self, session, node): + node_type = node[NODE_TYPE] + statement = 'MATCH (n:{0} {{ {1}: {{{1}}} }})<--(m)'.format(node_type, self.schema.get_id_field(node)) + statement += ' WHERE NOT (n)<--(m)-->() RETURN m' + result = session.run(statement, node) + children = [] + for obj in result: + children.append(self.get_node_from_result(obj, 'm')) + return children + + @staticmethod + def get_node_from_result(record, name): + node = record.data()[name] + result = dict(node.items()) + for label in node.labels: + result[NODE_TYPE] = label + break + return result + + # Simple delete given node, and it's relationships + def delete_single_node(self, session, node): + node_type = node[NODE_TYPE] + statement = 'MATCH (n:{0} {{ {1}: {{{1}}} }}) detach delete n'.format(node_type, self.schema.get_id_field(node)) + result = session.run(statement, node) + nodes_deleted = result.summary().counters.nodes_deleted + self.nodes_deleted += nodes_deleted + self.nodes_deleted_stat[node_type] = self.nodes_deleted_stat.get(node_type, 0) + nodes_deleted + relationship_deleted = result.summary().counters.relationships_deleted + self.relationships_deleted += relationship_deleted + return (nodes_deleted, relationship_deleted) + + # load file + def load_nodes(self, session, file_name, loading_mode, no_parents, split=False): + if loading_mode == NEW_MODE: + action_word = 'Loading new' + elif loading_mode == UPSERT_MODE: + action_word = 'Loading' + elif loading_mode == DELETE_MODE: + action_word = 'Deleting' + else: + raise Exception('Wrong loading_mode: {}'.format(loading_mode)) + self.log.info('{} nodes from file: {}'.format(action_word, file_name)) + + file_encoding = self.check_encoding(file_name) + with open(file_name, encoding=file_encoding) as in_file: + reader = csv.DictReader(in_file, delimiter='\t') + nodes_created = 0 + nodes_deleted = 0 + node_type = 'UNKNOWN' + relationship_deleted = 0 + line_num = 1 + transaction_counter = 0 + + # Use session in one transaction mode + tx = session + # Use transactions in split-transactions mode + if split: + tx = session.begin_transaction() + + for org_obj in reader: + line_num += 1 + transaction_counter += 1 + obj = self.prepare_node(org_obj, no_parents) + node_type = obj[NODE_TYPE] + node_id = self.schema.get_id(obj) + if not node_id: + raise Exception('Line:{}: No ids found!'.format(line_num)) + id_field = self.schema.get_id_field(obj) + if loading_mode == UPSERT_MODE: + statement = self.get_upsert_statement(node_type, id_field, obj) + elif loading_mode == NEW_MODE: + if self.node_exists(tx, node_type, id_field, node_id): + raise Exception( + 'Line: {}: Node (:{} {{ {}: {} }}) exists! Abort loading!'.format(line_num, node_type, + id_field, node_id)) + else: + statement = self.get_new_statement(node_type, obj) + elif loading_mode == DELETE_MODE: + n_deleted, r_deleted = self.delete_node(tx, obj) + nodes_deleted += n_deleted + relationship_deleted += r_deleted + else: + raise Exception('Wrong loading_mode: {}'.format(loading_mode)) + + if loading_mode != DELETE_MODE: + result = tx.run(statement, obj) + count = result.summary().counters.nodes_created + self.nodes_created += count + nodes_created += count + self.nodes_stat[node_type] = self.nodes_stat.get(node_type, 0) + count + # commit and restart a transaction when batch size reached + if split and transaction_counter >= BATCH_SIZE: + tx.commit() + tx = session.begin_transaction() + self.log.info(f'{line_num -1} rows loaded ...') + transaction_counter = 0 + # commit last transaction + if split: + tx.commit() + + + if loading_mode == DELETE_MODE: + self.log.info('{} node(s) deleted'.format(nodes_deleted)) + self.log.info('{} relationship(s) deleted'.format(relationship_deleted)) + else: + self.log.info('{} (:{}) node(s) loaded'.format(nodes_created, node_type)) + + def node_exists(self, session, label, prop, value): + statement = 'MATCH (m:{0} {{ {1}: {{{1}}} }}) return m'.format(label, prop) + result = session.run(statement, {prop: value}) + count = result.detach() + if count > 1: + self.log.warning('More than one nodes found! ') + return count >= 1 + + def collect_relationships(self, obj, session, create_intermediate_node, line_num): + node_type = obj[NODE_TYPE] + relationships = [] + int_node_created = 0 + provided_parents = 0 + relationship_properties = {} + for key, value in obj.items(): + if self.schema.is_parent_pointer(key): + provided_parents += 1 + other_node, other_id = key.split('.') + relationship = self.schema.get_relationship(node_type, other_node) + if not isinstance(relationship, dict): + self.log.error('Line: {}: Relationship not found!'.format(line_num)) + raise Exception('Undefined relationship, abort loading!') + relationship_name = relationship[RELATIONSHIP_TYPE] + multiplier = relationship[MULTIPLIER] + if not relationship_name: + self.log.error('Line: {}: Relationship not found!'.format(line_num)) + raise Exception('Undefined relationship, abort loading!') + if not self.node_exists(session, other_node, other_id, value): + if create_intermediate_node and self.int_node_creator and self.int_node_creator.is_valid_int_node( + other_node): + if self.int_node_creator.create_intermediate_node(session, line_num, other_node, value, obj): + int_node_created += 1 + relationships.append({PARENT_TYPE: other_node, PARENT_ID_FIELD: other_id, PARENT_ID: value, + RELATIONSHIP_TYPE: relationship_name, MULTIPLIER: multiplier}) + else: + self.log.error( + 'Line: {}: Couldn\'t create {} node automatically!'.format(line_num, other_node)) + else: + self.log.warning( + 'Line: {}: Parent node (:{} {{{}: "{}"}} not found in DB!'.format(line_num, other_node, + other_id, + value)) + else: + if multiplier == ONE_TO_ONE and self.parent_already_has_child(session, node_type, obj, + relationship_name, other_node, + other_id, value): + self.log.error( + 'Line: {}: one_to_one relationship failed, parent already has a child!'.format(line_num)) + else: + relationships.append({PARENT_TYPE: other_node, PARENT_ID_FIELD: other_id, PARENT_ID: value, + RELATIONSHIP_TYPE: relationship_name, MULTIPLIER: multiplier}) + elif self.schema.is_relationship_property(key): + rel_name, prop_name = key.split(self.rel_prop_delimiter) + if rel_name not in relationship_properties: + relationship_properties[rel_name] = {} + relationship_properties[rel_name][prop_name] = value + return {RELATIONSHIPS: relationships, INT_NODE_CREATED: int_node_created, PROVIDED_PARENTS: provided_parents, + RELATIONSHIP_PROPS: relationship_properties} + + def parent_already_has_child(self, session, node_type, node, relationship_name, parent_type, parent_id_field, + parent_id): + statement = 'MATCH (n:{})-[r:{}]->(m:{} {{ {}: {{parent_id}} }}) return n'.format(node_type, relationship_name, + parent_type, parent_id_field) + result = session.run(statement, {"parent_id": parent_id}) + if result: + child = result.single() + if child: + find_current_node_statement = 'MATCH (n:{0} {{ {1}: {{{1}}} }}) return n'.format(node_type, + self.schema.get_id_field( + node)) + current_node_result = session.run(find_current_node_statement, node) + if current_node_result: + current_node = current_node_result.single() + return child[0].id != current_node[0].id + else: + self.log.error('Could NOT find current node!') + + return False + + # Check if a relationship of same type exists, if so, return a statement which can delete it, otherwise return False + def has_existing_relationship(self, session, node_type, node, relationship, count_same_parent=False): + relationship_name = relationship[RELATIONSHIP_TYPE] + parent_type = relationship[PARENT_TYPE] + parent_id_field = relationship[PARENT_ID_FIELD] + + base_statement = 'MATCH (n:{0} {{ {1}: {{{1}}} }})-[r:{2}]->(m:{3})'.format(node_type, + self.schema.get_id_field(node), + relationship_name, parent_type) + statement = base_statement + ' return m.{} AS {}'.format(parent_id_field, PARENT_ID) + result = session.run(statement, node) + if result: + old_parent = result.single() + if old_parent: + if count_same_parent: + del_statement = base_statement + ' delete r' + return del_statement + else: + old_parent_id = old_parent[PARENT_ID] + if old_parent_id != relationship[PARENT_ID]: + self.log.warning('Old parent is different from new parent, delete relationship to old parent:' + + ' (:{} {{ {}: "{}" }})!'.format(parent_type, parent_id_field, old_parent_id)) + del_statement = base_statement + ' delete r' + return del_statement + else: + self.log.error('Remove old relationship failed: Query old relationship failed!') + + return False + + def remove_old_relationship(self, session, node_type, node, relationship): + del_statement = self.has_existing_relationship(session, node_type, node, relationship) + if del_statement: + del_result = session.run(del_statement, node) + if not del_result: + self.log.error('Delete old relationship failed!') + + def load_relationships(self, session, file_name, loading_mode, split=False): + if loading_mode == NEW_MODE: + action_word = 'Loading new' + elif loading_mode == UPSERT_MODE: + action_word = 'Loading' + else: + raise Exception('Wrong loading_mode: {}'.format(loading_mode)) + self.log.info('{} relationships from file: {}'.format(action_word, file_name)) + + file_encoding = self.check_encoding(file_name) + with open(file_name, encoding=file_encoding) as in_file: + reader = csv.DictReader(in_file, delimiter='\t') + relationships_created = {} + int_nodes_created = 0 + line_num = 1 + transaction_counter = 0 + + # Use session in one transaction mode + tx = session + # Use transactions in split-transactions mode + if split: + tx = session.begin_transaction() + + for org_obj in reader: + line_num += 1 + transaction_counter += 1 + obj = self.prepare_node(org_obj, False) + node_type = obj[NODE_TYPE] + results = self.collect_relationships(obj, tx, True, line_num) + relationships = results[RELATIONSHIPS] + int_nodes_created += results[INT_NODE_CREATED] + provided_parents = results[PROVIDED_PARENTS] + relationship_props = results[RELATIONSHIP_PROPS] + if provided_parents > 0: + if len(relationships) == 0: + raise Exception('Line: {}: No parents found, abort loading!'.format(line_num)) + + for relationship in relationships: + relationship_name = relationship[RELATIONSHIP_TYPE] + multiplier = relationship[MULTIPLIER] + parent_node = relationship[PARENT_TYPE] + parent_id_field = relationship[PARENT_ID_FIELD] + properties = relationship_props.get(relationship_name, {}) + if multiplier in [DEFAULT_MULTIPLIER, ONE_TO_ONE]: + if loading_mode == UPSERT_MODE: + self.remove_old_relationship(tx, node_type, obj, relationship) + elif loading_mode == NEW_MODE: + if self.has_existing_relationship(tx, node_type, obj, relationship, True): + raise Exception( + 'Line: {}: Relationship already exists, abort loading!'.format(line_num)) + else: + raise Exception('Wrong loading_mode: {}'.format(loading_mode)) + else: + self.log.debug('Multiplier: {}, no action needed!'.format(multiplier)) + prop_statement = ', '.join(self.get_relationship_prop_statements(properties)) + statement = 'MATCH (m:{0} {{ {1}: {{{1}}} }})'.format(parent_node, parent_id_field) + statement += ' MATCH (n:{0} {{ {1}: {{{1}}} }})'.format(node_type, + self.schema.get_id_field(obj)) + statement += ' MERGE (n)-[r:{}]->(m)'.format(relationship_name) + statement += ' ON CREATE SET r.{} = datetime()'.format(CREATED) + statement += ', {}'.format(prop_statement) if prop_statement else '' + statement += ' ON MATCH SET r.{} = datetime()'.format(UPDATED) + statement += ', {}'.format(prop_statement) if prop_statement else '' + + result = tx.run(statement, {**obj, **properties}) + count = result.summary().counters.relationships_created + self.relationships_created += count + relationship_pattern = '(:{})->[:{}]->(:{})'.format(node_type, relationship_name, parent_node) + relationships_created[relationship_pattern] = relationships_created.get(relationship_pattern, + 0) + count + self.relationships_stat[relationship_name] = self.relationships_stat.get(relationship_name, + 0) + count + # commit and restart a transaction when batch size reached + if split and transaction_counter >= BATCH_SIZE: + tx.commit() + tx = session.begin_transaction() + self.log.info(f'{line_num -1} rows loaded ...') + transaction_counter = 0 + # commit last transaction + if split: + tx.commit() + + for rel, count in relationships_created.items(): + self.log.info('{} {} relationship(s) loaded'.format(count, rel)) + if int_nodes_created > 0: + self.log.info('{} intermediate node(s) loaded'.format(int_nodes_created)) + + return True + + @staticmethod + def get_relationship_prop_statements(props): + prop_stmts = [] + + for key in props: + prop_stmts.append('r.{0} = {{{0}}}'.format(key)) + return prop_stmts + + def wipe_db(self, session, split=False): + if split: + return self.wipe_db_split(session) + else: + cleanup_db = 'MATCH (n) DETACH DELETE n' + result = session.run(cleanup_db).summary() + self.nodes_deleted = result.counters.nodes_deleted + self.relationships_deleted = result.counters.relationships_deleted + self.log.info('{} nodes deleted!'.format(self.nodes_deleted)) + self.log.info('{} relationships deleted!'.format(self.relationships_deleted)) + + def wipe_db_split(self, session): + while True: + tx = session.begin_transaction() + try: + cleanup_db = f'MATCH (n) WITH n LIMIT {BATCH_SIZE} DETACH DELETE n' + result = session.run(cleanup_db).summary() + tx.commit() + deleted_nodes = result.counters.nodes_deleted + self.nodes_deleted += deleted_nodes + deleted_relationships = result.counters.relationships_deleted + self.relationships_deleted += deleted_relationships + self.log.info(f'{deleted_nodes} nodes deleted...') + self.log.info(f'{deleted_relationships} relationships deleted...') + if deleted_nodes == 0 and deleted_relationships == 0: + break + except Exception as e: + tx.rollback() + self.log.exception(e) + raise e + self.log.info('{} nodes deleted!'.format(self.nodes_deleted)) + self.log.info('{} relationships deleted!'.format(self.relationships_deleted)) + + def create_indexes(self, session): + """ + Creates indexes, if they do not already exist, for all entries in the "id_fields" and "indexes" sections of the + properties file + + :param session: the current neo4j transaction session + """ + existing = get_indexes(session) + # Create indexes from "id_fields" section of the properties file + ids = self.schema.props.id_fields + for node_name in ids: + self.create_index(node_name, ids[node_name], existing, session) + # Create indexes from "indexes" section of the properties file + indexes = self.schema.props.indexes + # each index is a dictionary, indexes is a list of these dictionaries + # for each dictionary in list + for node_dict in indexes: + node_name = list(node_dict.keys())[0] + self.create_index(node_name, node_dict[node_name], existing, session) + + def create_index(self, node_name, node_property, existing, session): + index_tuple = format_as_tuple(node_name, node_property) + # If node_property is a list of properties, convert to a comma delimited string + if isinstance(node_property, list): + node_property = ",".join(node_property) + if index_tuple not in existing: + command = "CREATE INDEX ON :{}({});".format(node_name, node_property) + session.run(command) + self.indexes_created += 1 + self.log.info("Index created for \"{}\" on property \"{}\"".format(node_name, node_property)) + diff --git a/file_copier_config.py b/file_copier_config.py index 2ef7aa05..4ec6cda2 100644 --- a/file_copier_config.py +++ b/file_copier_config.py @@ -1,7 +1,7 @@ import argparse import os -from bento.common.config_base import BentoConfig +from config_base import BentoConfig MASTER_MODE = 'master' SLAVE_MODE = 'slave' diff --git a/file_loader.py b/file_loader.py index f84fcefb..026724cf 100755 --- a/file_loader.py +++ b/file_loader.py @@ -22,11 +22,11 @@ from bento.common.utils import UUID, NODES_CREATED, RELATIONSHIP_CREATED, removeTrailingSlash,\ get_logger, UPSERT_MODE, send_slack_message -from bento.common.config import BentoConfig -from bento.common.props import Props +from config import BentoConfig +from props import Props from bento.common.sqs import Queue, VisibilityExtender -from bento.common.data_loader import DataLoader -from bento.common.icdc_schema import ICDC_Schema +from data_loader import DataLoader +from icdc_schema import ICDC_Schema RAW_PREFIX = 'RAW' FINAL_PREFIX = 'Final' diff --git a/icdc_schema.py b/icdc_schema.py new file mode 100644 index 00000000..a642b12c --- /dev/null +++ b/icdc_schema.py @@ -0,0 +1,583 @@ +from datetime import datetime +import os +import re +import sys + +import yaml + +from bento.common.utils import get_logger, MULTIPLIER, DEFAULT_MULTIPLIER, RELATIONSHIP_TYPE, DATE_FORMAT, get_uuid +from props import Props + +NODES = 'Nodes' +RELATIONSHIPS = 'Relationships' +PROPERTIES = 'Props' +PROP_DEFINITIONS = 'PropDefinitions' +DEFAULT_TYPE = 'String' +PROP_TYPE = 'Type' +END_POINTS = 'Ends' +SRC = 'Src' +DEST = 'Dst' +VALUE_TYPE = 'value_type' +ITEM_TYPE = 'item_type' +LIST_DELIMITER = '*' +LABEL_NEXT = 'next' +NEXT_RELATIONSHIP = 'next' +UNITS = 'units' +REQUIRED = 'Req' +PRIVATE = 'Private' +NODE_TYPE = 'type' +ENUM = 'enum' +DEFAULT_VALUE = 'default_value' +HAS_UNIT = 'has_unit' +MIN = 'minimum' +MAX = 'maximum' +EX_MIN = 'exclusiveMinimum' +EX_MAX = 'exclusiveMaximum' + + + +class ICDC_Schema: + def __init__(self, yaml_files, props): + assert isinstance(props, Props) + self.props = props + self.rel_prop_delimiter = props.rel_prop_delimiter + + if not yaml_files: + raise Exception('File list is empty, couldn\'t initialize ICDC_Schema object!') + sys.exit(1) + else: + for data_file in yaml_files: + if not os.path.isfile(data_file): + raise Exception('File "{}" doesn\'t exist'.format(data_file)) + self.log = get_logger('ICDC Schema') + self.org_schema = {} + for aFile in yaml_files: + try: + self.log.info('Reading schema file: {} ...'.format(aFile)) + if os.path.isfile(aFile): + with open(aFile) as schema_file: + schema = yaml.safe_load(schema_file) + if schema: + self.org_schema.update(schema) + except Exception as e: + self.log.exception(e) + + self.nodes = {} + self.relationships = {} + self.relationship_props = {} + self.num_relationship = 0 + + self.log.debug("-------------processing nodes-----------------") + if NODES not in self.org_schema: + self.log.error('Can\'t load any nodes!') + sys.exit(1) + + elif PROP_DEFINITIONS not in self.org_schema: + self.log.error('Can\'t load any properties!') + sys.exit(1) + + for key, value in self.org_schema[NODES].items(): + # Assume all keys start with '_' are not regular nodes + if not key.startswith('_'): + self.process_node(key, value) + self.log.debug("-------------processing edges-----------------") + if RELATIONSHIPS in self.org_schema: + for key, value in self.org_schema[RELATIONSHIPS].items(): + # Assume all keys start with '_' are not regular nodes + if not key.startswith('_'): + self.process_node(key, value, True) + self.num_relationship += self.process_edges(key, value) + + def get_uuid_for_node(self, node_type, signature): + """Generate V5 UUID for a node + Arguments: + node_type - a string represents type of a node, e.g. case, study, file etc. + signature - a string that can uniquely identify a node within it's type, e.g. case_id, clinical_study_designation etc. + or a long string with all properties and values concat together if no id available + + """ + return get_uuid(self.props.domain, node_type, signature) + + def _process_properties(self, desc): + ''' + Gather properties from description + + :param desc: description of properties + :return: a dict with properties, required property list and private property list + ''' + props = {} + required = set() + private = set() + if PROPERTIES in desc and desc[PROPERTIES] is not None: + for prop in desc[PROPERTIES]: + prop_type = self.get_type(prop) + props[prop] = prop_type + value_unit_props = self.process_value_unit_type(prop, prop_type) + if value_unit_props: + props.update(value_unit_props) + if self.is_required_prop(prop): + required.add(prop) + if self.is_private_prop(prop): + private.add(prop) + + return {PROPERTIES: props, REQUIRED: required, PRIVATE: private} + + def process_node(self, name, desc, isRelationship=False): + ''' + Process input node/relationship properties and save it in self.nodes + + :param name: node/relationship name + :param desc: + :param isRelationship: if input is a relationship + :return: + ''' + properties = self._process_properties(desc) + + + # All nodes and relationships that has properties will be save to self.nodes + # Relationship without properties will be ignored + if properties[PROPERTIES] or not isRelationship: + self.nodes[name] = properties + + def process_edges(self, name, desc): + count = 0 + if MULTIPLIER in desc: + multiplier = desc[MULTIPLIER] + else: + multiplier = DEFAULT_MULTIPLIER + + properties = self._process_properties(desc) + self.relationship_props[name] = properties + + if END_POINTS in desc: + for end_points in desc[END_POINTS]: + src = end_points[SRC] + dest = end_points[DEST] + if MULTIPLIER in end_points: + actual_multiplier = end_points[MULTIPLIER] + self.log.debug('End point multiplier: "{}" overriding relationship multiplier: "{}"'.format(actual_multiplier, multiplier)) + else: + actual_multiplier = multiplier + if src not in self.relationships: + self.relationships[src] = {} + self.relationships[src][dest] = { RELATIONSHIP_TYPE: name, MULTIPLIER: actual_multiplier } + + count += 1 + if src in self.nodes: + self.add_relationship_to_node(src, actual_multiplier, name, dest) + # nodes[src][self.plural(dest)] = '[{}] @relation(name:"{}")'.format(dest, name) + else: + self.log.error('Source node "{}" not found!'.format(src)) + if dest in self.nodes: + self.add_relationship_to_node(dest, actual_multiplier, name, src, True) + # nodes[dest][self.plural(src)] = '[{}] @relation(name:"{}", direction:IN)'.format(src, name) + else: + self.log.error('Destination node "{}" not found!'.format(dest)) + return count + + # Process singular/plural array/single value based on relationship multipliers like many-to-many, many-to-one etc. + # Return a relationship property to add into a node + def add_relationship_to_node(self, name, multiplier, relationship, otherNode, dest=False): + node = self.nodes[name] + if multiplier == 'many_to_one': + if dest: + node[PROPERTIES][self.plural(otherNode)] = { PROP_TYPE: '[{}] @relation(name:"{}", direction:IN)'.format(otherNode, relationship) } + else: + node[PROPERTIES][otherNode] = {PROP_TYPE: '{} @relation(name:"{}", direction:OUT)'.format(otherNode, relationship) } + elif multiplier == 'one_to_one': + if relationship == NEXT_RELATIONSHIP: + if dest: + node[PROPERTIES]['prior_' + otherNode] = {PROP_TYPE: '{} @relation(name:"{}", direction:IN)'.format(otherNode, relationship) } + else: + node[PROPERTIES]['next_' + otherNode] = {PROP_TYPE: '{} @relation(name:"{}", direction:OUT)'.format(otherNode, relationship) } + else: + if dest: + node[PROPERTIES][otherNode] = {PROP_TYPE: '{} @relation(name:"{}", direction:IN)'.format(otherNode, relationship) } + else: + node[PROPERTIES][otherNode] = {PROP_TYPE: '{} @relation(name:"{}", direction:OUT)'.format(otherNode, relationship) } + elif multiplier == 'many_to_many': + if dest: + node[PROPERTIES][self.plural(otherNode)] = {PROP_TYPE: '[{}] @relation(name:"{}", direction:IN)'.format(otherNode, relationship) } + else: + node[PROPERTIES][self.plural(otherNode)] = {PROP_TYPE: '[{}] @relation(name:"{}", direction:OUT)'.format(otherNode, relationship) } + else: + self.log.warning('Unsupported relationship multiplier: "{}"'.format(multiplier)) + + def is_required_prop(self, name): + result = False + if name in self.org_schema[PROP_DEFINITIONS]: + prop = self.org_schema[PROP_DEFINITIONS][name] + result = prop.get(REQUIRED, False) + return result + + def is_private_prop(self, name): + result = False + if name in self.org_schema[PROP_DEFINITIONS]: + prop = self.org_schema[PROP_DEFINITIONS][name] + result = prop.get(PRIVATE, False) + return result + + def get_prop_type(self, node_type, prop): + if node_type in self.nodes: + node = self.nodes[node_type] + if prop in node[PROPERTIES]: + return node[PROPERTIES][prop][PROP_TYPE] + return DEFAULT_TYPE + + def get_type(self, name): + result = { PROP_TYPE: DEFAULT_TYPE } + if name in self.org_schema[PROP_DEFINITIONS]: + prop = self.org_schema[PROP_DEFINITIONS][name] + if PROP_TYPE in prop: + prop_desc = prop[PROP_TYPE] + if isinstance(prop_desc, str): + result[PROP_TYPE] = self.map_type(prop_desc) + elif isinstance(prop_desc, dict): + if VALUE_TYPE in prop_desc: + result[PROP_TYPE] = self.map_type(prop_desc[VALUE_TYPE]) + if ITEM_TYPE in prop_desc: + item_type = self._get_item_type(prop_desc[ITEM_TYPE]) + result[ITEM_TYPE] = item_type + if UNITS in prop_desc: + result[HAS_UNIT] = True + elif isinstance(prop_desc, list): + enum = set() + for t in prop_desc: + if not re.search(r'://', t): + enum.add(t) + if len(enum) > 0: + result[ENUM] = enum + else: + self.log.debug('Property type: "{}" not supported, use default type: "{}"'.format(prop_desc, DEFAULT_TYPE)) + + # Add value boundary support + if MIN in prop: + result[MIN] = float(prop[MIN]) + if MAX in prop: + result[MAX] = float(prop[MAX]) + if EX_MIN in prop: + result[EX_MIN] = float(prop[EX_MIN]) + if EX_MAX in prop: + result[EX_MAX] = float(prop[EX_MAX]) + + return result + + def _get_item_type(self, item_type): + if isinstance(item_type, str): + return {PROP_TYPE: self.map_type(item_type)} + elif isinstance(item_type, list): + enum = set() + for t in item_type: + if not re.search(r'://', t): + enum.add(t) + if len(enum) > 0: + return {PROP_TYPE: DEFAULT_TYPE, ENUM: enum} + else: + return None + else: + self.log.error(f"{item_type} is not a scala or Enum!") + return None + + def get_prop(self, node_name, name): + if node_name in self.nodes: + node = self.nodes[node_name] + if name in node[PROPERTIES]: + return node[PROPERTIES][name] + return None + + def get_default_value(self, node_name, name): + prop = self.get_prop(node_name, name) + if prop: + return prop.get(DEFAULT_VALUE, None) + + def get_default_unit(self, node_name, name): + unit_prop_name = self.get_unit_property_name(name) + return self.get_default_value(node_name, unit_prop_name) + + + def get_valid_values(self, node_name, name): + prop = self.get_prop(node_name, name) + if prop: + return prop.get(ENUM, None) + + def get_valid_units(self, node_name, name): + unit_prop_name = self.get_unit_property_name(name) + return self.get_valid_values(node_name, unit_prop_name) + + def get_extra_props(self, node_name, name, value): + results = {} + prop = self.get_prop(node_name, name) + if prop and HAS_UNIT in prop and prop[HAS_UNIT]: + # For MVP use default unit for all values + results[self.get_unit_property_name(name)] = self.get_default_unit(node_name, name) + org_prop_name = self.get_original_value_property_name(name) + # For MVP use value is same as original value + results[org_prop_name] = value + results[self.get_unit_property_name(org_prop_name)] = self.get_default_unit(node_name, name) + return results + + def process_value_unit_type(self, name, prop_type): + results = {} + if name in self.org_schema[PROP_DEFINITIONS]: + prop = self.org_schema[PROP_DEFINITIONS][name] + if PROP_TYPE in prop: + prop_desc = prop[PROP_TYPE] + if isinstance(prop_desc, dict): + if UNITS in prop_desc: + units = prop_desc[UNITS] + if units: + enum = set(units) + unit_prop_name = self.get_unit_property_name(name) + results[unit_prop_name] = {PROP_TYPE: DEFAULT_TYPE, ENUM: enum, DEFAULT_VALUE: units[0]} + org_prop_name = self.get_original_value_property_name(name) + org_unit_prop_name = self.get_unit_property_name(org_prop_name) + results[org_prop_name] = prop_type + results[org_unit_prop_name] = {PROP_TYPE: DEFAULT_TYPE, ENUM: enum, DEFAULT_VALUE: units[0]} + return results + + @staticmethod + def get_unit_property_name(name): + return name + '_unit' + + @staticmethod + def get_original_value_property_name(name): + return name + '_original' + + def validate_node(self, model_type, obj): + if not model_type or model_type not in self.nodes: + return {'result': False, 'messages': ['Node type: "{}" doesn\'t exist!'.format(model_type)]} + if not obj: + return {'result': False, 'messages': ['Node is empty!']} + + if not isinstance(obj, dict): + return {'result': False, 'messages': ['Node is not a dict!']} + + # Make sure all required properties exist, and are not empty + result = {'result': True, 'messages': []} + for prop in self.nodes[model_type].get(REQUIRED, set()): + if prop not in obj: + result['result'] = False + result['messages'].append('Missing required property: "{}"!'.format(prop)) + elif not obj[prop]: + result['result'] = False + result['messages'].append('Required property: "{}" is empty!'.format(prop)) + + properties = self.nodes[model_type][PROPERTIES] + # Validate all properties in given object + for key, value in obj.items(): + if key == NODE_TYPE: + continue + elif self.is_parent_pointer(key): + continue + elif self.is_relationship_property(key): + rel_type, rel_prop = key.split(self.rel_prop_delimiter) + if rel_type not in self.relationship_props: + result['result'] = False + result['messages'].append(f'Relationship "{rel_type}" does NOT exist in data model!') + continue + elif rel_prop not in self.relationship_props[rel_type][PROPERTIES]: + result['result'] = False + result['messages'].append(f'Property "{rel_prop}" does NOT exist in relationship "{rel_type}"!') + continue + + prop_type = self.relationship_props[rel_type][PROPERTIES][rel_prop] + if not self._validate_type(prop_type, value): + result['result'] = False + result['messages'].append( + 'Property: "{}":"{}" is not a valid "{}" type!'.format(rel_prop, value, prop_type)) + + elif key not in properties: + self.log.warn('Property "{}" is not in data model!'.format(key)) + else: + prop_type = properties[key] + if not self._validate_type(prop_type, value): + result['result'] = False + result['messages'].append('Property: "{}":"{}" is not a valid "{}" type!'.format(key, value, prop_type)) + + return result + + @staticmethod + def _validate_value_range(model_type, value): + ''' + Validate an int of float value, return whether value is in range + + :param model_type: dict specify value type and boundary/range + :param value: value to be validated + :return: boolean + ''' + + if MIN in model_type: + if value < model_type[MIN]: + return False + if MAX in model_type: + if value > model_type[MAX]: + return False + if EX_MIN in model_type: + if value <= model_type[EX_MIN]: + return False + if EX_MAX in model_type: + if value >= model_type[EX_MAX]: + return False + return True + + def _validate_type(self, model_type, str_value): + if model_type[PROP_TYPE] == 'Float': + try: + if str_value: + value = float(str_value) + if not self._validate_value_range(model_type, value): + return False + except ValueError: + return False + elif model_type[PROP_TYPE] == 'Int': + try: + if str_value: + value = int(str_value) + if not self._validate_value_range(model_type, value): + return False + except ValueError: + return False + elif model_type[PROP_TYPE] == 'Boolean': + if (str_value and not re.match(r'\byes\b|\btrue\b', str_value, re.IGNORECASE) + and not re.match(r'\bno\b|\bfalse\b', str_value, re.IGNORECASE) + and not re.match(r'\bltf\b', str_value, re.IGNORECASE)): + return False + elif model_type[PROP_TYPE] == 'Array': + for item in self.get_list_values(str_value): + if not self._validate_type(model_type[ITEM_TYPE], item): + return False + + elif model_type[PROP_TYPE] == 'Object': + if not isinstance(str_value, dict): + return False + elif model_type[PROP_TYPE] == 'String': + if ENUM in model_type: + if not isinstance(str_value, str): + return False + if str_value != '' and str_value not in model_type[ENUM]: + return False + elif model_type[PROP_TYPE] == 'Date': + if not isinstance(str_value, str): + return False + try: + if str_value.strip() != '': + datetime.strptime(str_value, DATE_FORMAT) + except ValueError: + return False + elif model_type[PROP_TYPE] == 'DateTime': + if not isinstance(str_value, str): + return False + try: + if str_value.strip() != '': + datetime.strptime(str_value, DATE_FORMAT) + except ValueError: + return False + return True + + def get_list_values(self, list_str): + return [item.strip() for item in list_str.split(LIST_DELIMITER)] + + # Find relationship type from src to dest + def get_relationship(self, src, dest): + if src in self.relationships: + relationships = self.relationships[src] + if relationships and dest in relationships: + return relationships[dest] + else: + self.log.error('No relationships found for "{}"-->"{}"'.format(src, dest)) + return None + else: + self.log.debug('No relationships start from "{}"'.format(src)) + return None + + # Find destination node name from (:src)-[:name]->(:dest) + def get_dest_node_for_relationship(self, src, name): + if src in self.relationships: + relationships = self.relationships[src] + if relationships: + for dest, rel in relationships.items(): + if rel[RELATIONSHIP_TYPE] == name: + return dest + else: + self.log.error('Couldn\'t find any relationship from (:{})'.format(src)) + return None + + + # Get type info from description + def map_type(self, type_name): + mapping = self.props.type_mapping + result = DEFAULT_TYPE + + if type_name in mapping: + result = mapping[type_name] + else: + self.log.debug('Type: "{}" has no mapping, use default type: "{}"'.format(type_name, DEFAULT_TYPE)) + + return result + + def plural(self, word): + plurals = self.props.plurals + if word in plurals: + return plurals[word] + else: + self.log.warning('Plural for "{}" not found!'.format(word)) + return 'NONE' + + # Get all node names, sorted + def get_node_names(self): + return sorted(self.nodes.keys()) + + def node_count(self): + return len(self.nodes) + + def relationship_count(self): + return self.num_relationship + + # Get all properties of a node (name) + def get_props_for_node(self, node_name): + if node_name in self.nodes: + return self.nodes[node_name][PROPERTIES] + else: + return None + + # Get all properties of a node (name) + def get_public_props_for_node(self, node_name): + if node_name in self.nodes: + props = self.nodes[node_name][PROPERTIES].copy() + for private_prop in self.nodes[node_name].get(PRIVATE, []): + del(props[private_prop]) + self.log.info('Delete private property: "{}"'.format(private_prop)) + return props + else: + return None + + # Get node's id field, such as case_id for case node, or clinical_study_designation for study node + def get_id_field(self, obj): + if NODE_TYPE not in obj: + self.log.error('get_id_field: there is no "{}" field in node, can\'t retrieve id!'.format(NODE_TYPE)) + return None + node_type = obj[NODE_TYPE] + id_fields = self.props.id_fields + if node_type: + return id_fields.get(node_type, 'uuid') + else: + self.log.error('get_id_field: "{}" field is empty'.format(NODE_TYPE)) + return None + + # Find node's id + def get_id(self, obj): + id_field = self.get_id_field(obj) + if not id_field: + return None + if id_field not in obj: + return None + else: + return obj[id_field] + + def is_relationship_property(self, key): + return re.match(r'^.+\{}.+$'.format(self.rel_prop_delimiter), key) + + + def is_parent_pointer(self, field_name): + return re.fullmatch(r'\w+\.\w+', field_name) is not None + diff --git a/loader.py b/loader.py index 39cc8ddb..97fdf9e1 100755 --- a/loader.py +++ b/loader.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 import argparse -import datetime import glob import os import sys @@ -8,19 +7,18 @@ from neo4j import GraphDatabase, ServiceUnavailable from neobolt.exceptions import AuthError -from bento.common.icdc_schema import ICDC_Schema -from bento.common.props import Props -from bento.common.utils import get_logger, removeTrailingSlash, check_schema_files, DATETIME_FORMAT, get_host, \ - UPSERT_MODE, NEW_MODE, DELETE_MODE, get_log_file, LOG_PREFIX, APP_NAME -from bento.common.visit_creator import VisitCreator +from icdc_schema import ICDC_Schema +from props import Props +from bento.common.utils import get_logger, removeTrailingSlash, check_schema_files, UPSERT_MODE, NEW_MODE, DELETE_MODE, get_log_file, LOG_PREFIX, APP_NAME +from visit_creator import VisitCreator if LOG_PREFIX not in os.environ: os.environ[LOG_PREFIX] = 'Data_Loader' os.environ[APP_NAME] = 'Data_Loader' -from bento.common.config import BentoConfig -from bento.common.data_loader import DataLoader +from config import BentoConfig +from data_loader import DataLoader from bento.common.s3 import S3Bucket diff --git a/model-converter.py b/model-converter.py index b4a7cf75..2b792a83 100755 --- a/model-converter.py +++ b/model-converter.py @@ -5,8 +5,8 @@ import argparse import os import sys -from bento.common.icdc_schema import ICDC_Schema, PROP_TYPE -from bento.common.props import Props +from icdc_schema import ICDC_Schema, PROP_TYPE +from props import Props from bento.common.utils import check_schema_files, get_logger diff --git a/props.py b/props.py new file mode 100644 index 00000000..04b2921d --- /dev/null +++ b/props.py @@ -0,0 +1,25 @@ +import os +import yaml +from bento.common.utils import get_logger + +class Props: + def __init__(self, file_name): + self.log = get_logger('Props') + if file_name and os.path.isfile(file_name): + with open(file_name) as prop_file: + props = yaml.safe_load(prop_file)['Properties'] + if not props: + msg = 'Can\'t read property file!' + self.log.error(msg) + raise Exception(msg) + self.plurals = props.get('plurals', {}) + self.type_mapping = props.get('type_mapping', {}) + self.id_fields = props.get('id_fields', {}) + self.visit_date_in_nodes = props.get('visit_date_in_nodes', {}) + self.domain = props.get('domain', 'Unknown.domain.nci.nih.gov') + self.rel_prop_delimiter = props.get('rel_prop_delimiter', '$') + self.indexes = props.get('indexes', []) + else: + msg = f'Can NOT open file: "{file_name}"' + self.log.error(msg) + raise Exception(msg) diff --git a/tests/test_file_loader.py b/tests/test_file_loader.py index 0afb82bd..21b262e6 100644 --- a/tests/test_file_loader.py +++ b/tests/test_file_loader.py @@ -3,10 +3,10 @@ import os from neo4j import GraphDatabase from file_loader import FileLoader -from bento.common.icdc_schema import ICDC_Schema -from bento.common.props import Props -from bento.common.config import BentoConfig -from bento.common.data_loader import DataLoader +from icdc_schema import ICDC_Schema +from props import Props +from config import BentoConfig +from data_loader import DataLoader class TestLambda(unittest.TestCase): @@ -51,4 +51,4 @@ def test_lambda(self): load_result = self.loader.load(self.file_list, True, False, 'upsert', False, 1) self.assertIsInstance(load_result, dict, msg='Load data failed!') - self.assertTrue(self.processor.handler(self.event)) \ No newline at end of file + self.assertTrue(self.processor.handler(self.event)) diff --git a/tests/test_loader.py b/tests/test_loader.py index f231b36c..9a566e1b 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -1,9 +1,9 @@ import unittest import os from bento.common.utils import get_logger, removeTrailingSlash, UUID -from bento.common.data_loader import DataLoader -from bento.common.icdc_schema import ICDC_Schema -from bento.common.props import Props +from data_loader import DataLoader +from icdc_schema import ICDC_Schema +from props import Props from neo4j import GraphDatabase @@ -115,4 +115,4 @@ def test_cleanup_node(self): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/test_reloading_data.py b/tests/test_reloading_data.py index 680a5fc5..5685fbbe 100644 --- a/tests/test_reloading_data.py +++ b/tests/test_reloading_data.py @@ -1,8 +1,8 @@ import unittest from bento.common.utils import get_logger, NODES_CREATED, RELATIONSHIP_CREATED, NODES_DELETED, RELATIONSHIP_DELETED -from bento.common.data_loader import DataLoader -from bento.common.icdc_schema import ICDC_Schema -from bento.common.props import Props +from data_loader import DataLoader +from icdc_schema import ICDC_Schema +from props import Props import os from neo4j import GraphDatabase @@ -136,4 +136,4 @@ def test_reload_upsert(self): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/test_schema.py b/tests/test_schema.py index dd978d4e..30baf718 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,6 +1,6 @@ import unittest -from bento.common.icdc_schema import ICDC_Schema -from bento.common.props import Props +from icdc_schema import ICDC_Schema +from props import Props class TestSchema(unittest.TestCase): @@ -50,4 +50,4 @@ def test_get_id_field(self): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/visit_creator.py b/visit_creator.py new file mode 100644 index 00000000..2c1a29a5 --- /dev/null +++ b/visit_creator.py @@ -0,0 +1,188 @@ +from datetime import datetime, timedelta + +from neo4j import Session, Transaction + +from icdc_schema import ICDC_Schema +from bento.common.utils import get_logger, UUID, RELATIONSHIP_TYPE, DATE_FORMAT + +VISIT_NODE = 'visit' +VISIT_ID = 'visit_id' +VISIT_DATE = 'visit_date' +OF_CYCLE = 'of_cycle' +CYCLE_NODE = 'cycle' +INFERRED = 'inferred' +START_DATE = 'date_of_cycle_start' +END_DATE = 'date_of_cycle_end' +CYCLE_ID = 'cycle_id' + +PREDATE = 7 +FOREVER = '9999-12-31' + +# duplicated declaration from data_loader.py +NODE_TYPE = 'type' +CREATED = 'created' +UPDATED = 'updated' +CASE_ID = 'case_id' +CASE_NODE = 'case' + + +# Intermediate node creator +class VisitCreator: + def __init__(self, schema): + if not schema or not isinstance(schema, ICDC_Schema): + raise Exception('Invalid ICDC_Schema object') + self.schema = schema + self.log = get_logger('VisitCreator') + self.nodes_created = 0 + self.relationships_created = 0 + self.nodes_stat = {} + self.relationships_stat = {} + # Dictionary to cache case IDs and their associated cycles in order to prevent redundant querying + self.cycle_map = {} + + def is_valid_int_node(self, node_type): + return node_type == VISIT_NODE + + def create_intermediate_node(self, session, line_num, node_type, node_id, src): + if node_type != VISIT_NODE: + self.log.debug("Line: {}: Won't create node for type: '{}'".format(line_num, VISIT_NODE, node_type)) + return False + if not node_id: + self.log.error("Line: {}: Can't create (:{}) node for id: '{}'".format(line_num, VISIT_NODE, node_id)) + return False + if not src: + self.log.error("Line: {}: Can't create (:{}) node for empty object".format(line_num, VISIT_NODE)) + return False + if not session or (not isinstance(session, Session) and not isinstance(session, Transaction)): + self.log.error("Neo4j session is not valid!") + return False + date_map = self.schema.props.visit_date_in_nodes + if NODE_TYPE not in src: + self.log.error('Line: {}: Given object doesn\'t have a "{}" field!'.format(line_num, NODE_TYPE)) + return False + source_type = src[NODE_TYPE] + date = src[date_map[source_type]] + if not date: + self.log.error('Line: {}: Visit date is empty!'.format(line_num)) + return False + if NODE_TYPE not in src: + self.log.error('Line: {}: Given object doesn\'t have a "{}" field!'.format(line_num, NODE_TYPE)) + return False + statement = 'MERGE (v:{} {{ {}: {{node_id}}, {}: {{date}}, {}: true, {}: {{{}}} }})'.format( + VISIT_NODE, VISIT_ID, VISIT_DATE, INFERRED, UUID, UUID) + statement += ' ON CREATE SET v.{} = datetime()'.format(CREATED) + statement += ' ON MATCH SET v.{} = datetime()'.format(UPDATED) + + result = session.run(statement, {"node_id": node_id, "date": date, + UUID: self.schema.get_uuid_for_node(VISIT_NODE, node_id)}) + if result: + count = result.summary().counters.nodes_created + self.nodes_created += count + self.nodes_stat[VISIT_NODE] = self.nodes_stat.get(VISIT_NODE, 0) + count + if count > 0: + case_id = src[CASE_ID] + if not self.connect_visit_to_cycle(session, line_num, node_id, case_id, date): + self.log.error('Line: {}: Visit: "{}" does NOT belong to a cycle!'.format(line_num, node_id)) + return True + else: + return False + + def connect_visit_to_cycle(self, session, line_num, visit_id, case_id, visit_date): + cycle_data_array = [] + if case_id not in self.cycle_map: + find_cycles_stmt = 'MATCH (c:cycle) WHERE c.case_id = {case_id} RETURN c ORDER BY c.date_of_cycle_start' + result = session.run(find_cycles_stmt, {'case_id': case_id}) + if result: + # Iterates through each record in the result + for record in result.records(): + # Retreives the cycle object from the record + cycle = record.data()['c'] + # Stores the relevant cycle data in a dictionary + formatted_start_date = datetime.strptime(cycle[START_DATE], DATE_FORMAT) + formatted_end_date = None + if cycle[END_DATE]: + formatted_end_date = datetime.strptime(cycle[END_DATE], DATE_FORMAT) + cycle_data = { + START_DATE: formatted_start_date, + END_DATE: formatted_end_date, + CYCLE_ID: cycle.id + } + # Adds the dictionary to an array for storage + cycle_data_array.append(cycle_data) + # The array of cycle data dictionaries is added to the cycle map + self.cycle_map[case_id] = cycle_data_array + else: + cycle_data_array = self.cycle_map[case_id] + if len(cycle_data_array) > 0: + first_date = None + pre_date = None + relationship_name = self.schema.get_relationship(VISIT_NODE, CYCLE_NODE)[RELATIONSHIP_TYPE] + if not relationship_name: + return False + for cycle_data in cycle_data_array: + date = datetime.strptime(visit_date, DATE_FORMAT) + start_date = cycle_data[START_DATE] + if not first_date: + first_date = start_date + pre_date = first_date - timedelta(days=PREDATE) + if cycle_data[END_DATE]: + end_date = cycle_data[END_DATE] + else: + self.log.warning('Line: {}: No end dates for cycle started on {} for {}'.format(line_num, + start_date.strftime( + DATE_FORMAT), + case_id)) + end_date = datetime.strptime(FOREVER, DATE_FORMAT) + if (start_date <= date <= end_date) or (first_date > date >= pre_date): + if first_date > date >= pre_date: + self.log.info( + 'Line: {}: Date: {} is before first cycle, but within {}'.format(line_num, visit_date, + PREDATE) + + ' days before first cycle started: {}, connected to first cycle'.format( + first_date.strftime(DATE_FORMAT))) + connect_stmt = 'MATCH (v:{} {{ {}: {{visit_id}} }}) '.format(VISIT_NODE, VISIT_ID) + connect_stmt += 'MATCH (c:{}) WHERE id(c) = {{cycle_id}} '.format(CYCLE_NODE) + connect_stmt += 'MERGE (v)-[r:{} {{ {}: true }}]->(c)'.format(relationship_name, INFERRED) + connect_stmt += ' ON CREATE SET r.{} = datetime()'.format(CREATED) + connect_stmt += ' ON MATCH SET r.{} = datetime()'.format(UPDATED) + + cnt_result = session.run(connect_stmt, {'visit_id': visit_id, 'cycle_id': cycle_data[CYCLE_ID]}) + relationship_created = cnt_result.summary().counters.relationships_created + if relationship_created > 0: + self.relationships_created += relationship_created + self.relationships_stat[relationship_name] = self.relationships_stat.get(relationship_name, + 0) + relationship_created + return True + else: + self.log.error( + 'Line: {}: Create (:visit)-[:of_cycle]->(:cycle) relationship failed!'.format(line_num)) + return False + self.log.warning('Line: {}: Date: {} does not belong to any cycles, connected to case {} directly!'.format( + line_num, visit_date, case_id)) + return self.connect_visit_to_case(session, line_num, visit_id, case_id) + else: + self.log.error('Line: {}: No cycles found for case: {}'.format(line_num, case_id)) + return False + + def connect_visit_to_case(self, session, line_num, visit_id, case_id): + relationship_name = self.schema.get_relationship(VISIT_NODE, CASE_NODE)[RELATIONSHIP_TYPE] + if not relationship_name: + return False + cnt_statement = 'MATCH (c:case {{ case_id: {{case_id}} }}) MATCH (v:visit {{ {}: {{visit_id}} }}) '.format( + VISIT_ID) + cnt_statement += 'MERGE (c)<-[r:{} {{ {}: true }}]-(v)'.format(relationship_name, INFERRED) + cnt_statement += ' ON CREATE SET r.{} = datetime()'.format(CREATED) + cnt_statement += ' ON MATCH SET r.{} = datetime()'.format(UPDATED) + + result = session.run(cnt_statement, {'case_id': case_id, 'visit_id': visit_id}) + relationship_created = result.summary().counters.relationships_created + if relationship_created > 0: + self.relationships_created += relationship_created + self.relationships_stat[relationship_name] = self.relationships_stat.get(relationship_name, + 0) + relationship_created + return True + else: + self.log.error('Line: {}: Create (:{})-[:{}]->(:{}) relationship failed!'.format(line_num, VISIT_NODE, + relationship_name, + CASE_NODE)) + return False