From 06c08e6a05e88cc084f277398c5a8872c47d5337 Mon Sep 17 00:00:00 2001 From: Maxime Mulder Date: Thu, 10 Oct 2024 15:34:47 -0400 Subject: [PATCH] Improve subject configuration structure (#1150) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refacor subject config * remove suspicious phantom code * change naming * fix rebase lints * cleaner database url * Cecile CandID path join bug fix Co-authored-by: Cécile Madjar --------- Co-authored-by: Cécile Madjar --- dicom-archive/database_config_template.py | 68 +++++----- pyproject.toml | 3 +- python/lib/config_file.py | 80 ++++++++++++ python/lib/database.py | 22 ++-- python/lib/database_lib/candidate_db.py | 2 +- python/lib/db/connect.py | 32 +++-- .../base_pipeline.py | 120 ++++++------------ .../dicom_validation_pipeline.py | 2 +- .../nifti_insertion_pipeline.py | 38 +++--- .../push_imaging_files_to_s3_pipeline.py | 5 +- python/lib/imaging.py | 16 +-- python/lib/lorisgetopt.py | 13 +- python/lib/validate_subject_ids.py | 73 ----------- python/lib/validate_subject_info.py | 40 ++++++ 14 files changed, 258 insertions(+), 256 deletions(-) create mode 100644 python/lib/config_file.py delete mode 100644 python/lib/validate_subject_ids.py create mode 100644 python/lib/validate_subject_info.py diff --git a/dicom-archive/database_config_template.py b/dicom-archive/database_config_template.py index 9715d7dee..9f65d28a9 100644 --- a/dicom-archive/database_config_template.py +++ b/dicom-archive/database_config_template.py @@ -1,43 +1,53 @@ #!/usr/bin/env python import re +from lib.database import Database from lib.imaging import Imaging +from lib.config_file import CreateVisitInfo, DatabaseConfig, S3Config, SubjectInfo -mysql = { - 'host' : 'DBHOST', - 'username': 'DBUSER', - 'passwd' : 'DBPASS', - 'database': 'DBNAME', - 'port' : '' -} -s3 = { - 'aws_access_key_id' : 'AWS_ACCESS_KEY_ID', - 'aws_secret_access_key': 'AWS_SECRET_ACCESS_KEY', - 'aws_s3_endpoint_url' : 'AWS_S3_ENDPOINT', - 'aws_s3_bucket_name' : 'AWS_S3_BUCKET_NAME', -} +mysql: DatabaseConfig = DatabaseConfig( + host = 'DBHOST', + username = 'DBUSER', + password = 'DBPASS', + database = 'DBNAME', + port = 3306, +) +# This statement can be omitted if the project does not use AWS S3. +s3: S3Config = S3Config( + aws_access_key_id = 'AWS_ACCESS_KEY_ID', + aws_secret_access_key = 'AWS_SECRET_ACCESS_KEY', + aws_s3_endpoint_url = 'AWS_S3_ENDPOINT', + aws_s3_bucket_name = 'AWS_S3_BUCKET_NAME', +) -def get_subject_ids(db, dicom_value=None, scanner_id=None): - - subject_id_dict = {} +def get_subject_info(db: Database, subject_name: str, scanner_id: int | None = None) -> SubjectInfo | None: imaging = Imaging(db, False) - phantom_match = re.search(r'(pha)|(test)', dicom_value, re.IGNORECASE) - candidate_match = re.search(r'([^_]+)_(\d+)_([^_]+)', dicom_value, re.IGNORECASE) + phantom_match = re.search(r'(pha)|(test)', subject_name, re.IGNORECASE) + candidate_match = re.search(r'([^_]+)_(\d+)_([^_]+)', subject_name, re.IGNORECASE) if phantom_match: - subject_id_dict['isPhantom'] = True - subject_id_dict['CandID'] = imaging.get_scanner_candid(scanner_id) - subject_id_dict['visitLabel'] = dicom_value.strip() - subject_id_dict['createVisitLabel'] = 1 + return SubjectInfo.from_phantom( + name = subject_name, + # Pass the scanner candidate CandID. If the scanner candidate does not exist in the + # database yet, create it in this function. + cand_id = imaging.get_scanner_candid(scanner_id), + visit_label = subject_name.strip(), + create_visit = CreateVisitInfo( + project_id = 1, # Change to relevant project ID + cohort_id = 1, # Change to relevant cohort ID + ), + ) elif candidate_match: - subject_id_dict['isPhantom'] = False - subject_id_dict['PSCID'] = candidate_match.group(1) - subject_id_dict['CandID'] = candidate_match.group(2) - subject_id_dict['visitLabel'] = candidate_match.group(3) - subject_id_dict['createVisitLabel'] = 0 - - return subject_id_dict + return SubjectInfo.from_candidate( + name = subject_name, + psc_id = candidate_match.group(1), + cand_id = int(candidate_match.group(2)), + visit_label = candidate_match.group(3), + create_visit = None, + ) + + return None diff --git a/pyproject.toml b/pyproject.toml index 3ea22bc20..5a21d8238 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,8 @@ include = [ "python/tests", "python/lib/db", "python/lib/exception", - "python/lib/validate_subject_ids.py", + "python/lib/config_file.py", + "python/lib/validate_subject_info.py", ] typeCheckingMode = "strict" reportMissingTypeStubs = "none" diff --git a/python/lib/config_file.py b/python/lib/config_file.py new file mode 100644 index 000000000..f625d14b7 --- /dev/null +++ b/python/lib/config_file.py @@ -0,0 +1,80 @@ +""" +This module stores the classes used in the Python configuration file of LORIS-MRI. +""" + +from dataclasses import dataclass + + +@dataclass +class DatabaseConfig: + """ + Class wrapping the MariaDB / MySQL database access configuration. + """ + + host: str + username: str + password: str + database: str + port: int = 3306 # Default database port. + + +@dataclass +class S3Config: + """ + Class wrapping AWS S3 access configuration. + """ + + aws_access_key_id: str + aws_secret_access_key: str + aws_s3_endpoint_url: str | None = None # Can also be obtained from the database. + aws_s3_bucket_name: str | None = None # Can also be obtained from the database. + + +@dataclass +class CreateVisitInfo: + """ + Class wrapping the parameters for automated visit creation (in the `Visit_Windows` table). + """ + + project_id: int + cohort_id: int + + +@dataclass +class SubjectInfo: + """ + Dataclass wrapping information about a subject configuration, including information about the + candidate, the visit label, and the automated visit creation (or not). + """ + + # The name of the subject may be either the DICOM's PatientName or PatientID depending on the + # LORIS configuration. + name: str + is_phantom: bool + # For a phantom scan, the PSCID is 'scanner'. + psc_id: str + # For a phantom scan, the CandID is that of the scanner. + cand_id: int + visit_label: str + # `CreateVisitInfo` means that a visit can be created automatically using the parameters + # provided, `None` means that the visit needs to already exist in the database. + create_visit: CreateVisitInfo | None + + @staticmethod + def from_candidate( + name: str, + psc_id: str, + cand_id: int, + visit_label: str, + create_visit: CreateVisitInfo | None, + ): + return SubjectInfo(name, False, psc_id, cand_id, visit_label, create_visit) + + @staticmethod + def from_phantom( + name: str, + cand_id: int, + visit_label: str, + create_visit: CreateVisitInfo | None, + ): + return SubjectInfo(name, True, 'scanner', cand_id, visit_label, create_visit) diff --git a/python/lib/database.py b/python/lib/database.py index ecdda8f25..ad0b3d91a 100644 --- a/python/lib/database.py +++ b/python/lib/database.py @@ -5,6 +5,7 @@ import MySQLdb import lib.exitcode +from lib.config_file import DatabaseConfig __license__ = "GPLv3" @@ -62,25 +63,22 @@ class Database: db.disconnect() """ - def __init__(self, credentials, verbose): + def __init__(self, config: DatabaseConfig, verbose: bool): """ Constructor method for the Database class. - :param credentials: LORIS database credentials - :type credentials: dict - :param verbose : whether to be verbose or not - :type verbose : bool + :param config: LORIS database credentials + :param verbose: whether to be verbose or not """ self.verbose = verbose # grep database credentials - default_port = 3306 - self.db_name = credentials['database'] - self.user_name = credentials['username'] - self.password = credentials['passwd'] - self.host_name = credentials['host'] - port = credentials['port'] + self.db_name = config.database + self.user_name = config.username + self.password = config.password + self.host_name = config.host + self.port = config.port if not self.user_name: raise Exception("\nUser name cannot be empty string.\n") @@ -89,8 +87,6 @@ def __init__(self, credentials, verbose): if not self.host_name: raise Exception("\nDatabase host cannot be empty string.\n") - self.port = int(port) if port else default_port - def connect(self): """ Attempts to connect to the database using the connection parameters diff --git a/python/lib/database_lib/candidate_db.py b/python/lib/database_lib/candidate_db.py index 766616bf6..e7f97a81b 100644 --- a/python/lib/database_lib/candidate_db.py +++ b/python/lib/database_lib/candidate_db.py @@ -38,7 +38,7 @@ def __init__(self, db, verbose): self.verbose = verbose @deprecated('Use `lib.db.query.candidate.try_get_candidate_with_cand_id` instead') - def get_candidate_psc_id(self, cand_id: str | int) -> str | None: + def get_candidate_psc_id(self, cand_id: int) -> str | None: """ Return a candidate PSCID and based on its CandID, or `None` if no candidate is found in the database. diff --git a/python/lib/db/connect.py b/python/lib/db/connect.py index be0690c92..5593cf75a 100644 --- a/python/lib/db/connect.py +++ b/python/lib/db/connect.py @@ -1,18 +1,24 @@ -from typing import Any -from urllib.parse import quote - -from sqlalchemy import create_engine +from sqlalchemy import URL, create_engine from sqlalchemy.orm import Session -default_port = 3306 +from lib.config_file import DatabaseConfig + + +def connect_to_database(config: DatabaseConfig): + """ + Connect to the database and get an SQLAlchemy session to interract with it using the provided + credentials. + """ + # The SQLAlchemy URL object notably escapes special characters in the configuration attributes + url = URL.create( + drivername = 'mysql+mysqldb', + host = config.host, + port = config.port, + username = config.username, + password = config.password, + database = config.database, + ) -def connect_to_db(credentials: dict[str, Any]): - host = credentials['host'] - port = credentials['port'] - username = quote(credentials['username']) - password = quote(credentials['passwd']) - database = credentials['database'] - port = int(port) if port else default_port - engine = create_engine(f'mysql+mysqldb://{username}:{password}@{host}:{port}/{database}') + engine = create_engine(url) return Session(engine) diff --git a/python/lib/dcm2bids_imaging_pipeline_lib/base_pipeline.py b/python/lib/dcm2bids_imaging_pipeline_lib/base_pipeline.py index 2c37e6c5e..0ca31da9c 100644 --- a/python/lib/dcm2bids_imaging_pipeline_lib/base_pipeline.py +++ b/python/lib/dcm2bids_imaging_pipeline_lib/base_pipeline.py @@ -1,13 +1,13 @@ import os -import re import shutil import sys +from typing import Never import lib.exitcode import lib.utilities from lib.database import Database from lib.database_lib.config import Config -from lib.db.connect import connect_to_db +from lib.db.connect import connect_to_database from lib.dicom_archive import DicomArchive from lib.exception.determine_subject_info_error import DetermineSubjectInfoError from lib.exception.validate_subject_info_error import ValidateSubjectInfoError @@ -15,7 +15,7 @@ from lib.imaging_upload import ImagingUpload from lib.log import Log from lib.session import Session -from lib.validate_subject_ids import validate_subject_ids +from lib.validate_subject_info import validate_subject_info class BasePipeline: @@ -60,7 +60,7 @@ def __init__(self, loris_getopt_obj, script_name): self.db = Database(self.config_file.mysql, self.verbose) self.db.connect() - self.db_orm = connect_to_db(self.config_file.mysql) + self.db_orm = connect_to_database(self.config_file.mysql) # ----------------------------------------------------------------------------------- # Load the Config, Imaging, ImagingUpload, Tarchive, Session database classes @@ -109,7 +109,7 @@ def __init__(self, loris_getopt_obj, script_name): # --------------------------------------------------------------------------------- if self.dicom_archive_obj.tarchive_info_dict.keys(): try: - self.subject_id_dict = self.imaging_obj.determine_subject_ids(self.dicom_archive_obj.tarchive_info_dict) + self.subject_info = self.imaging_obj.determine_subject_info(self.dicom_archive_obj.tarchive_info_dict) except DetermineSubjectInfoError as error: self.log_error_and_exit( error.message, @@ -190,25 +190,20 @@ def determine_study_info(self): :rtype: dict """ - cand_id = self.subject_id_dict['CandID'] - visit_label = self.subject_id_dict['visitLabel'] - patient_name = self.subject_id_dict['PatientName'] - # get the CenterID from the session table if the PSCID and visit label exists # and could be extracted from the database - if cand_id and visit_label: - self.session_obj.create_session_dict(cand_id, visit_label) - session_dict = self.session_obj.session_info_dict - if session_dict: - return {"CenterName": session_dict["MRI_alias"], "CenterID": session_dict["CenterID"]} + self.session_obj.create_session_dict(self.subject_info.cand_id, self.subject_info.visit_label) + session_dict = self.session_obj.session_info_dict + if session_dict: + return {"CenterName": session_dict["MRI_alias"], "CenterID": session_dict["CenterID"]} # if could not find center information based on cand_id and visit_label, use the # patient name to match it to the site alias or MRI alias list_of_sites = self.session_obj.get_list_of_sites() for site_dict in list_of_sites: - if site_dict["Alias"] in patient_name: + if site_dict["Alias"] in self.subject_info.name: return {"CenterName": site_dict["Alias"], "CenterID": site_dict["CenterID"]} - elif site_dict["MRI_alias"] in patient_name: + elif site_dict["MRI_alias"] in self.subject_info.name: return {"CenterName": site_dict["MRI_alias"], "CenterID": site_dict["CenterID"]} # if we got here, it means we could not find a center associated to the dataset @@ -235,25 +230,15 @@ def determine_scanner_info(self): self.log_info(message, is_error="N", is_verbose="Y") return scanner_id - def validate_subject_ids(self): + def validate_subject_info(self): """ Ensure that the subject PSCID/CandID corresponds to a single candidate in the candidate table and that the visit label can be found in the Visit_Windows table. If those conditions are not fulfilled. """ - # no further checking if the subject is phantom - if self.subject_id_dict['isPhantom']: - return - try: - validate_subject_ids( - self.db_orm, - self.subject_id_dict['PSCID'], - self.subject_id_dict['CandID'], - self.subject_id_dict['visitLabel'], - bool(self.subject_id_dict['createVisitLabel']), - ) + validate_subject_info(self.db_orm, self.subject_info) self.imaging_upload_obj.update_mri_upload( upload_id=self.upload_id, fields=('IsCandidateInfoValidated',), values=('1',) @@ -264,7 +249,7 @@ def validate_subject_ids(self): upload_id=self.upload_id, fields=('IsCandidateInfoValidated',), values=('0',) ) - def log_error_and_exit(self, message, exit_code, is_error, is_verbose): + def log_error_and_exit(self, message, exit_code, is_error, is_verbose) -> Never: """ Function to commonly executes all logging information when the script needs to be interrupted due to an error. It will log the error in the log file created by the @@ -312,9 +297,7 @@ def get_session_info(self): Creates the session info dictionary based on entries found in the session table. """ - cand_id = self.subject_id_dict["CandID"] - visit_label = self.subject_id_dict["visitLabel"] - self.session_obj.create_session_dict(cand_id, visit_label) + self.session_obj.create_session_dict(self.subject_info.cand_id, self.subject_info.visit_label) if self.session_obj.session_info_dict: message = f"Session ID for the file to insert is {self.session_obj.session_info_dict['ID']}" @@ -325,38 +308,29 @@ def create_session(self): Function that will create a new visit in the session table for the imaging scans after verification that all the information necessary for the creation of the visit are present. """ - cand_id = self.subject_id_dict["CandID"] - visit_label = self.subject_id_dict["visitLabel"] - create_visit_label = self.subject_id_dict["createVisitLabel"] - project_id = self.subject_id_dict["ProjectID"] if "ProjectID" in self.subject_id_dict.keys() else None - cohort_id = self.subject_id_dict["CohortID"] if "CohortID" in self.subject_id_dict.keys() else None + + create_visit = self.subject_info.create_visit # check if whether the visit label should be created - if not create_visit_label: - message = f"Visit {visit_label} for candidate {cand_id} does not exist." + if create_visit is None: + message = f"Visit {self.subject_info.visit_label} for candidate {self.subject_info.cand_id} does not exist." self.log_error_and_exit(message, lib.exitcode.GET_SESSION_ID_FAILURE, is_error="Y", is_verbose="N") - # check if a project ID was provided in the config file for the visit label - if not project_id: - message = "Cannot create visit: profile file does not defined the visit's ProjectID" - self.log_error_and_exit(message, lib.exitcode.CREATE_SESSION_FAILURE, is_error="Y", is_verbose="N") - - # check if a cohort ID was provided in the config file for the visit label - if not cohort_id: - message = "Cannot create visit: profile file does not defined the visit's CohortID" - self.log_error_and_exit(message, lib.exitcode.CREATE_SESSION_FAILURE, is_error="Y", is_verbose="N") - # check that the project ID and cohort ID refers to an existing row in project_cohort_rel table - self.session_obj.create_proj_cohort_rel_info_dict(project_id, cohort_id) + self.session_obj.create_proj_cohort_rel_info_dict(create_visit.project_id, create_visit.cohort_id) if not self.session_obj.proj_cohort_rel_info_dict.keys(): - message = f"Cannot create visit with project ID {project_id} and cohort ID {cohort_id}:" \ - f" no such association in table project_cohort_rel" + message = f"Cannot create visit with project ID {create_visit.project_id} and " \ + f"cohort ID {create_visit.cohort_id}: no such association in table project_cohort_rel" self.log_error_and_exit(message, lib.exitcode.CREATE_SESSION_FAILURE, is_error="Y", is_verbose="N") # determine the visit number and center ID for the next session to be created center_id, visit_nb = self.determine_new_session_site_and_visit_nb() if not center_id: - message = f"No center ID found for candidate {cand_id}, visit {visit_label}" + message = ( + f"No center ID found for candidate {self.subject_info.cand_id}, " + f"visit {self.subject_info.visit_label}" + ) + self.log_error_and_exit(message, is_error="Y", is_verbose="N") else: message = f"Set newVisitNo = {visit_nb} and center ID = {center_id}" @@ -365,15 +339,15 @@ def create_session(self): # create the new visit session_id = self.session_obj.insert_into_session( { - 'CandID': cand_id, - 'Visit_label': visit_label, + 'CandID': self.subject_info.cand_id, + 'Visit_label': self.subject_info.visit_label, 'CenterID': center_id, 'VisitNo': visit_nb, 'Current_stage': 'Not Started', 'Scan_done': 'Y', 'Submitted': 'N', - 'CohortID': cohort_id, - 'ProjectID': project_id + 'CohortID': create_visit.cohort_id, + 'ProjectID': create_visit.project_id } ) if session_id: @@ -385,45 +359,25 @@ def determine_new_session_site_and_visit_nb(self): :returns: The center ID and visit number of the future new session """ - cand_id = self.subject_id_dict["CandID"] - visit_label = self.subject_id_dict["visitLabel"] - is_phantom = self.subject_id_dict["isPhantom"] visit_nb = 0 center_id = 0 - if is_phantom: - center_info_dict = self.determine_phantom_data_site(string_with_site_acronym=visit_label) + if self.subject_info.is_phantom: + center_info_dict = self.session_obj.get_session_center_info( + self.subject_info.psc_id, self.subject_info.visit_label, + ) + if center_info_dict: center_id = center_info_dict["CenterID"] visit_nb = 1 else: - center_info_dict = self.session_obj.get_next_session_site_id_and_visit_number(cand_id) + center_info_dict = self.session_obj.get_next_session_site_id_and_visit_number(self.subject_info.cand_id) if center_info_dict: center_id = center_info_dict["CenterID"] visit_nb = center_info_dict["newVisitNo"] return center_id, visit_nb - def determine_phantom_data_site(self, string_with_site_acronym): - """ - Determine the site of a phantom dataset. - - :param string_with_site_acronym: string to use to look for Alias or MRI_alias in the psc table - :type string_with_site_acronym: str - """ - - pscid = self.subject_id_dict["PSCID"] - visit_label = self.subject_id_dict["visitLabel"] - - # first check whether there is already a session in the database for the phantom scan - if pscid and visit_label: - return self.session_obj.get_session_center_info(pscid, visit_label) - - # if no session found, use a string_with_site_acronym to match it to a site alias or MRI alias - for row in self.site_dict: - if re.search(rf"{row['Alias']}|{row['MRI_alias']}", string_with_site_acronym, re.IGNORECASE): - return row - def check_if_tarchive_validated_in_db(self): """ Checks whether the DICOM archive was previously validated in the database (as per the value present diff --git a/python/lib/dcm2bids_imaging_pipeline_lib/dicom_validation_pipeline.py b/python/lib/dcm2bids_imaging_pipeline_lib/dicom_validation_pipeline.py index c41cfaed0..9d85f6dd8 100644 --- a/python/lib/dcm2bids_imaging_pipeline_lib/dicom_validation_pipeline.py +++ b/python/lib/dcm2bids_imaging_pipeline_lib/dicom_validation_pipeline.py @@ -27,7 +27,7 @@ def __init__(self, loris_getopt_obj, script_name): :type script_name: str """ super().__init__(loris_getopt_obj, script_name) - self.validate_subject_ids() + self.validate_subject_info() self._validate_dicom_archive_md5sum() # --------------------------------------------------------------------------------------------- diff --git a/python/lib/dcm2bids_imaging_pipeline_lib/nifti_insertion_pipeline.py b/python/lib/dcm2bids_imaging_pipeline_lib/nifti_insertion_pipeline.py index aa255231b..b41102054 100644 --- a/python/lib/dcm2bids_imaging_pipeline_lib/nifti_insertion_pipeline.py +++ b/python/lib/dcm2bids_imaging_pipeline_lib/nifti_insertion_pipeline.py @@ -11,7 +11,7 @@ from lib.dcm2bids_imaging_pipeline_lib.base_pipeline import BasePipeline from lib.exception.determine_subject_info_error import DetermineSubjectInfoError from lib.exception.validate_subject_info_error import ValidateSubjectInfoError -from lib.validate_subject_ids import validate_subject_ids +from lib.validate_subject_info import validate_subject_info __license__ = "GPLv3" @@ -84,20 +84,14 @@ def __init__(self, loris_getopt_obj, script_name): # --------------------------------------------------------------------------------------------- if self.dicom_archive_obj.tarchive_info_dict.keys(): self._validate_nifti_patient_name_with_dicom_patient_name() - self.subject_id_dict = self.imaging_obj.determine_subject_ids( + self.subject_info = self.imaging_obj.determine_subject_info( self.dicom_archive_obj.tarchive_info_dict, self.scanner_id ) else: - self._determine_subject_ids_based_on_json_patient_name() + self._determine_subject_info_based_on_json_patient_name() try: - validate_subject_ids( - self.db_orm, - self.subject_id_dict['PSCID'], - self.subject_id_dict['CandID'], - self.subject_id_dict['visitLabel'], - bool(self.subject_id_dict['createVisitLabel']), - ) + validate_subject_info(self.db_orm, self.subject_info) except ValidateSubjectInfoError as error: self.imaging_obj.insert_mri_candidate_errors( self.dicom_archive_obj.tarchive_info_dict['PatientName'], @@ -316,7 +310,7 @@ def _check_if_nifti_file_was_already_inserted(self): if error_msg: self.log_error_and_exit(error_msg, lib.exitcode.FILE_NOT_UNIQUE, is_error="Y", is_verbose="N") - def _determine_subject_ids_based_on_json_patient_name(self): + def _determine_subject_info_based_on_json_patient_name(self): """ Determines the subject IDs information based on the patient name information present in the JSON file. """ @@ -325,7 +319,7 @@ def _determine_subject_ids_based_on_json_patient_name(self): dicom_value = self.json_file_dict[dicom_header] try: - self.subject_id_dict = self.imaging_obj.determine_subject_ids(dicom_value) + self.subject_info = self.imaging_obj.determine_subject_info(dicom_value) except DetermineSubjectInfoError as error: self.log_error_and_exit( error.message, @@ -431,8 +425,8 @@ def _determine_new_nifti_assembly_rel_path(self): # determine file BIDS entity values for the file into a dictionary file_bids_entities_dict = { - 'sub': self.subject_id_dict['CandID'], - 'ses': self.subject_id_dict['visitLabel'], + 'sub': self.subject_info.cand_id, + 'ses': self.subject_info.visit_label, 'run': 1 } if self.bids_categories_dict['BIDSEchoNumber']: @@ -444,8 +438,8 @@ def _determine_new_nifti_assembly_rel_path(self): file_bids_entities_dict[key] = value # determine where the file should go - bids_cand_id = 'sub-' + self.subject_id_dict['CandID'] - bids_visit = 'ses-' + self.subject_id_dict['visitLabel'] + bids_cand_id = 'sub-' + str(self.subject_info.cand_id) + bids_visit = 'ses-' + self.subject_info.visit_label bids_subfolder = self.bids_categories_dict['BIDSCategoryName'] # determine NIfTI file name @@ -596,8 +590,8 @@ def _register_protocol_violated_scan(self): self.imaging_obj.insert_protocol_violated_scan( patient_name, - self.subject_id_dict['CandID'], - self.subject_id_dict['PSCID'], + self.subject_info.cand_id, + self.subject_info.psc_id, self.dicom_archive_obj.tarchive_info_dict['TarchiveID'], self.json_file_dict, self.trashbin_nifti_rel_path, @@ -633,9 +627,9 @@ def _register_violations_log(self, violations_list, file_rel_path): 'SeriesUID': scan_param['SeriesInstanceUID'] if 'SeriesInstanceUID' in scan_param.keys() else None, 'TarchiveID': self.dicom_archive_obj.tarchive_info_dict['TarchiveID'], 'MincFile': file_rel_path, - 'PatientName': self.subject_id_dict['PatientName'], - 'CandID': self.subject_id_dict['CandID'], - 'Visit_label': self.subject_id_dict['visitLabel'], + 'PatientName': self.subject_info.name, + 'CandID': self.subject_info.cand_id, + 'Visit_label': self.subject_info.visit_label, 'Scan_type': self.scan_type_id, 'EchoTime': scan_param['EchoTime'] if 'EchoTime' in scan_param.keys() else None, 'EchoNumber': scan_param['EchoNumber'] if 'EchoNumber' in scan_param.keys() else None, @@ -697,7 +691,7 @@ def _create_pic_image(self): Creates the pic image of the NIfTI file. """ file_info = { - 'cand_id': self.subject_id_dict['CandID'], + 'cand_id': self.subject_info.cand_id, 'data_dir_path': self.data_dir, 'file_rel_path': self.assembly_nifti_rel_path, 'is_4D_dataset': True if self.json_file_dict['time'] else False, diff --git a/python/lib/dcm2bids_imaging_pipeline_lib/push_imaging_files_to_s3_pipeline.py b/python/lib/dcm2bids_imaging_pipeline_lib/push_imaging_files_to_s3_pipeline.py index e8911e7c8..cae44b312 100644 --- a/python/lib/dcm2bids_imaging_pipeline_lib/push_imaging_files_to_s3_pipeline.py +++ b/python/lib/dcm2bids_imaging_pipeline_lib/push_imaging_files_to_s3_pipeline.py @@ -262,8 +262,7 @@ def _clean_up_empty_folders(self): # remove empty folders from file system print("Cleaning up empty folders") - cand_id = self.subject_id_dict["CandID"] - bids_cand_id = f"sub-{cand_id}" + bids_cand_id = f"sub-{self.subject_info.cand_id}" lib.utilities.remove_empty_folders(os.path.join(self.data_dir, "assembly_bids", bids_cand_id)) - lib.utilities.remove_empty_folders(os.path.join(self.data_dir, "pic", cand_id)) + lib.utilities.remove_empty_folders(os.path.join(self.data_dir, "pic", str(self.subject_info.cand_id))) lib.utilities.remove_empty_folders(os.path.join(self.data_dir, "trashbin")) diff --git a/python/lib/imaging.py b/python/lib/imaging.py index 60c9baf2c..aedd7a1ab 100644 --- a/python/lib/imaging.py +++ b/python/lib/imaging.py @@ -5,12 +5,13 @@ import os import re import tarfile -from typing import Any, Optional +from typing import Optional import nibabel as nib from nilearn import image, plotting import lib.utilities as utilities +from lib.config_file import SubjectInfo from lib.database_lib.config import Config from lib.database_lib.files import Files from lib.database_lib.mri_candidate_errors import MriCandidateErrors @@ -512,7 +513,7 @@ def grep_cand_id_from_file_id(self, file_id): # return the result return results[0]['CandID'] if results else None - def determine_subject_ids(self, tarchive_info_dict, scanner_id: Optional[int] = None) -> dict[str, Any]: + def determine_subject_info(self, tarchive_info_dict, scanner_id: Optional[int] = None) -> SubjectInfo: """ Determine subject IDs based on the DICOM header specified by the lookupCenterNameUsing config setting. This function will call a function in the configuration file that can be @@ -533,23 +534,22 @@ def determine_subject_ids(self, tarchive_info_dict, scanner_id: Optional[int] = subject_name = tarchive_info_dict[dicom_header] try: - subject_id_dict = self.config_file.get_subject_ids(self.db, subject_name, scanner_id) + subject_info = self.config_file.get_subject_info(self.db, subject_name, scanner_id) except AttributeError: raise DetermineSubjectInfoError( - 'Config file does not contain a `get_subject_ids` function. Upload will exit now.' + 'Config file does not contain a `get_subject_info` function. Upload will exit now.' ) - if subject_id_dict == {}: + if subject_info is None: raise DetermineSubjectInfoError( f'Cannot get subject IDs for subject \'{subject_name}\'.\n' 'Possible causes:\n' '- The subject name is not correctly formatted (should usually be \'PSCID_CandID_VisitLabel\').\n' - '- The function `get_subject_ids` in the Python configuration file is not properly defined.\n' + '- The function `get_subject_info` in the Python configuration file is not properly defined.\n' '- Other project specific reason.' ) - subject_id_dict['PatientName'] = subject_name - return subject_id_dict + return subject_info def map_bids_param_to_loris_param(self, file_parameters): """ diff --git a/python/lib/lorisgetopt.py b/python/lib/lorisgetopt.py index bed62a308..d36ccbf13 100644 --- a/python/lib/lorisgetopt.py +++ b/python/lib/lorisgetopt.py @@ -104,20 +104,15 @@ def __init__(self, usage, options_dict, script_name): s3_bucket_name = self.config_db_obj.get_config("AWS_S3_Default_Bucket") self.s3_obj = None if hasattr(self.config_file, 's3'): - if not self.config_file.s3["aws_access_key_id"] or not self.config_file.s3["aws_secret_access_key"]: - print( - "\n[ERROR ] missing 'aws_access_key_id' or 'aws_secret_access_key' in config file 's3' object\n" - ) - sys.exit(lib.exitcode.S3_SETTINGS_FAILURE) - s3_endpoint = s3_endpoint if s3_endpoint else self.config_file.s3["aws_s3_endpoint_url"] - s3_bucket_name = s3_bucket_name if s3_bucket_name else self.config_file.s3["aws_s3_bucket_name"] + s3_endpoint = s3_endpoint if s3_endpoint else self.config_file.s3.aws_s3_endpoint_url + s3_bucket_name = s3_bucket_name if s3_bucket_name else self.config_file.s3.aws_s3_bucket_name if not s3_endpoint or not s3_bucket_name: print('\n[ERROR ] missing configuration for S3 endpoint URL or S3 bucket name\n') sys.exit(lib.exitcode.S3_SETTINGS_FAILURE) try: self.s3_obj = AwsS3( - aws_access_key_id=self.config_file.s3["aws_access_key_id"], - aws_secret_access_key=self.config_file.s3["aws_secret_access_key"], + aws_access_key_id=self.config_file.s3.aws_access_key_id, + aws_secret_access_key=self.config_file.s3.aws_secret_access_key, aws_endpoint_url=s3_endpoint, bucket_name=s3_bucket_name ) diff --git a/python/lib/validate_subject_ids.py b/python/lib/validate_subject_ids.py deleted file mode 100644 index e78af3dda..000000000 --- a/python/lib/validate_subject_ids.py +++ /dev/null @@ -1,73 +0,0 @@ -from dataclasses import dataclass -from typing import cast - -from sqlalchemy.orm import Session as Database - -from lib.db.model.candidate import DbCandidate -from lib.db.query.candidate import try_get_candidate_with_cand_id -from lib.db.query.visit import try_get_visit_window_with_visit_label -from lib.exception.validate_subject_info_error import ValidateSubjectInfoError - -# Utility class - - -@dataclass -class Subject: - """ - Wrapper for the properties of a subject. - """ - - psc_id: str - cand_id: str - visit_label: str - - def get_name(self): - return f'{self.psc_id}_{self.cand_id}_{self.visit_label}' - - -# Main validation functions - -def validate_subject_ids( - db: Database, - psc_id: str, - cand_id: str, - visit_label: str, - create_visit: bool -): - """ - Validate a subject's information against the database from its parts (PSCID, CandID, VisitLabel). - Raise an exception if an error is found, or return `None` otherwise. - """ - - subject = Subject(psc_id, cand_id, visit_label) - validate_subject(db, subject, create_visit) - - -def validate_subject(db: Database, subject: Subject, create_visit: bool): - candidate = try_get_candidate_with_cand_id(db, int(subject.cand_id)) - if candidate is None: - validate_subject_error( - subject, - f'Candidate (CandID = \'{subject.cand_id}\') does not exist in the database.' - ) - - # Safe because the previous check raises an exception if the candidate is `None`. - candidate = cast(DbCandidate, candidate) - - if candidate.psc_id != subject.psc_id: - validate_subject_error( - subject, - f'Candidate (CandID = \'{subject.cand_id}\') PSCID does not match the subject PSCID.\n' - f'Candidate PSCID = \'{candidate.psc_id}\', Subject PSCID = \'{subject.psc_id}\'' - ) - - visit_window = try_get_visit_window_with_visit_label(db, subject.visit_label) - if visit_window is None and not create_visit: - validate_subject_error( - subject, - f'Visit label \'{subject.visit_label}\' does not exist in the database (table `Visit_Windows`).' - ) - - -def validate_subject_error(subject: Subject, message: str): - raise ValidateSubjectInfoError(f'Validation error for subject \'{subject.get_name()}\'.\n{message}') diff --git a/python/lib/validate_subject_info.py b/python/lib/validate_subject_info.py new file mode 100644 index 000000000..65b84689b --- /dev/null +++ b/python/lib/validate_subject_info.py @@ -0,0 +1,40 @@ +from typing import Never + +from sqlalchemy.orm import Session as Database + +from lib.config_file import SubjectInfo +from lib.db.query.candidate import try_get_candidate_with_cand_id +from lib.db.query.visit import try_get_visit_window_with_visit_label +from lib.exception.validate_subject_info_error import ValidateSubjectInfoError + + +def validate_subject_info(db: Database, subject_info: SubjectInfo): + """ + Validate a subject's information against the database from its parts (PSCID, CandID, VisitLabel). + Raise an exception if an error is found, or return `None` otherwise. + """ + + candidate = try_get_candidate_with_cand_id(db, subject_info.cand_id) + if candidate is None: + validate_subject_error( + subject_info, + f'Candidate (CandID = \'{subject_info.cand_id}\') does not exist in the database.' + ) + + if candidate.psc_id != subject_info.psc_id: + validate_subject_error( + subject_info, + f'Candidate (CandID = \'{subject_info.cand_id}\') PSCID does not match the subject PSCID.\n' + f'Candidate PSCID = \'{candidate.psc_id}\', Subject PSCID = \'{subject_info.psc_id}\'' + ) + + visit_window = try_get_visit_window_with_visit_label(db, subject_info.visit_label) + if visit_window is None and subject_info.create_visit is not None: + validate_subject_error( + subject_info, + f'Visit label \'{subject_info.visit_label}\' does not exist in the database (table `Visit_Windows`).' + ) + + +def validate_subject_error(subject_info: SubjectInfo, message: str) -> Never: + raise ValidateSubjectInfoError(f'Validation error for subject \'{subject_info.name}\'.\n{message}')