diff --git a/scripts/credential_management.py b/scripts/credential_management.py index 9a17003..4ce4fed 100755 --- a/scripts/credential_management.py +++ b/scripts/credential_management.py @@ -26,10 +26,10 @@ def _put_s3_data(name: str, bucket_name: str, client, data: dict, path: str = "a client.upload_fileobj(Bucket=bucket_name, Key=f"{path}/{name}", Fileobj=b_data) -def create_auth(client, user: str, auth: str, site: str) -> str: +def create_auth(user: str, auth: str, site: str) -> str: """Adds a new entry to the auth dict used to issue pre-signed URLs""" site_id = _basic_auth_str(user, auth).split(" ")[1] - return f'"{site_id}"": {{"user": {user}, "site":{site}}}' + return f'"{site_id}": {{"user": "{user}", "site":"{site}"}}' def create_meta(client, bucket_name: str, site: str, folder: str) -> None: diff --git a/src/shared/awswrangler_functions.py b/src/shared/awswrangler_functions.py index 5ed716a..dcd0378 100644 --- a/src/shared/awswrangler_functions.py +++ b/src/shared/awswrangler_functions.py @@ -1,6 +1,9 @@ """functions specifically requiring AWSWranger, which requires a lambda layer""" +import csv + import awswrangler +import numpy from .enums import BucketPath @@ -35,3 +38,25 @@ def get_s3_study_meta_list( ), suffix=extension, ) + + +def generate_csv_from_parquet( + bucket_name: str, bucket_root: str, subbucket_path: str, to_path: str | None = None +): + """Convenience function for generating csvs for dashboard upload + + TODO: Remove on dashboard parquet/API support""" + if to_path is None: + to_path = f"s3://{bucket_name}/{bucket_root}/{subbucket_path}".replace(".parquet", ".csv") + last_valid_df = awswrangler.s3.read_parquet( + f"s3://{bucket_name}/{bucket_root}" f"/{subbucket_path}" + ) + last_valid_df = last_valid_df.apply(lambda x: x.strip() if isinstance(x, str) else x).replace( + '""', numpy.nan + ) + awswrangler.s3.to_csv( + last_valid_df, + to_path, + index=False, + quoting=csv.QUOTE_MINIMAL, + ) diff --git a/src/shared/enums.py b/src/shared/enums.py index aad6069..557ecbe 100644 --- a/src/shared/enums.py +++ b/src/shared/enums.py @@ -11,7 +11,9 @@ class BucketPath(enum.Enum): ARCHIVE = "archive" CACHE = "cache" CSVAGGREGATE = "csv_aggregates" + CSVFLAT = "csv_flat" ERROR = "error" + FLAT = "flat" LAST_VALID = "last_valid" LATEST = "latest" META = "metadata" @@ -33,9 +35,19 @@ class JsonFilename(enum.Enum): COLUMN_TYPES = "column_types" TRANSACTIONS = "transactions" DATA_PACKAGES = "data_packages" + FLAT_PACKAGES = "flat_packages" STUDY_PERIODS = "study_periods" +class StudyPeriodMetadataKeys(enum.Enum): + """stores names of expected keys in the study period metadata dictionary""" + + STUDY_PERIOD_FORMAT_VERSION = "study_period_format_version" + EARLIEST_DATE = "earliest_date" + LATEST_DATE = "latest_date" + LAST_DATA_UPDATE = "last_data_update" + + class TransactionKeys(enum.Enum): """stores names of expected keys in the transaction dictionary""" @@ -47,10 +59,12 @@ class TransactionKeys(enum.Enum): DELETED = "deleted" -class StudyPeriodMetadataKeys(enum.Enum): - """stores names of expected keys in the study period metadata dictionary""" +class UploadTypes(enum.Enum): + """stores names of different expected upload formats""" - STUDY_PERIOD_FORMAT_VERSION = "study_period_format_version" - EARLIEST_DATE = "earliest_date" - LATEST_DATE = "latest_date" - LAST_DATA_UPDATE = "last_data_update" + # archive is not expected to be uploaded, but is one of the generated file types + # in the library + ARCHIVE = "archive" + CUBE = "cube" + FLAT = "flat" + META = "meta" diff --git a/src/shared/functions.py b/src/shared/functions.py index 68b40cb..86508b2 100644 --- a/src/shared/functions.py +++ b/src/shared/functions.py @@ -140,7 +140,7 @@ def update_metadata( # Should only be hit if you add a new JSON dict and forget to add it # to this function case _: - raise OSError(f"{meta_type} does not have a handler for updates.") + raise ValueError(f"{meta_type} does not have a handler for updates.") data_version_metadata.update(extra_items) return metadata @@ -158,7 +158,7 @@ def write_metadata( s3_client.put_object( Bucket=s3_bucket_name, Key=f"{enums.BucketPath.META.value}/{meta_type}.json", - Body=json.dumps(metadata, default=str), + Body=json.dumps(metadata, default=str, indent=2), ) @@ -182,6 +182,30 @@ def move_s3_file(s3_client, s3_bucket_name: str, old_key: str, new_key: str) -> raise S3UploadError +def get_s3_keys( + s3_client, + s3_bucket_name: str, + prefix: str, + token: str | None = None, + max_keys: int | None = None, +) -> list[str]: + """Gets the list of all keys in S3 starting with the prefix""" + if max_keys is None: + max_keys = 1000 + if token: + res = s3_client.list_objects_v2( + Bucket=s3_bucket_name, Prefix=prefix, ContinuationToken=token, MaxKeys=max_keys + ) + else: + res = s3_client.list_objects_v2(Bucket=s3_bucket_name, Prefix=prefix, MaxKeys=max_keys) + if "Contents" not in res: + return [] + contents = [record["Key"] for record in res["Contents"]] + if res["IsTruncated"]: + contents += get_s3_keys(s3_client, s3_bucket_name, prefix, res["NextContinuationToken"]) + return contents + + def get_s3_site_filename_suffix(s3_path: str): """Extracts site/filename data from s3 path""" # The expected s3 path for site data packages looks like: @@ -209,14 +233,15 @@ def get_latest_data_package_version(bucket, prefix): prefix = prefix + "/" s3_res = s3_client.list_objects_v2(Bucket=bucket, Prefix=prefix) highest_ver = None - for item in s3_res["Contents"]: - ver_str = item["Key"].replace(prefix, "").split("/")[0] - if ver_str.isdigit(): - if highest_ver is None: - highest_ver = ver_str - else: - if int(highest_ver) < int(ver_str): + if "Contents" in s3_res: + for item in s3_res["Contents"]: + ver_str = item["Key"].replace(prefix, "").split("/")[1].split("__")[2] + if ver_str.isdigit(): + if highest_ver is None: highest_ver = ver_str - if highest_ver is None: + else: + if int(highest_ver) < int(ver_str): + highest_ver = ver_str + if "Contents" not in s3_res or highest_ver is None: logging.error("No data package versions found for %s", prefix) return highest_ver diff --git a/src/shared/s3_manager.py b/src/shared/s3_manager.py new file mode 100644 index 0000000..c3a9a6b --- /dev/null +++ b/src/shared/s3_manager.py @@ -0,0 +1,235 @@ +import csv +import logging +import os +import traceback + +import awswrangler +import boto3 +import numpy +import pandas + +from shared import ( + awswrangler_functions, + enums, + functions, +) + +log_level = os.environ.get("LAMBDA_LOG_LEVEL", "INFO") +logger = logging.getLogger() +logger.setLevel(log_level) + + +class S3Manager: + """Class for managing S3 paramaters/access from data in an AWS SNS event. + + This is generally intended as a one stop shop for the data processing phase + of the aggregator pipeline, providing commmon file paths/sns event parsing helpers/ + stripped down write methods. Consider adding utility functions here instead of using + raw awswrangler/shared functions to try and make those processes simpler. + """ + + def __init__(self, event): + self.s3_bucket_name = os.environ.get("BUCKET_NAME") + self.s3_client = boto3.client("s3") + self.sns_client = boto3.client("sns", region_name=self.s3_client.meta.region_name) + self.event_source = event["Records"][0]["Sns"]["TopicArn"] + self.s3_key = event["Records"][0]["Sns"]["Message"] + s3_key_array = self.s3_key.split("/") + self.study = s3_key_array[1] + self.data_package = s3_key_array[2].split("__")[1] + self.site = s3_key_array[3] + self.version = s3_key_array[4].split("__")[-1] + self.metadata = functions.read_metadata( + self.s3_client, self.s3_bucket_name, meta_type=enums.JsonFilename.TRANSACTIONS.value + ) + self.types_metadata = functions.read_metadata( + self.s3_client, + self.s3_bucket_name, + meta_type=enums.JsonFilename.COLUMN_TYPES.value, + ) + self.parquet_aggregate_path = ( + f"s3://{self.s3_bucket_name}/{enums.BucketPath.AGGREGATE.value}/" + f"{self.study}/{self.study}__{self.data_package}/" + f"{self.study}__{self.data_package}__{self.version}/" + f"{self.study}__{self.data_package}__aggregate.parquet" + ) + self.csv_aggregate_path = ( + f"s3://{self.s3_bucket_name}/{enums.BucketPath.CSVAGGREGATE.value}/" + f"{self.study}/{self.study}__{self.data_package}/" + f"{self.version}/" + f"{self.study}__{self.data_package}__aggregate.csv" + ) + # TODO: Taking out a folder layer to match the depth of non-site aggregates + # Revisit when targeted crawling is implemented + self.parquet_flat_key = ( + f"{enums.BucketPath.FLAT.value}/" + f"{self.study}/{self.site}/" # {self.study}__{self.data_package}/" + f"{self.study}__{self.data_package}__{self.version}/" + f"{self.study}__{self.data_package}_{self.site}__flat.parquet" + ) + self.csv_flat_key = ( + f"{enums.BucketPath.CSVFLAT.value}/" + f"{self.study}/{self.site}/" # {self.study}__{self.data_package}/" + f"{self.study}__{self.data_package}__{self.version}/" + f"{self.study}__{self.data_package}_{self.site}__flat.parquet" + ) + + def error_handler( + self, + s3_path: str, + subbucket_path: str, + error: Exception, + ) -> None: + """Logs errors and moves files to the error folder + + :param s3_path: the path of the file generating the S3 error + :param subbucket_path: the path to write the file to inside the root error folder + """ + logger.error("Error processing file %s: %s", s3_path, str(error)) + logger.error(traceback.print_exc()) + self.move_file( + s3_path.replace(f"s3://{self.s3_bucket_name}/", ""), + f"{enums.BucketPath.ERROR.value}/{subbucket_path}", + ) + self.update_local_metadata(enums.TransactionKeys.LAST_ERROR.value) + + # S3 Filesystem operations + def copy_file(self, from_path_or_key: str, to_path_or_key: str) -> None: + """Copies a file from one location to another in S3. + + This function is agnostic to being provided an S3 path versus an S3 key. + + :param from_path_or_key: the data source + :param to_path_or_key: the data destination. + """ + if from_path_or_key.startswith("s3"): + from_path_or_key = from_path_or_key.split("/", 3)[-1] + if to_path_or_key.startswith("s3"): + to_path_or_key = to_path_or_key.split("/", 3)[-1] + source = { + "Bucket": self.s3_bucket_name, + "Key": from_path_or_key, + } + self.s3_client.copy_object( + CopySource=source, + Bucket=self.s3_bucket_name, + Key=to_path_or_key, + ) + + def get_data_package_list(self, bucket_root) -> list: + """Gets a list of data packages associated with the study from the SNS event payload. + + :param bucket_root: the top level directory name in the root of the S3 bucket + :returns: a list of full s3 file paths + """ + return awswrangler_functions.get_s3_data_package_list( + bucket_root, self.s3_bucket_name, self.study, self.data_package + ) + + def move_file(self, from_path_or_key: str, to_path_or_key: str) -> None: + """moves file from one location to another in s3 + + This function is agnostic to being provided an S3 path versus an S3 key. + + :param from_path_or_key: the data source + :param to_path_or_key: the data destination. + + """ + if from_path_or_key.startswith("s3"): + from_path_or_key = from_path_or_key.split("/", 3)[-1] + if to_path_or_key.startswith("s3"): + to_path_or_key = to_path_or_key.split("/", 3)[-1] + functions.move_s3_file( + self.s3_client, self.s3_bucket_name, from_path_or_key, to_path_or_key + ) + + # parquet/csv output creation + def cache_api(self): + """Sends an SNS cache event""" + topic_sns_arn = os.environ.get("TOPIC_CACHE_API_ARN") + self.sns_client.publish( + TopicArn=topic_sns_arn, Message="data_packages", Subject="data_packages" + ) + + def write_csv(self, df: pandas.DataFrame, path=None) -> None: + """writes dataframe as csv to s3 + + :param df: pandas dataframe + :param path: an S3 path to write to (default: aggregate csv path)""" + if path is None: + path = self.csv_aggregate_path + + df = df.apply(lambda x: x.strip() if isinstance(x, str) else x).replace('""', numpy.nan) + df = df.replace(to_replace=r",", value="", regex=True) + awswrangler.s3.to_csv(df, path, index=False, quoting=csv.QUOTE_NONE) + + def write_parquet(self, df: pandas.DataFrame, is_new_data_package: bool, path=None) -> None: + """Writes a dataframe as parquet to s3 and sends an SNS cache event if new + + :param df: pandas dataframe + :param is_new_data_package: if true, will dispatch a cache SNS event after copy is completed + :param path: an S3 path to write to (default: aggregate path)""" + if path is None: + path = self.parquet_aggregate_path + awswrangler.s3.to_parquet(df, path, index=False) + if is_new_data_package: + self.cache_api() + + # metadata + def update_local_metadata( + self, + key, + *, + site=None, + value=None, + metadata: dict | None = None, + meta_type: str | None = enums.JsonFilename.TRANSACTIONS.value, + extra_items: dict | None = None, + ): + """Updates the local cache of a json metadata dictionary + + :param key: the key of the parameter to update + :keyword site: If provided, the site to update + :keyword value: If provided, a specific value to assign to the key parameter + (only used by ColumnTypes) + :keyword metadata: the specific metadata type to update. default: Transactions + :keyword meta_type: The enum representing the name of the metadata type. + Default: Transactions + :keyword extra_items: A dictionary of items to append to the metadata + + """ + # We are excluding COLUMN_TYPES explicitly from this first check because, + # by design, it should never have a site field in it - the column types + # are tied to the study version, not a specific site's data + if extra_items is None: + extra_items = {} + if site is None and meta_type != enums.JsonFilename.COLUMN_TYPES.value: + site = self.site + if metadata is None: + metadata = self.metadata + functions.update_metadata( + metadata=metadata, + site=site, + study=self.study, + data_package=self.data_package, + version=self.version, + target=key, + value=value, + meta_type=meta_type, + extra_items=extra_items, + ) + + def write_local_metadata(self, metadata: dict | None = None, meta_type: str | None = None): + """Writes a cache of the local metadata back to S3 + + :param metadata: the specific dictionary to write. Default: transactions + :param meta_type: The enum representing the name of the metadata type. Default: Transactions + """ + metadata = metadata or self.metadata + meta_type = meta_type or enums.JsonFilename.TRANSACTIONS.value + functions.write_metadata( + s3_client=self.s3_client, + s3_bucket_name=self.s3_bucket_name, + metadata=metadata, + meta_type=meta_type, + ) diff --git a/src/site_upload/cache_api/cache_api.py b/src/site_upload/cache_api/cache_api.py index fa24a62..e1de6fb 100644 --- a/src/site_upload/cache_api/cache_api.py +++ b/src/site_upload/cache_api/cache_api.py @@ -28,28 +28,30 @@ def cache_api_data(s3_client, s3_bucket_name: str, db: str, target: str) -> None f"{enums.BucketPath.META.value}/{enums.JsonFilename.COLUMN_TYPES.value}.json", ) dp_details = [] + files = functions.get_s3_keys(s3_client, s3_bucket_name, enums.BucketPath.AGGREGATE.value) + files += functions.get_s3_keys(s3_client, s3_bucket_name, enums.BucketPath.FLAT.value) for dp in list(data_packages): + if not any([dp in x for x in files]): + continue dp_detail = { - "study": dp.split("__", 1)[0], - "name": dp.split("__", 1)[1], + "study": dp.split("__")[0], + "name": dp.split("__")[1], } - try: - versions = column_types[dp_detail["study"]][dp_detail["name"]] - for version in versions: - dp_details.append( - { - **dp_detail, - **versions[version], - "version": version, - "id": dp + "__" + version, - } - ) - except KeyError: - continue + versions = column_types[dp_detail["study"]][dp_detail["name"]] + for version in versions: + dp_dict = { + **dp_detail, + **versions[version], + "version": version, + "id": f"{dp_detail['study']}__{dp_detail['name']}__{version}", + } + if "__flat" in dp: + dp_dict["type"] = "flat" + dp_details.append(dp_dict) s3_client.put_object( Bucket=s3_bucket_name, Key=f"{enums.BucketPath.CACHE.value}/{enums.JsonFilename.DATA_PACKAGES.value}.json", - Body=json.dumps(dp_details), + Body=json.dumps(dp_details, indent=2), ) diff --git a/src/site_upload/powerset_merge/powerset_merge.py b/src/site_upload/powerset_merge/powerset_merge.py index 563f0a6..35b654d 100644 --- a/src/site_upload/powerset_merge/powerset_merge.py +++ b/src/site_upload/powerset_merge/powerset_merge.py @@ -1,23 +1,13 @@ """Lambda for performing joins of site count data""" -import csv import datetime import logging import os -import traceback import awswrangler -import boto3 -import numpy import pandas from pandas.core.indexes.range import RangeIndex -from shared import ( - awswrangler_functions, - decorators, - enums, - functions, - pandas_functions, -) +from shared import awswrangler_functions, decorators, enums, functions, pandas_functions, s3_manager log_level = os.environ.get("LAMBDA_LOG_LEVEL", "INFO") logger = logging.getLogger() @@ -30,144 +20,12 @@ def __init__(self, message, filename): self.filename = filename -class S3Manager: - """Convenience class for managing S3 Access""" - - def __init__(self, event): - self.s3_bucket_name = os.environ.get("BUCKET_NAME") - self.s3_client = boto3.client("s3") - self.sns_client = boto3.client("sns", region_name=self.s3_client.meta.region_name) - - self.s3_key = event["Records"][0]["Sns"]["Message"] - s3_key_array = self.s3_key.split("/") - self.study = s3_key_array[1] - self.data_package = s3_key_array[2].split("__")[1] - self.site = s3_key_array[3] - self.version = s3_key_array[4].split("__")[-1] - self.metadata = functions.read_metadata(self.s3_client, self.s3_bucket_name) - self.types_metadata = functions.read_metadata( - self.s3_client, - self.s3_bucket_name, - meta_type=enums.JsonFilename.COLUMN_TYPES.value, - ) - self.csv_aggerate_path = ( - f"s3://{self.s3_bucket_name}/{enums.BucketPath.CSVAGGREGATE.value}/" - f"{self.study}/{self.study}__{self.data_package}/" - f"{self.version}/" - f"{self.study}__{self.data_package}__aggregate.csv" - ) - - # S3 Filesystem operations - def get_data_package_list(self, path) -> list: - """convenience wrapper for get_s3_data_package_list""" - return awswrangler_functions.get_s3_data_package_list( - path, self.s3_bucket_name, self.study, self.data_package - ) - - def move_file(self, from_path: str, to_path: str) -> None: - """convenience wrapper for move_s3_file""" - functions.move_s3_file(self.s3_client, self.s3_bucket_name, from_path, to_path) - - def copy_file(self, from_path: str, to_path: str) -> None: - """convenience wrapper for copy_s3_file""" - source = { - "Bucket": self.s3_bucket_name, - "Key": from_path, - } - self.s3_client.copy_object( - CopySource=source, - Bucket=self.s3_bucket_name, - Key=to_path, - ) - - # parquet/csv output creation - def write_parquet(self, df: pandas.DataFrame, is_new_data_package: bool) -> None: - """writes dataframe as parquet to s3 and sends an SNS notification if new""" - parquet_aggregate_path = ( - f"s3://{self.s3_bucket_name}/{enums.BucketPath.AGGREGATE.value}/" - f"{self.study}/{self.study}__{self.data_package}/" - f"{self.study}__{self.data_package}__{self.version}/" - f"{self.study}__{self.data_package}__aggregate.parquet" - ) - awswrangler.s3.to_parquet(df, parquet_aggregate_path, index=False) - if is_new_data_package: - topic_sns_arn = os.environ.get("TOPIC_CACHE_API_ARN") - self.sns_client.publish( - TopicArn=topic_sns_arn, Message="data_packages", Subject="data_packages" - ) - - def write_csv(self, df: pandas.DataFrame) -> None: - """writes dataframe as csv to s3""" - df = df.apply(lambda x: x.strip() if isinstance(x, str) else x).replace('""', numpy.nan) - df = df.replace(to_replace=r",", value="", regex=True) - awswrangler.s3.to_csv(df, self.csv_aggerate_path, index=False, quoting=csv.QUOTE_NONE) - - # metadata - def update_local_metadata( - self, - key, - *, - site=None, - value=None, - metadata: dict | None = None, - meta_type: str | None = enums.JsonFilename.TRANSACTIONS.value, - extra_items: dict | None = None, - ): - """convenience wrapper for update_metadata""" - # We are excluding COLUMN_TYPES explicitly from this first check because, - # by design, it should never have a site field in it - the column types - # are tied to the study version, not a specific site's data - if extra_items is None: - extra_items = {} - if site is None and meta_type != enums.JsonFilename.COLUMN_TYPES.value: - site = self.site - if metadata is None: - metadata = self.metadata - metadata = functions.update_metadata( - metadata=metadata, - site=site, - study=self.study, - data_package=self.data_package, - version=self.version, - target=key, - value=value, - meta_type=meta_type, - extra_items=extra_items, - ) - - def write_local_metadata(self, metadata: dict | None = None, meta_type: str | None = None): - """convenience wrapper for write_metadata""" - metadata = metadata or self.metadata - meta_type = meta_type or enums.JsonFilename.TRANSACTIONS.value - functions.write_metadata( - s3_client=self.s3_client, - s3_bucket_name=self.s3_bucket_name, - metadata=metadata, - meta_type=meta_type, - ) - - def merge_error_handler( - self, - s3_path: str, - subbucket_path: str, - error: Exception, - ) -> None: - """Helper for logging errors and moving files""" - logger.error("File %s failed to aggregate: %s", s3_path, str(error)) - logger.error(traceback.print_exc()) - self.move_file( - s3_path.replace(f"s3://{self.s3_bucket_name}/", ""), - f"{enums.BucketPath.ERROR.value}/{subbucket_path}", - ) - self.update_local_metadata(enums.TransactionKeys.LAST_ERROR.value) - - def get_static_string_series(static_str: str, index: RangeIndex) -> pandas.Series: """Helper for the verbose way of defining a pandas string series""" return pandas.Series([static_str] * len(index)).astype("string") -def expand_and_concat_sets( +def expand_and_concat_powersets( df: pandas.DataFrame, file_path: str, site_name: str ) -> pandas.DataFrame: """Processes and joins dataframes containing powersets. @@ -223,31 +81,13 @@ def expand_and_concat_sets( return agg_df -def generate_csv_from_parquet(bucket_name: str, bucket_root: str, subbucket_path: str): - """Convenience function for generating csvs for dashboard upload - - TODO: Remove on dashboard parquet/API support""" - last_valid_df = awswrangler.s3.read_parquet( - f"s3://{bucket_name}/{bucket_root}" f"/{subbucket_path}" - ) - last_valid_df = last_valid_df.apply(lambda x: x.strip() if isinstance(x, str) else x).replace( - '""', numpy.nan - ) - awswrangler.s3.to_csv( - last_valid_df, - (f"s3://{bucket_name}/{bucket_root}/{subbucket_path}".replace(".parquet", ".csv")), - index=False, - quoting=csv.QUOTE_MINIMAL, - ) - - -def merge_powersets(manager: S3Manager) -> None: +def merge_powersets(manager: s3_manager.S3Manager) -> None: """Creates an aggregate powerset from all files with a given s3 prefix""" # TODO: this should be memory profiled for large datasets. We can use # chunking to lower memory usage during merges. - # initializing this early in case an empty file causes us to never set it logger.info(f"Proccessing data package at {manager.s3_key}") + # initializing this early in case an empty file causes us to never set it is_new_data_package = False df = pandas.DataFrame() latest_file_list = manager.get_data_package_list(enums.BucketPath.LATEST.value) @@ -262,14 +102,14 @@ def merge_powersets(manager: S3Manager) -> None: # one instead try: if not any(x.endswith(site_specific_name) for x in latest_file_list): - df = expand_and_concat_sets(df, last_valid_path, last_valid_site) + df = expand_and_concat_powersets(df, last_valid_path, last_valid_site) manager.update_local_metadata( enums.TransactionKeys.LAST_AGGREGATION.value, site=last_valid_site ) except MergeError as e: - # This is expected to trigger if there's an issue in expand_and_concat_sets; + # This is expected to trigger if there's an issue in expand_and_concat_powersets; # this usually means there's a data problem. - manager.merge_error_handler( + manager.error_handler( e.filename, subbucket_path, e, @@ -298,7 +138,7 @@ def merge_powersets(manager: S3Manager) -> None: # we'll generate a new list of valid tables for the dashboard else: is_new_data_package = True - df = expand_and_concat_sets(df, latest_path, manager.site) + df = expand_and_concat_powersets(df, latest_path, manager.site) manager.move_file( f"{enums.BucketPath.LATEST.value}/{subbucket_path}", f"{enums.BucketPath.LAST_VALID.value}/{subbucket_path}", @@ -309,7 +149,7 @@ def merge_powersets(manager: S3Manager) -> None: # This is used for uploading to the dashboard. # TODO: remove as soon as we support either parquet upload or # the API is supported by the dashboard - generate_csv_from_parquet( + awswrangler_functions.generate_csv_from_parquet( manager.s3_bucket_name, enums.BucketPath.LAST_VALID.value, subbucket_path, @@ -324,7 +164,7 @@ def merge_powersets(manager: S3Manager) -> None: enums.TransactionKeys.LAST_AGGREGATION.value, site=latest_site ) except Exception as e: - manager.merge_error_handler( + manager.error_handler( latest_path, subbucket_path, e, @@ -332,7 +172,7 @@ def merge_powersets(manager: S3Manager) -> None: # if a new file fails, we want to replace it with the last valid # for purposes of aggregation if any(x.endswith(site_specific_name) for x in last_valid_file_list): - df = expand_and_concat_sets( + df = expand_and_concat_powersets( df, f"s3://{manager.s3_bucket_name}/{enums.BucketPath.LAST_VALID.value}" f"/{subbucket_path}", @@ -352,7 +192,7 @@ def merge_powersets(manager: S3Manager) -> None: value=column_dict, metadata=manager.types_metadata, meta_type=enums.JsonFilename.COLUMN_TYPES.value, - extra_items={"total": int(df["cnt"][0]), "s3_path": manager.csv_aggerate_path}, + extra_items={"total": int(df["cnt"][0]), "s3_path": manager.csv_aggregate_path}, ) manager.update_local_metadata( enums.ColumnTypesKeys.LAST_DATA_UPDATE.value, @@ -378,7 +218,7 @@ def merge_powersets(manager: S3Manager) -> None: def powerset_merge_handler(event, context): """manages event from SNS, triggers file processing and merge""" del context - manager = S3Manager(event) + manager = s3_manager.S3Manager(event) merge_powersets(manager) res = functions.http_response(200, "Merge successful") return res diff --git a/src/site_upload/process_flat/__init__.py b/src/site_upload/process_flat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/site_upload/process_flat/process_flat.py b/src/site_upload/process_flat/process_flat.py new file mode 100644 index 0000000..b44f65a --- /dev/null +++ b/src/site_upload/process_flat/process_flat.py @@ -0,0 +1,70 @@ +import logging +import os + +import awswrangler +from shared import awswrangler_functions, decorators, enums, functions, pandas_functions, s3_manager + +log_level = os.environ.get("LAMBDA_LOG_LEVEL", "INFO") +logger = logging.getLogger() +logger.setLevel(log_level) + + +def process_flat(manager: s3_manager.S3Manager): + new_file = False + if awswrangler.s3.does_object_exist( + f"s3://{manager.s3_bucket_name}/{manager.parquet_flat_key}" + ): + manager.move_file( + manager.parquet_flat_key, + manager.parquet_flat_key.replace( + enums.BucketPath.FLAT.value, enums.BucketPath.ARCHIVE.value + ), + ) + else: + new_file = True + manager.move_file( + manager.s3_key, + manager.parquet_flat_key, + ) + awswrangler_functions.generate_csv_from_parquet( + manager.s3_bucket_name, + manager.parquet_flat_key.split("/", 1)[0], + manager.parquet_flat_key.split("/", 1)[1], + f"s3://{manager.s3_bucket_name}/{manager.csv_flat_key}", + ) + df = awswrangler.s3.read_parquet(f"s3://{manager.s3_bucket_name}/{manager.parquet_flat_key}") + column_dict = pandas_functions.get_column_datatypes(df.dtypes) + manager.update_local_metadata( + enums.ColumnTypesKeys.COLUMNS.value, + value=column_dict, + site=manager.site, + metadata=manager.types_metadata, + meta_type=enums.JsonFilename.COLUMN_TYPES.value, + extra_items={ + "s3_path": f"s3://{manager.s3_bucket_name}/{manager.parquet_flat_key}", + "type": "flat", + }, + ) + manager.update_local_metadata(enums.TransactionKeys.LAST_DATA_UPDATE.value, site=manager.site) + manager.update_local_metadata( + enums.ColumnTypesKeys.LAST_DATA_UPDATE.value, + value=column_dict, + metadata=manager.types_metadata, + meta_type=enums.JsonFilename.COLUMN_TYPES.value, + ) + manager.write_local_metadata( + metadata=manager.types_metadata, meta_type=enums.JsonFilename.COLUMN_TYPES.value + ) + manager.write_local_metadata() + if new_file: + manager.cache_api() + + +@decorators.generic_error_handler(msg="Error processing flat upload") +def process_flat_handler(event, context): + """manages event from S3, triggers file processing""" + del context + manager = s3_manager.S3Manager(event) + process_flat(manager) + res = functions.http_response(200, "Merge successful") + return res diff --git a/src/site_upload/process_flat/shared b/src/site_upload/process_flat/shared new file mode 120000 index 0000000..0ab0cb2 --- /dev/null +++ b/src/site_upload/process_flat/shared @@ -0,0 +1 @@ +../../shared \ No newline at end of file diff --git a/src/site_upload/process_upload/process_upload.py b/src/site_upload/process_upload/process_upload.py index b09aa96..b960807 100644 --- a/src/site_upload/process_upload/process_upload.py +++ b/src/site_upload/process_upload/process_upload.py @@ -23,15 +23,36 @@ def process_upload(s3_client, sns_client, s3_bucket_name: str, s3_key: str) -> N metadata = functions.read_metadata(s3_client, s3_bucket_name) path_params = s3_key.split("/") study = path_params[1] - data_package = path_params[2] + # This happens when we're processing flat files, due to having to condense one + # folder layer. + # TODO: revisit on targeted crawling + if "__" in path_params[2]: + data_package = path_params[2].split("__")[1] + else: + data_package = path_params[2] site = path_params[3] version = path_params[4] if s3_key.endswith(".parquet"): - if "__meta_" in s3_key or "/discovery__" in s3_key: + if ( + s3_key.endswith(f".{enums.UploadTypes.META.value}.parquet") + or "/discovery__" in s3_key + or "__meta_" in s3_key + ): new_key = f"{enums.BucketPath.STUDY_META.value}/{s3_key.split('/', 1)[-1]}" topic_sns_arn = os.environ.get("TOPIC_PROCESS_STUDY_META_ARN") sns_subject = "Process study metadata upload event" + elif s3_key.endswith(f".{enums.UploadTypes.FLAT.value}.parquet"): + new_key = f"{enums.BucketPath.LATEST.value}/{s3_key.split('/', 1)[-1]}" + topic_sns_arn = os.environ.get("TOPIC_PROCESS_FLAT_ARN") + sns_subject = "Process flat table upload event" + elif s3_key.endswith(f".{enums.UploadTypes.ARCHIVE.value}.parquet"): + # These may contain line level data, and so we just throw them out as a matter + # of policy + s3_client.delete_object(Bucket=s3_bucket_name, Key=s3_key) + logging.info(f"Deleted archive file at {s3_key}") + return else: + # TODO: Check for .cube.parquet prefix after older versions of the library phase out new_key = f"{enums.BucketPath.LATEST.value}/{s3_key.split('/', 1)[-1]}" topic_sns_arn = os.environ.get("TOPIC_PROCESS_COUNTS_ARN") sns_subject = "Process counts upload event" diff --git a/template.yaml b/template.yaml index c940248..49e711a 100644 --- a/template.yaml +++ b/template.yaml @@ -173,12 +173,15 @@ Resources: Variables: BUCKET_NAME: !Sub '${BucketNameParameter}-${AWS::AccountId}-${DeployStage}' TOPIC_PROCESS_COUNTS_ARN: !Ref SNSTopicProcessCounts + TOPIC_PROCESS_FLAT_ARN: !Ref SNSTopicProcessFlat TOPIC_PROCESS_STUDY_META_ARN: !Ref SNSTopicProcessStudyMeta Policies: - S3CrudPolicy: BucketName: !Sub '${BucketNameParameter}-${AWS::AccountId}-${DeployStage}' - SNSPublishMessagePolicy: TopicName: !GetAtt SNSTopicProcessCounts.TopicName + - SNSPublishMessagePolicy: + TopicName: !GetAtt SNSTopicProcessFlat.TopicName - SNSPublishMessagePolicy: TopicName: !GetAtt SNSTopicProcessStudyMeta.TopicName - Statement: @@ -215,7 +218,7 @@ Resources: BUCKET_NAME: !Sub '${BucketNameParameter}-${AWS::AccountId}-${DeployStage}' TOPIC_CACHE_API_ARN: !Ref SNSTopicCacheAPI Events: - ProcessUploadSNSEvent: + ProcessCountsUploadSNSEvent: Type: SNS Properties: Topic: !Ref SNSTopicProcessCounts @@ -238,6 +241,49 @@ Resources: LogGroupName: !Sub "/aws/lambda/${PowersetMergeFunction}" RetentionInDays: !Ref RetentionTime + ProcessFlatFunction: + Type: AWS::Serverless::Function + Properties: + FunctionName: !Sub 'CumulusAggProcessFlat-${DeployStage}' + Layers: [arn:aws:lambda:us-east-1:336392948345:layer:AWSSDKPandas-Python311:17] + CodeUri: ./src/site_upload/process_flat + Handler: process_flat.process_flat_handler + Runtime: "python3.11" + LoggingConfig: + ApplicationLogLevel: !Ref LogLevel + LogFormat: !Ref LogFormat + LogGroup: !Sub "/aws/lambda/CumulusAggProcessFlat-${DeployStage}" + MemorySize: 8192 + Timeout: 800 + Description: Merges and aggregates powerset count data + Environment: + Variables: + BUCKET_NAME: !Sub '${BucketNameParameter}-${AWS::AccountId}-${DeployStage}' + TOPIC_CACHE_API_ARN: !Ref SNSTopicCacheAPI + Events: + ProcessFlatUploadSNSEvent: + Type: SNS + Properties: + Topic: !Ref SNSTopicProcessFlat + Policies: + - S3CrudPolicy: + BucketName: !Sub '${BucketNameParameter}-${AWS::AccountId}-${DeployStage}' + - SNSPublishMessagePolicy: + TopicName: !GetAtt SNSTopicCacheAPI.TopicName + - Statement: + - Sid: KMSDecryptPolicy + Effect: Allow + Action: + - kms:Decrypt + Resource: + - !ImportValue cumulus-kms-KMSKeyArn + + ProcessFlatLogGroup: + Type: AWS::Logs::LogGroup + Properties: + LogGroupName: !Sub "/aws/lambda/${ProcessFlatFunction}" + RetentionInDays: !Ref RetentionTime + StudyPeriodFunction: Type: AWS::Serverless::Function Properties: @@ -717,6 +763,14 @@ Resources: - Key: Name Value: !Sub 'CumulusProcessCounts-${DeployStage}' + SNSTopicProcessFlat: + Type: AWS::SNS::Topic + Properties: + TopicName: !Sub 'CumulusProcessFlat-${DeployStage}' + Tags: + - Key: Name + Value: !Sub 'CumulusProcessFlat-${DeployStage}' + SNSTopicProcessStudyMeta: Type: AWS::SNS::Topic Properties: @@ -779,6 +833,7 @@ Resources: Targets: S3Targets: - Path: !Sub '${AggregatorBucket}/aggregates' + - Path: !Sub '${AggregatorBucket}/flat' AthenaWorkGroup: Type: AWS::Athena::WorkGroup @@ -845,6 +900,17 @@ Resources: - s3:GetObject - s3:PutObject Resource: !Sub "arn:aws:s3:::${BucketNameParameter}-${AWS::AccountId}-${DeployStage}/aggregates/*" + - Effect: Allow + Action: + - s3:GetObject + - s3:PutObject + Resource: !Sub "arn:aws:s3:::${BucketNameParameter}-${AWS::AccountId}-${DeployStage}/flat/*" + - Sid: KMSDecryptPolicy + Effect: Allow + Action: + - kms:Decrypt + Resource: + - !ImportValue cumulus-kms-KMSKeyArn ### API Gateways diff --git a/tests/conftest.py b/tests/conftest.py index 004619e..711eda2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -40,6 +40,8 @@ def _init_mock_data(s3_client, bucket, study, data_package, version): The following items are added: - Aggregates, with a site of plainsboro, in parquet and csv, for the study provided + - Flat tables, with a site of plainsboro, in parquet and csv, for the + study provided - a data_package cache for api testing - credentials for the 3 unit test hospitals (princeton, elsewhere, hope) @@ -58,6 +60,20 @@ def _init_mock_data(s3_client, bucket, study, data_package, version): f"{enums.BucketPath.CSVAGGREGATE.value}/{study}/" f"{study}__{data_package}/{version}/{study}__{data_package}__aggregate.csv", ) + s3_client.upload_file( + "./tests/test_data/flat_synthea_q_date_recent.parquet", + bucket, + f"{enums.BucketPath.FLAT.value}/{study}/{mock_utils.EXISTING_SITE}" + f"{study}__{data_package}__{version}/" + f"{study}__{data_package}__flat.parquet", + ) + s3_client.upload_file( + "./tests/test_data/flat_synthea_q_date_recent.csv", + bucket, + f"{enums.BucketPath.CSVFLAT.value}/{study}/{mock_utils.EXISTING_SITE}" + f"{study}__{data_package}__{version}/" + f"{study}__{data_package}__flat.csv", + ) s3_client.upload_file( "./tests/test_data/data_packages_cache.json", bucket, @@ -65,7 +81,7 @@ def _init_mock_data(s3_client, bucket, study, data_package, version): ) -@pytest.fixture(autouse=True) +@pytest.fixture(scope="session", autouse=True) def mock_env(): with mock.patch.dict(os.environ, mock_utils.MOCK_ENV): yield @@ -130,6 +146,7 @@ def mock_notification(): sns.start() sns_client = boto3.client("sns", region_name="us-east-1") sns_client.create_topic(Name="test-counts") + sns_client.create_topic(Name="test-flat") sns_client.create_topic(Name="test-meta") sns_client.create_topic(Name="test-cache") yield diff --git a/tests/mock_utils.py b/tests/mock_utils.py index 4dd5698..ae3795d 100644 --- a/tests/mock_utils.py +++ b/tests/mock_utils.py @@ -4,9 +4,10 @@ TEST_WORKGROUP = "cumulus-aggregator-test-wg" TEST_GLUE_DB = "cumulus-aggregator-test-db" TEST_PROCESS_COUNTS_ARN = "arn:aws:sns:us-east-1:123456789012:test-counts" +TEST_PROCESS_FLAT_ARN = "arn:aws:sns:us-east-1:123456789012:test-flat" TEST_PROCESS_STUDY_META_ARN = "arn:aws:sns:us-east-1:123456789012:test-meta" TEST_CACHE_API_ARN = "arn:aws:sns:us-east-1:123456789012:test-cache" -ITEM_COUNT = 9 +ITEM_COUNT = 13 DATA_PACKAGE_COUNT = 3 EXISTING_SITE = "princeton_plainsboro_teaching_hospital" @@ -27,6 +28,7 @@ "GLUE_DB_NAME": TEST_GLUE_DB, "WORKGROUP_NAME": TEST_WORKGROUP, "TOPIC_PROCESS_COUNTS_ARN": TEST_PROCESS_COUNTS_ARN, + "TOPIC_PROCESS_FLAT_ARN": TEST_PROCESS_FLAT_ARN, "TOPIC_PROCESS_STUDY_META_ARN": TEST_PROCESS_STUDY_META_ARN, "TOPIC_CACHE_API_ARN": TEST_CACHE_API_ARN, } @@ -126,7 +128,22 @@ def get_mock_column_types_metadata(): "last_data_update": "2023-02-24T15:08:07.771080+00:00", } } - } + }, + OTHER_STUDY: { + EXISTING_DATA_P: { + EXISTING_VERSION: { + "column_types_format_version": "1", + "columns": { + "cnt": "integer", + "gender": "string", + "age": "integer", + "race_display": "string", + "site": "string", + }, + "last_data_update": "2023-02-24T15:08:07.771080+00:00", + } + } + }, } diff --git a/tests/shared/test_functions.py b/tests/shared/test_functions.py index d096dde..20d7e71 100644 --- a/tests/shared/test_functions.py +++ b/tests/shared/test_functions.py @@ -1,10 +1,21 @@ +"""Unit tests for shared functions. + + +As of this writing, since a lot of this was historically covered by other tests, +this file does not contain a 1-1 set of tests to the source module, +instead focusing only on edge case scenarios (though in those cases, tests +should be comprehensive). 1-1 coverage is a desirable long term goal. +""" + from contextlib import nullcontext as does_not_raise from unittest import mock +import boto3 import pandas import pytest -from src.shared import functions, pandas_functions +from src.shared import enums, functions, pandas_functions +from tests import mock_utils @pytest.mark.parametrize( @@ -60,3 +71,44 @@ def test_column_datatypes(): "bool": "boolean", "string": "string", } + + +def test_update_metadata_error(mock_bucket): + with pytest.raises(ValueError): + enums.JsonFilename.FOO = "foo" + functions.update_metadata( + metadata={}, study="", data_package="", version="", target="", meta_type="foo" + ) + + +def test_get_s3_keys(mock_bucket): + s3_client = boto3.client("s3") + res = functions.get_s3_keys(s3_client, mock_utils.TEST_BUCKET, "") + assert len(res) == mock_utils.ITEM_COUNT + res = functions.get_s3_keys(s3_client, mock_utils.TEST_BUCKET, "", max_keys=2) + assert len(res) == mock_utils.ITEM_COUNT + res = functions.get_s3_keys(s3_client, mock_utils.TEST_BUCKET, "cache") + assert res == ["cache/data_packages.json"] + + +def test_latest_data_package_version(mock_bucket): + version = functions.get_latest_data_package_version( + mock_utils.TEST_BUCKET, f"{enums.BucketPath.AGGREGATE.value}/{mock_utils.EXISTING_STUDY}" + ) + assert version == mock_utils.EXISTING_VERSION + s3_client = boto3.client("s3") + s3_client.upload_file( + "./tests/test_data/count_synthea_patient_agg.parquet", + mock_utils.TEST_BUCKET, + f"{enums.BucketPath.AGGREGATE.value}/{mock_utils.EXISTING_STUDY}/" + f"{mock_utils.EXISTING_STUDY}__{mock_utils.EXISTING_DATA_P}/" + f"{mock_utils.EXISTING_STUDY}__{mock_utils.EXISTING_DATA_P}__{mock_utils.NEW_VERSION}/" + f"{mock_utils.EXISTING_STUDY}__{mock_utils.EXISTING_DATA_P}__aggregate.parquet", + ) + version = functions.get_latest_data_package_version( + mock_utils.TEST_BUCKET, f"{enums.BucketPath.AGGREGATE.value}/{mock_utils.EXISTING_STUDY}" + ) + version = functions.get_latest_data_package_version( + mock_utils.TEST_BUCKET, f"{enums.BucketPath.AGGREGATE.value}/not_a_study" + ) + assert version is None diff --git a/tests/shared/test_s3_manager.py b/tests/shared/test_s3_manager.py new file mode 100644 index 0000000..64f8666 --- /dev/null +++ b/tests/shared/test_s3_manager.py @@ -0,0 +1,205 @@ +import io +from unittest import mock + +import pandas +import pytest + +from src.shared import enums, s3_manager +from tests import mock_utils + +SNS_EVENT = { + "Records": [ + { + "Sns": { + "TopicArn": "arn", + "Message": "/study/study__encounter/site/study__encounter__version/file.parquet", + } + } + ] +} + + +def test_init_manager(mock_bucket): + manager = s3_manager.S3Manager(SNS_EVENT) + assert manager.s3_bucket_name == "cumulus-aggregator-site-counts-test" + assert manager.event_source == "arn" + assert manager.s3_key == "/study/study__encounter/site/study__encounter__version/file.parquet" + assert manager.study == "study" + assert manager.data_package == "encounter" + assert manager.site == "site" + assert manager.version == "version" + assert manager.metadata == mock_utils.get_mock_metadata() + assert manager.types_metadata == mock_utils.get_mock_column_types_metadata() + assert ( + manager.parquet_aggregate_path + == "s3://cumulus-aggregator-site-counts-test/aggregates/study/study__encounter/study__encounter__version/study__encounter__aggregate.parquet" + ) + assert ( + manager.csv_aggregate_path + == "s3://cumulus-aggregator-site-counts-test/csv_aggregates/study/study__encounter/version/study__encounter__aggregate.csv" + ) + assert manager.parquet_flat_key == ( + "flat/study/site/study__encounter__version/study__encounter_site__flat.parquet" + ) + assert manager.csv_flat_key == ( + "csv_flat/study/site/study__encounter__version/study__encounter_site__flat.parquet" + ) + + +@pytest.mark.parametrize( + "file,dest", + [ + ( + "s3://cumulus-aggregator-site-counts-test/aggregates/study/study__encounter/study__encounter__099/study__encounter__aggregate.parquet", + "s3://cumulus-aggregator-site-counts-test/aggregates/study/study__encounter/study__encounter__099/study__encounter__aggregate_moved.parquet", + ), + ( + "aggregates/study/study__encounter/study__encounter__099/study__encounter__aggregate.parquet", + "aggregates/study/study__encounter/study__encounter__099/study__encounter__aggregate_moved.parquet", + ), + ], +) +def test_copy_file(mock_bucket, file, dest): + manager = s3_manager.S3Manager(SNS_EVENT) + manager.copy_file(file, dest) + files = [ + file["Key"] + for file in manager.s3_client.list_objects_v2(Bucket=manager.s3_bucket_name)["Contents"] + ] + assert any( + "study/study__encounter/study__encounter__099/study__encounter__aggregate_moved.parquet" + in file + for file in files + ) + assert any( + "study/study__encounter/study__encounter__099/study__encounter__aggregate.parquet" in file + for file in files + ) + + +def test_get_list(mock_bucket): + manager = s3_manager.S3Manager(SNS_EVENT) + assert manager.get_data_package_list(enums.BucketPath.AGGREGATE.value) == [ + "s3://cumulus-aggregator-site-counts-test/aggregates/study/study__encounter/study__encounter__099/study__encounter__aggregate.parquet" + ] + + +@pytest.mark.parametrize( + "file,dest", + [ + ( + "s3://cumulus-aggregator-site-counts-test/aggregates/study/study__encounter/study__encounter__099/study__encounter__aggregate.parquet", + "s3://cumulus-aggregator-site-counts-test/aggregates/study/study__encounter/study__encounter__099/study__encounter__aggregate_moved.parquet", + ), + ( + "aggregates/study/study__encounter/study__encounter__099/study__encounter__aggregate.parquet", + "aggregates/study/study__encounter/study__encounter__099/study__encounter__aggregate_moved.parquet", + ), + ], +) +def test_move_file(mock_bucket, file, dest): + manager = s3_manager.S3Manager(SNS_EVENT) + manager.move_file(file, dest) + files = [ + file["Key"] + for file in manager.s3_client.list_objects_v2(Bucket=manager.s3_bucket_name)["Contents"] + ] + assert any( + "study/study__encounter/study__encounter__099/study__encounter__aggregate_moved.parquet" + in file + for file in files + ) + assert not any( + "study/study__encounter/study__encounter__099/study__encounter__aggregate.parquet" in file + for file in files + ) + + +@mock.patch("boto3.client") +def test_cache_api(mock_client, mock_bucket): + manager = s3_manager.S3Manager(SNS_EVENT) + manager.cache_api() + publish_args = mock_client.mock_calls[-1][2] + assert publish_args["TopicArn"] == mock_utils.TEST_CACHE_API_ARN + assert publish_args["Message"] == "data_packages" + assert publish_args["Subject"] == "data_packages" + + +def test_write_csv(mock_bucket): + df = pandas.DataFrame(data={"foo": [1, 2], "bar": [11, 22]}) + manager = s3_manager.S3Manager(SNS_EVENT) + manager.write_csv(df) + df2 = pandas.read_csv( + io.BytesIO( + manager.s3_client.get_object( + Bucket=manager.s3_bucket_name, + Key=( + "csv_aggregates/study/study__encounter/version/" + "study__encounter__aggregate.csv" + ), + )["Body"].read() + ) + ) + assert df.compare(df2).empty + + +@mock.patch("src.shared.s3_manager.S3Manager.cache_api") +def test_write_parquet(mock_cache, mock_bucket): + df = pandas.DataFrame(data={"foo": [1, 2], "bar": [11, 22]}) + manager = s3_manager.S3Manager(SNS_EVENT) + manager.write_parquet(df, False) + df2 = pandas.read_parquet( + io.BytesIO( + manager.s3_client.get_object( + Bucket=manager.s3_bucket_name, + Key=( + "aggregates/study/study__encounter/study__encounter__version/" + "study__encounter__aggregate.parquet" + ), + )["Body"].read() + ) + ) + assert df.compare(df2).empty + assert not mock_cache.called + manager.write_parquet( + df, + True, + path=( + f"s3://{mock_utils.TEST_BUCKET}/aggregates/study/study__encounter/" + "study__encounter__version/study__encounter__aggregate.parquet" + ), + ) + assert mock_cache.called + + +def test_update_local_metadata(mock_bucket): + manager = s3_manager.S3Manager(SNS_EVENT) + original_transactions = manager.metadata.copy() + original_types = manager.types_metadata.copy() + other_dict = {} + manager.update_local_metadata( + key="foo", + site=mock_utils.NEW_SITE, + value="bar", + metadata=other_dict, + extra_items={"foobar": "baz"}, + ) + assert mock_utils.NEW_SITE in other_dict.keys() + assert "foo" in other_dict[mock_utils.NEW_SITE]["study"]["encounter"]["version"].keys() + assert "foobar" in other_dict[mock_utils.NEW_SITE]["study"]["encounter"]["version"].keys() + assert original_transactions == manager.metadata + assert original_types == manager.types_metadata + manager.update_local_metadata("foo") + assert original_transactions != manager.metadata + assert original_types == manager.types_metadata + assert "foo" in manager.metadata["site"]["study"]["encounter"]["version"].keys() + + +def test_write_local_metadata(mock_bucket): + manager = s3_manager.S3Manager(SNS_EVENT) + manager.metadata = {"foo": "bar"} + manager.write_local_metadata() + metadata = manager.s3_client.get_object( + Bucket=manager.s3_bucket_name, Key="metadata/transactions.json" + )["Body"].read() + assert metadata == b'{\n "foo": "bar"\n}' diff --git a/tests/site_upload/test_powerset_merge.py b/tests/site_upload/test_powerset_merge.py index 0c79fc2..4db7648 100644 --- a/tests/site_upload/test_powerset_merge.py +++ b/tests/site_upload/test_powerset_merge.py @@ -1,8 +1,6 @@ import io -import os from contextlib import nullcontext as does_not_raise from datetime import UTC, datetime -from unittest import mock import awswrangler import boto3 @@ -18,7 +16,6 @@ EXISTING_STUDY, EXISTING_VERSION, ITEM_COUNT, - MOCK_ENV, NEW_SITE, NEW_STUDY, NEW_VERSION, @@ -124,7 +121,6 @@ ), ], ) -@mock.patch.dict(os.environ, MOCK_ENV) def test_powerset_merge_single_upload( upload_file, upload_path, @@ -159,7 +155,10 @@ def test_powerset_merge_single_upload( event = { "Records": [ { - "Sns": {"Message": f"{enums.BucketPath.LATEST.value}{event_key}"}, + "Sns": { + "Message": f"{enums.BucketPath.LATEST.value}{event_key}", + "TopicArn": "TOPIC_PROCESS_COUNTS_ARN", + }, } ] } @@ -238,6 +237,8 @@ def test_powerset_merge_single_upload( or item["Key"].startswith(enums.BucketPath.ERROR.value) or item["Key"].startswith(enums.BucketPath.ADMIN.value) or item["Key"].startswith(enums.BucketPath.CACHE.value) + or item["Key"].startswith(enums.BucketPath.FLAT.value) + or item["Key"].startswith(enums.BucketPath.CSVFLAT.value) or item["Key"].endswith("study_periods.json") ) if archives: @@ -258,7 +259,6 @@ def test_powerset_merge_single_upload( ("./tests/test_data/other_schema.parquet", True, 1), ], ) -@mock.patch.dict(os.environ, MOCK_ENV) def test_powerset_merge_join_study_data( upload_file, archives, @@ -298,8 +298,9 @@ def test_powerset_merge_join_study_data( "Sns": { "Message": f"{enums.BucketPath.LATEST.value}/{EXISTING_STUDY}" f"/{EXISTING_STUDY}__{EXISTING_DATA_P}/{NEW_SITE}" - f"/{EXISTING_VERSION}/encounter.parquet" - }, + f"/{EXISTING_VERSION}/encounter.parquet", + "TopicArn": "TOPIC_PROCESS_COUNTS_ARN", + } } ] } @@ -349,25 +350,6 @@ def test_expand_and_concat(mock_bucket, upload_file, load_empty, raises): TEST_BUCKET, s3_path, ) - powerset_merge.expand_and_concat_sets(df, f"s3://{TEST_BUCKET}/{s3_path}", EXISTING_STUDY) - - -def test_parquet_to_csv(mock_bucket): - bucket_root = "test" - subbucket_path = "/uploaded.parquet" - s3_client = boto3.client("s3", region_name="us-east-1") - s3_client.upload_file( - "./tests/test_data/cube_strings_with_commas.parquet", - TEST_BUCKET, - f"{bucket_root}/{subbucket_path}", - ) - powerset_merge.generate_csv_from_parquet(TEST_BUCKET, bucket_root, subbucket_path) - df = awswrangler.s3.read_csv( - f"s3://{TEST_BUCKET}/{bucket_root}/{subbucket_path.replace('.parquet','.csv')}" - ) - assert list(df["race"].dropna().unique()) == [ - "White", - "Black, or African American", - "Asian", - "American Indian, or Alaska Native", - ] + powerset_merge.expand_and_concat_powersets( + df, f"s3://{TEST_BUCKET}/{s3_path}", EXISTING_STUDY + ) diff --git a/tests/site_upload/test_process_flat.py b/tests/site_upload/test_process_flat.py new file mode 100644 index 0000000..7ab8b57 --- /dev/null +++ b/tests/site_upload/test_process_flat.py @@ -0,0 +1,45 @@ +from unittest import mock + +import boto3 + +from src.site_upload.process_flat import process_flat +from tests import mock_utils + + +@mock.patch.object(process_flat.s3_manager.S3Manager, "cache_api") +def test_process_flat(mock_cache, mock_bucket): + event = { + "Records": [ + { + "Sns": { + "TopicArn": "arn", + "Message": ( + "latest/study/study__encounter/site/study__encounter__version/file.parquet" + ), + } + } + ] + } + s3_client = boto3.client("s3") + files = [ + file["Key"] for file in s3_client.list_objects_v2(Bucket=mock_utils.TEST_BUCKET)["Contents"] + ] + s3_client.upload_file( + Bucket=mock_utils.TEST_BUCKET, + Key=event["Records"][0]["Sns"]["Message"], + Filename="./tests/test_data/count_synthea_patient_agg.parquet", + ) + process_flat.process_flat_handler(event, {}) + files = [ + file["Key"] for file in s3_client.list_objects_v2(Bucket=mock_utils.TEST_BUCKET)["Contents"] + ] + assert "flat/study/site/study__encounter__version/study__encounter_site__flat.parquet" in files + + mock_cache.reset_mock() + s3_client.upload_file( + Bucket=mock_utils.TEST_BUCKET, + Key=event["Records"][0]["Sns"]["Message"], + Filename="./tests/test_data/count_synthea_patient_agg.parquet", + ) + process_flat.process_flat_handler(event, {}) + assert not mock_cache.called diff --git a/tests/site_upload/test_process_upload.py b/tests/site_upload/test_process_upload.py index b9ede60..3e63387 100644 --- a/tests/site_upload/test_process_upload.py +++ b/tests/site_upload/test_process_upload.py @@ -57,13 +57,24 @@ 200, ITEM_COUNT + 1, ), - ( # Upload of the template study + ( # Upload of a flat file "./tests/test_data/cube_simple_example.parquet", - f"/template/{NEW_DATA_P}/{EXISTING_SITE}" f"/{EXISTING_VERSION}/document.parquet", - f"/template/{NEW_DATA_P}/{EXISTING_SITE}" f"/{EXISTING_VERSION}/document.parquet", + f"/{EXISTING_STUDY}/{NEW_DATA_P}/{EXISTING_SITE}" + f"/{EXISTING_VERSION}/document.flat.parquet", + f"/{EXISTING_STUDY}/{NEW_DATA_P}/{EXISTING_SITE}" + f"/{EXISTING_VERSION}/document.flat.parquet", 200, ITEM_COUNT + 1, ), + ( # Upload of an archive file (which should be deleted) + "./tests/test_data/cube_simple_example.parquet", + f"/{EXISTING_STUDY}/{NEW_DATA_P}/{EXISTING_SITE}" + f"/{EXISTING_VERSION}/document.archive.parquet", + f"/{EXISTING_STUDY}/{NEW_DATA_P}/{EXISTING_SITE}" + f"/{EXISTING_VERSION}/document.archive.parquet", + 200, + ITEM_COUNT, + ), ( # Non-parquet file "./tests/test_data/cube_simple_example.csv", f"/{EXISTING_STUDY}/{NEW_DATA_P}/{EXISTING_SITE}" f"/{EXISTING_VERSION}/document.csv", @@ -136,7 +147,8 @@ def test_process_upload( assert res["statusCode"] == status s3_res = s3_client.list_objects_v2(Bucket=TEST_BUCKET) assert len(s3_res["Contents"]) == expected_contents - found_archive = False + if event_key.endswith(".archive.parquet"): + return for item in s3_res["Contents"]: if item["Key"].endswith("aggregate.parquet"): assert item["Key"].startswith(enums.BucketPath.AGGREGATE.value) @@ -158,8 +170,6 @@ def test_process_upload( ) elif item["Key"].startswith(enums.BucketPath.STUDY_META.value): assert any(x in item["Key"] for x in ["_meta_", "/discovery__"]) - elif item["Key"].startswith(enums.BucketPath.ARCHIVE.value): - found_archive = True else: assert ( item["Key"].startswith(enums.BucketPath.LATEST.value) @@ -167,8 +177,9 @@ def test_process_upload( or item["Key"].startswith(enums.BucketPath.ERROR.value) or item["Key"].startswith(enums.BucketPath.ADMIN.value) or item["Key"].startswith(enums.BucketPath.CACHE.value) + or item["Key"].startswith(enums.BucketPath.FLAT.value) + or item["Key"].startswith(enums.BucketPath.CSVFLAT.value) + or item["Key"].startswith(enums.BucketPath.ARCHIVE.value) or item["Key"].endswith("study_periods.json") or item["Key"].endswith("column_types.json") ) - if found_archive: - assert "template" in upload_path diff --git a/tests/test_data/flat_synthea_q_date_recent.csv b/tests/test_data/flat_synthea_q_date_recent.csv new file mode 100644 index 0000000..567826c --- /dev/null +++ b/tests/test_data/flat_synthea_q_date_recent.csv @@ -0,0 +1,19 @@ +resource,subgroup,numerator,denominator,percentage +Procedure,,2,5,40.00 +Procedure,cumulus__all,2,5,40.00 +Observation,,0,0,0.00 +Observation,cumulus__all,0,0,0.00 +MedicationRequest,,0,0,0.00 +MedicationRequest,cumulus__all,0,0,0.00 +Immunization,,0,0,0.00 +Immunization,cumulus__all,0,0,0.00 +Encounter,,1,4,25.00 +Encounter,cumulus__all,1,4,25.00 +DocumentReference,,0,0,0.00 +DocumentReference,cumulus__all,0,0,0.00 +DiagnosticReport,,0,0,0.00 +DiagnosticReport,cumulus__all,0,0,0.00 +Condition,,2,4,50.00 +Condition,cumulus__all,2,4,50.00 +AllergyIntolerance,,0,0,0.00 +AllergyIntolerance,cumulus__all,0,0,0.00 diff --git a/tests/test_data/flat_synthea_q_date_recent.parquet b/tests/test_data/flat_synthea_q_date_recent.parquet new file mode 100644 index 0000000..4e5dd72 Binary files /dev/null and b/tests/test_data/flat_synthea_q_date_recent.parquet differ