From 69e7de73d70c0f84fb7a68bc29f6a2cddd54b210 Mon Sep 17 00:00:00 2001 From: Nayib Gloria <55710092+nayib-jose-gloria@users.noreply.github.com> Date: Fri, 9 Aug 2024 14:28:24 -0400 Subject: [PATCH 1/4] feat: Create a standalone, manual batch job that runs bulk collection and dataset rollbacks for migration error recovery (#7317) --- .happy/terraform/modules/batch/main.tf | 62 +++ backend/layers/business/business.py | 51 ++- backend/layers/business/exceptions.py | 6 + backend/layers/persistence/persistence.py | 13 + .../persistence/persistence_interface.py | 7 + .../layers/persistence/persistence_mock.py | 9 +- backend/layers/processing/rollback.py | 293 +++++++++++++ scripts/cxg_admin.py | 12 - scripts/cxg_admin_scripts/rollback.py | 82 ++++ scripts/cxg_admin_scripts/schema_migration.py | 13 - .../backend/layers/business/test_business.py | 52 +++ tests/unit/processing/test_rollback.py | 406 ++++++++++++++++++ 12 files changed, 979 insertions(+), 27 deletions(-) create mode 100644 backend/layers/processing/rollback.py create mode 100755 scripts/cxg_admin_scripts/rollback.py create mode 100644 tests/unit/processing/test_rollback.py diff --git a/.happy/terraform/modules/batch/main.tf b/.happy/terraform/modules/batch/main.tf index 85f06d64494df..d85b62b7c9a3b 100644 --- a/.happy/terraform/modules/batch/main.tf +++ b/.happy/terraform/modules/batch/main.tf @@ -123,6 +123,68 @@ resource aws_batch_job_definition dataset_metadata_update { }) } +resource aws_batch_job_definition rollback { + type = "container" + name = "dp-${var.deployment_stage}-${var.custom_stack_name}-rollback" + container_properties = jsonencode({ + "command": ["python3", "-m", "backend.layers.processing.rollback"], + "jobRoleArn": "${var.batch_role_arn}", + "image": "${var.image}", + "memory": 8000, + "environment": [ + { + "name": "ARTIFACT_BUCKET", + "value": "${var.artifact_bucket}" + }, + { + "name": "CELLXGENE_BUCKET", + "value": "${var.cellxgene_bucket}" + }, + { + "name": "DATASETS_BUCKET", + "value": "${var.datasets_bucket}" + }, + { + "name": "DEPLOYMENT_STAGE", + "value": "${var.deployment_stage}" + }, + { + "name": "AWS_DEFAULT_REGION", + "value": "${data.aws_region.current.name}" + }, + { + "name": "REMOTE_DEV_PREFIX", + "value": "${var.remote_dev_prefix}" + } + ], + "vcpus": 1, + "linuxParameters": { + "maxSwap": 0, + "swappiness": 0 + }, + "retryStrategy": { + "attempts": 3, + "evaluateOnExit": [ + { + "action": "RETRY", + "onReason": "Task failed to start" + }, + { + "action": "EXIT", + "onReason": "*" + } + ] + }, + "logConfiguration": { + "logDriver": "awslogs", + "options": { + "awslogs-group": "${aws_cloudwatch_log_group.cloud_watch_logs_group.id}", + "awslogs-region": "${data.aws_region.current.name}" + } + } +}) +} + resource aws_cloudwatch_log_group cloud_watch_logs_group { retention_in_days = 365 name = "/dp/${var.deployment_stage}/${var.custom_stack_name}/upload" diff --git a/backend/layers/business/business.py b/backend/layers/business/business.py index 123de4ec2bc04..69e564cc02fdd 100644 --- a/backend/layers/business/business.py +++ b/backend/layers/business/business.py @@ -40,6 +40,7 @@ DatasetVersionNotFoundException, InvalidURIException, MaxFileSizeExceededException, + NoPreviousCollectionVersionException, NoPreviousDatasetVersionException, ) from backend.layers.common import validation @@ -273,6 +274,19 @@ def get_unpublished_collection_version_from_canonical( unpublished_collection = collection return unpublished_collection + def get_unpublished_collection_versions_from_canonical( + self, collection_id: CollectionId + ) -> List[CollectionVersionWithDatasets]: + """ + Given a canonical collection_id, retrieves its latest unpublished versions (max of 2 with a migration_revision + and non-migration revision) + """ + unpublished_collections = [] + for collection in self.get_collection_versions_from_canonical(collection_id): + if collection.published_at is None: + unpublished_collections.append(collection) + return unpublished_collections + def get_collection_url(self, collection_id: str) -> str: return f"{CorporaConfig().collections_base_url}/collections/{collection_id}" @@ -437,7 +451,7 @@ def update_collection_version( ): for dataset in current_collection_version.datasets: if dataset.status.processing_status != DatasetProcessingStatus.SUCCESS: - self.logger.info( + logger.info( f"Dataset {dataset.version_id.id} is not successfully processed. Skipping metadata update." ) continue @@ -1323,3 +1337,38 @@ def restore_previous_dataset_version( self.database_provider.replace_dataset_in_collection_version( collection_version_id, current_version.version_id, previous_version_id ) + + def restore_previous_collection_version(self, collection_id: CollectionId) -> CollectionVersion: + """ + Restore the previously published collection version for a collection, if any exist. + + Returns CollectionVersion that was replaced, and is no longer linked to canonical collection. + + :param collection_id: The collection id to restore the previous version of. + """ + version_to_replace = self.get_collection_version_from_canonical(collection_id) + all_published_versions = list(self.get_all_published_collection_versions_from_canonical(collection_id)) + if len(all_published_versions) < 2: + raise NoPreviousCollectionVersionException(f"No previous collection version for collection {collection_id}") + + # get most recent previously published version + previous_version = None + previous_published_at = datetime.fromtimestamp(0) + + for version in all_published_versions: + if version.version_id == version_to_replace.version_id: + continue + if version.published_at > previous_published_at: + previous_version = version + previous_published_at = version.published_at + + logger.info( + { + "message": "Restoring previous collection version", + "collection_id": collection_id.id, + "replace_version_id": version_to_replace.version_id.id, + "restored_version_id": previous_version.version_id.id, + } + ) + self.database_provider.replace_collection_version(collection_id, previous_version.version_id) + return version_to_replace diff --git a/backend/layers/business/exceptions.py b/backend/layers/business/exceptions.py index 6c6643c86eb8a..f93348b6ba403 100644 --- a/backend/layers/business/exceptions.py +++ b/backend/layers/business/exceptions.py @@ -43,6 +43,12 @@ class NoPreviousDatasetVersionException(BusinessException): """ +class NoPreviousCollectionVersionException(BusinessException): + """ + Raised when a previous collection version is expected, but does not exist + """ + + class InvalidLinkException(BusinessException): def __init__(self, errors: Optional[List[str]] = None) -> None: self.errors: Optional[List[str]] = errors diff --git a/backend/layers/persistence/persistence.py b/backend/layers/persistence/persistence.py index c7c3bba0f8976..016e4d0647bde 100644 --- a/backend/layers/persistence/persistence.py +++ b/backend/layers/persistence/persistence.py @@ -1106,3 +1106,16 @@ def get_previous_dataset_version_id(self, dataset_id: DatasetId) -> Optional[Dat if version_id is None: return None return DatasetVersionId(str(version_id.id)) + + def replace_collection_version( + self, collection_id: CollectionId, new_collection_version_id: CollectionVersionId + ) -> None: + """ + Replaces the version_id of a canonical collection, and deletes the replaced CollectionVersionId. + """ + with self._manage_session() as session: + collection = session.query(CollectionTable).filter_by(id=collection_id.id).one() + replaced_version_id = str(collection.version_id) + collection.version_id = new_collection_version_id.id + replaced_collection_version = session.query(CollectionVersionTable).filter_by(id=replaced_version_id).one() + session.delete(replaced_collection_version) diff --git a/backend/layers/persistence/persistence_interface.py b/backend/layers/persistence/persistence_interface.py index 84743f086c941..2945642ffcbee 100644 --- a/backend/layers/persistence/persistence_interface.py +++ b/backend/layers/persistence/persistence_interface.py @@ -281,6 +281,13 @@ def replace_dataset_in_collection_version( Replaces an existing mapping between a collection version and a dataset version """ + def replace_collection_version( + self, collection_id: CollectionId, new_collection_version_id: CollectionVersionId + ) -> None: + """ + Replaces existing canonical Collection mapping with a new CollectionVersionId + """ + def get_dataset_mapped_version(self, dataset_id: DatasetId, get_tombstoned: bool) -> Optional[DatasetVersion]: """ Returns the dataset version mapped to a canonical dataset_id, or None if not existing diff --git a/backend/layers/persistence/persistence_mock.py b/backend/layers/persistence/persistence_mock.py index fe780220e5c9f..5dc252e46f902 100644 --- a/backend/layers/persistence/persistence_mock.py +++ b/backend/layers/persistence/persistence_mock.py @@ -155,7 +155,7 @@ def get_canonical_collection(self, collection_id: CollectionId) -> Optional[Cano def get_all_mapped_collection_versions( self, get_tombstoned: bool = False ) -> Iterable[CollectionVersion]: # TODO: add filters if needed - for version_id, collection_version in self.collections_versions.items(): + for version_id, collection_version in list(self.collections_versions.items()): if version_id in [c.version_id.id for c in self.collections.values()]: collection_id = collection_version.collection_id.id if not get_tombstoned and self.collections[collection_id].tombstoned: @@ -588,6 +588,13 @@ def replace_dataset_in_collection_version( collection_version.datasets[idx] = new_dataset_version_id return copy.deepcopy(new_dataset_version) + def replace_collection_version( + self, collection_id: CollectionId, new_collection_version_id: CollectionVersionId + ) -> None: + old_version_id = self.collections[collection_id.id].version_id + self.collections[collection_id.id].version_id = new_collection_version_id + del self.collections_versions[old_version_id.id] + def set_collection_version_datasets_order( self, collection_version_id: CollectionVersionId, diff --git a/backend/layers/processing/rollback.py b/backend/layers/processing/rollback.py new file mode 100644 index 0000000000000..e6dff20cc6626 --- /dev/null +++ b/backend/layers/processing/rollback.py @@ -0,0 +1,293 @@ +""" +This batch job is meant for rollback of datasets or collections in response to migration failures-- +either pre- or post- publish. It can either: +1) rollback all datasets to their previous dataset version in all private collections +2) rollback all datasets in an input list of private collections +3) rollback an input list of private datasets +4) rollback all public collections to their previous collection version +5) rollback an input list of public collections +depending on the input "ROLLBACK_TYPE" and optional "ENTITY_LIST" environment variables. + +Example usages: + +1) Rollback all datasets in all private collections: +$ aws batch submit-job --job-name rollback \ + --job-queue \ + --job-definition \ + --container-overrides '{ + "environment": [{"name": "ROLLBACK_TYPE", "value": "private_collections"}] + }' + +2) Rollback all datasets in an input list of private collections: +$ aws batch submit-job --job-name rollback \ + --job-queue \ + --job-definition \ + --container-overrides '{ + "environment": [{"name": "ROLLBACK_TYPE", "value": "private_collection_list"}, + {"name": "ENTITY_LIST", "value": "collection_version_id1,collection_version_id2"}] + }' + +3) Rollback an input list of private datasets: +$ aws batch submit-job --job-name rollback \ + --job-queue \ + --job-definition \ + --container-overrides '{ + "environment": [{"name": "ROLLBACK_TYPE", "value": "private_dataset_list"}, + {"name": "ENTITY_LIST", "value": "dataset_version_id1,dataset_version_id2"}] + }' + +4) Rollback all public collections to their previous collection version: +$ aws batch submit-job --job-name rollback \ + --job-queue \ + --job-definition \ + --container-overrides '{ + "environment": [{"name": "ROLLBACK_TYPE", "value": "public_collections"}] + }' + +5) Rollback an input list of public collections: +$ aws batch submit-job --job-name rollback \ + --job-queue \ + --job-definition \ + --container-overrides '{ + "environment": [{"name": "ROLLBACK_TYPE", "value": "public_collection_list"}, + {"name": "ENTITY_LIST", "value": "canonical_collection_id1,canonical_collection_id2"}] + }' + +cxg admin cli command to trigger the batch calls above are located in scripts/cxg_admin_scripts/rollback.py +""" + +import logging +import os +from enum import Enum +from typing import List + +from backend.layers.business.business import BusinessLogic, CollectionQueryFilter +from backend.layers.business.exceptions import ( + CollectionNotFoundException, +) +from backend.layers.common.entities import ( + CollectionId, + CollectionVersion, + CollectionVersionId, + DatasetVersion, + DatasetVersionId, +) +from backend.layers.persistence.persistence import DatabaseProvider +from backend.layers.thirdparty.s3_provider import S3Provider +from backend.layers.thirdparty.uri_provider import UriProvider + +logger = logging.getLogger(__name__) + + +class RollbackType(Enum): + PRIVATE_COLLECTIONS = "private_collections" + PUBLIC_COLLECTIONS = "public_collections" + PUBLIC_COLLECTION_LIST = "public_collection_list" + PRIVATE_COLLECTION_LIST = "private_collection_list" + PRIVATE_DATASET_LIST = "private_dataset_list" + + +class RollbackEntity: + def __init__( + self, business_logic: BusinessLogic, rollback_type: RollbackType, entity_id_list: List[str] = None + ) -> None: + self.business_logic = business_logic + self.rollback_type = rollback_type + self.entity_id_list = entity_id_list + + def rollback(self): + if self.rollback_type == RollbackType.PRIVATE_COLLECTIONS: + self.collections_private_rollback() + elif self.rollback_type == RollbackType.PUBLIC_COLLECTIONS: + self.collections_public_rollback() + elif self.rollback_type == RollbackType.PUBLIC_COLLECTION_LIST: + collection_id_list = [CollectionId(entity_id) for entity_id in self.entity_id_list] + self.collection_list_public_rollback(collection_id_list) + elif self.rollback_type == RollbackType.PRIVATE_COLLECTION_LIST: + collection_version_id_list = [CollectionVersionId(entity_id) for entity_id in self.entity_id_list] + self.collection_list_private_rollback(collection_version_id_list) + elif self.rollback_type == RollbackType.PRIVATE_DATASET_LIST: + dataset_version_id_list = [DatasetVersionId(entity_id) for entity_id in self.entity_id_list] + self.dataset_list_private_rollback(dataset_version_id_list) + else: + raise ValueError(f"Invalid rollback type: {self.rollback_type}") + + def dataset_list_private_rollback(self, dataset_version_id_list: List[DatasetVersionId]) -> None: + """ + Rolls back the Datasets in the CollectionVersions associated with each DatasetVersionId passed in, to their + respective previous, most recently created DatasetVersion. Then, triggers deletion of the DB references and S3 + assets for the rolled back DatasetVersions. + """ + rolled_back_datasets = [] + for dataset_version_id in dataset_version_id_list: + try: + rolled_back_dataset = self.dataset_private_rollback(dataset_version_id) + except Exception: + logger.exception(f"Failed to rollback DatasetVersion {dataset_version_id}") + continue + rolled_back_datasets.append(rolled_back_dataset) + self._clean_up_rolled_back_datasets(rolled_back_datasets) + + def dataset_private_rollback( + self, dataset_version_id: DatasetVersionId, collection_version_id: CollectionVersionId = None + ) -> DatasetVersion: + """ + For a given DatasetVersionId and unpublished CollectionVersionId, rolls back the associated Dataset in the + CollectionVersion to its previous, most recently created DatasetVersion and deletes the given DatasetVersion. + + :param dataset_version_id: DatasetVersionId of the DatasetVersion to rollback + :param collection_version_id: CollectionVersionId of the CollectionVersion to rollback the DatasetVersion + in. If not passed in, the CollectionVersionId will be determined from the DatasetVersionId. + :return: DatasetVersion that was rolled back + """ + cv_id = collection_version_id + dataset_version = self.business_logic.get_dataset_version(dataset_version_id) + if cv_id is None: + # account for collection potentially having a migration revision and a non-migration revision + collection_versions = self.business_logic.get_unpublished_collection_versions_from_canonical( + dataset_version.collection_id + ) + for cv in collection_versions: + if cv_id is not None: + break + for dataset in cv.datasets: + if dataset.version_id == dataset_version_id: + cv_id = cv.version_id + break + if cv_id is None: + raise CollectionNotFoundException( + f"An Associated CollectionVersion not found for DatasetVersion {dataset_version_id}" + ) + self.business_logic.restore_previous_dataset_version(cv_id, dataset_version.dataset_id) + return dataset_version + + def collections_private_rollback(self) -> None: + """ + Rollback all the datasets in all private collections. This will restore the + state of private collections to their pre-migration state. Then, triggers deletion of the DB + references and S3 assets for the rolled back DatasetVersions. + """ + filter = CollectionQueryFilter(is_published=False) + collections = self.business_logic.get_collections(filter) + self.collection_list_private_rollback([collection.version_id for collection in collections]) + + def collection_list_private_rollback(self, collection_version_id_list: List[CollectionVersionId]) -> None: + """ + Rollback all the datasets from the input list of private collections. This will restore the + state of the list of private collections to their pre-migration state. Then, triggers deletion of the DB + references and S3 assets for the rolled back DatasetVersions. + """ + rolled_back_datasets = [] + for collection_version_id in collection_version_id_list: + try: + collection_dataset_versions = self.collection_private_rollback(collection_version_id) + except Exception: + logger.exception(f"Failed to rollback private CollectionVersion {collection_version_id}") + continue + rolled_back_datasets.extend(collection_dataset_versions) + self._clean_up_rolled_back_datasets(rolled_back_datasets) + + def collection_private_rollback(self, collection_version_id: CollectionVersionId) -> List[DatasetVersion]: + """ + Rolls back the dataset versions for all datasets in the given private collection. This will restore the state of + the private collection to its pre-migration state. + """ + collection_version = self.business_logic.get_collection_version(collection_version_id) + if collection_version is None: + raise CollectionNotFoundException(f"CollectionVersion {collection_version_id} not found") + for dataset in collection_version.datasets: + try: + self.dataset_private_rollback(dataset.version_id, collection_version_id) + except Exception: + logger.exception(f"Failed to rollback DatasetVersion {dataset.version_id}") + continue + return collection_version.datasets + + def collections_public_rollback(self) -> None: + """ + Rollback each public collection to its previous, most recently published CollectionVersion (if one exists). + This will restore the state of public collections to their pre-migration state. Then, triggers deletion of the DB + references and S3 assets for the rolled back CollectionVersions and their associated DatasetVersions. + """ + filter = CollectionQueryFilter(is_published=True) + collections = self.business_logic.get_collections(filter) + self.collection_list_public_rollback([collection.collection_id for collection in collections]) + + def collection_list_public_rollback(self, collection_id_list: List[CollectionId]) -> None: + """ + Rollback each public collection in the input list to its previous, most recently published CollectionVersion (if + one exists). This will restore the state of the list of public collections to their pre-migration state. Then, + triggers deletion of the DB references and S3 assets for the rolled back CollectionVersions and their associated + DatasetVersions. + """ + rolled_back_collection_versions = [] + for collection_id in collection_id_list: + try: + rolled_back_collection_version = self.collection_public_rollback(collection_id) + except Exception as e: + logger.info(f"Failed to rollback Collection {collection_id}, {e}") + continue + rolled_back_collection_versions.append(rolled_back_collection_version) + self._clean_up_published_collection_versions(rolled_back_collection_versions) + + def collection_public_rollback(self, collection_id: CollectionId) -> CollectionVersion: + """ + For a given public CollectionId, rolls back the associated Collection to its previous, most recently published + CollectionVersion (if one exists). This will restore the state of the public collection to its pre-migration + state. + + :param collection_id: CollectionId of the Collection to rollback + :return: CollectionVersion that was rolled back + """ + return self.business_logic.restore_previous_collection_version(collection_id) + + def _clean_up_rolled_back_datasets(self, rolled_back_datasets: List[DatasetVersion]) -> None: + """ + Triggers deletion of the DB references and S3 assets for the rolled back DatasetVersions. + """ + self.business_logic.delete_dataset_versions(rolled_back_datasets) + + def _clean_up_published_collection_versions(self, rolled_back_collection_versions: List[CollectionVersion]) -> None: + """ + Triggers deletion of the DB references and S3 assets for the rolled back CollectionVersions' associated + DatasetVersions, if they are not associated with any still-existing CollectionVersions. + """ + datasets_to_rollback = [] + for rolled_back_collection_version in rolled_back_collection_versions: + dataset_rollback_candidates = rolled_back_collection_version.datasets + collection_version_history = self.business_logic.get_collection_versions_from_canonical( + rolled_back_collection_version.collection_id + ) + dataset_version_history = {dv.version_id.id: dv for cv in collection_version_history for dv in cv.datasets} + for dataset_rollback_candidate in dataset_rollback_candidates: + if dataset_rollback_candidate.version_id.id not in dataset_version_history: + datasets_to_rollback.append(dataset_rollback_candidate) + self._clean_up_rolled_back_datasets(datasets_to_rollback) + + +if __name__ == "__main__": + business_logic = BusinessLogic( + DatabaseProvider(), + None, + None, + None, + S3Provider(), + UriProvider(), + ) + rollback_type_str = os.environ.get("ROLLBACK_TYPE") + if rollback_type_str is None: + raise ValueError("ROLLBACK_TYPE is required") + + rollback_type = RollbackType(rollback_type_str) + if rollback_type in ( + RollbackType.PUBLIC_COLLECTION_LIST, + RollbackType.PRIVATE_COLLECTION_LIST, + RollbackType.PRIVATE_DATASET_LIST, + ): + entities_to_rollback = os.environ.get("ENTITY_LIST") + if entities_to_rollback is None: + raise ValueError(f"ENTITY_LIST is required for rollback type {rollback_type_str}") + entities_to_rollback_list = entities_to_rollback.split(",") + RollbackEntity(business_logic, rollback_type, entities_to_rollback_list).rollback() + else: + RollbackEntity(business_logic, rollback_type).rollback() diff --git a/scripts/cxg_admin.py b/scripts/cxg_admin.py index 4b118ffb3bb91..073a0a98d993b 100755 --- a/scripts/cxg_admin.py +++ b/scripts/cxg_admin.py @@ -299,18 +299,6 @@ def schema_migration_cli(ctx): os.environ["ARTIFACT_BUCKET"] = happy_config["s3_buckets"]["artifact"]["name"] -@schema_migration_cli.command() -@click.pass_context -@click.argument("report_path", type=click.Path(exists=True)) -def rollback_datasets(ctx, report_path: str): - """ - Used to rollback a datasets to a previous version. - - ./scripts/cxg_admin.py schema-migration --deployment dev rollback-dataset report.json - """ - schema_migration.rollback_dataset(ctx, report_path) - - @schema_migration_cli.command() @click.pass_context @click.argument("execution_id") diff --git a/scripts/cxg_admin_scripts/rollback.py b/scripts/cxg_admin_scripts/rollback.py new file mode 100755 index 0000000000000..c915518337241 --- /dev/null +++ b/scripts/cxg_admin_scripts/rollback.py @@ -0,0 +1,82 @@ +#! /usr/bin/env python + +import os +import sys +from time import time + +import boto3 +import click + +pkg_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) # noqa +sys.path.insert(0, pkg_root) # noqa + + +@click.group() +def cli(): + pass + + +@cli.command() +@click.argument("deployment_stage", type=click.Choice(["dev", "staging", "prod", "rdev"])) +@click.argument( + "rollback_type", + type=click.Choice( + [ + "private_dataset_list", + "private_collection_list", + "private_collections", + "public_collection_list", + "public_collections", + ] + ), +) +@click.option( + "--entity-id-list", + default="", + help="comma-delimited ID list of entities to rollback, if using '*_list' rollback type", +) +@click.option("--rdev-stack-name", default="", help="if deployment stage is rdev, specific stack name") +def trigger_rollback(deployment_stage, rollback_type, entity_id_list, rdev_stack_name): + """ + Used to rollback datasets or collections to a previous version. + + ./scripts/cxg_admin_scripts/rollback.py trigger-rollback --entity-id-list entity_id1,entity_id2 + """ + rollback(rollback_type, deployment_stage, entity_id_list, rdev_stack_name) + + +def rollback(rollback_type: str, deployment_stage: str, entity_id_list: str = "", rdev_stack_name: str = ""): + stack_name = f"{deployment_stage}stack" + if deployment_stage == "staging": + stack_name = "stagestack" + elif deployment_stage == "rdev": + stack_name = rdev_stack_name + + client = boto3.client("batch") + response = client.submit_job( + jobName=f"rollback_{int(time())}", + jobQueue=f"schema_migration-{deployment_stage}", + jobDefinition=f"dp-{deployment_stage}-{stack_name}-rollback", + containerOverrides={ + "environment": [ + { + "name": "ROLLBACK_TYPE", + "value": rollback_type, + }, + { + "name": "ENTITY_LIST", + "value": entity_id_list, + }, + ] + }, + ) + + click.echo( + f"Batch Job executing: " + f"https://us-west-2.console.aws.amazon.com/states/home?region=us-west-2#/executions/details/" + f"{response['jobArn']}" + ) + + +if __name__ == "__main__": + cli() diff --git a/scripts/cxg_admin_scripts/schema_migration.py b/scripts/cxg_admin_scripts/schema_migration.py index 89c92cfb578c5..d55eff3ef9da1 100644 --- a/scripts/cxg_admin_scripts/schema_migration.py +++ b/scripts/cxg_admin_scripts/schema_migration.py @@ -2,22 +2,9 @@ from pathlib import Path from backend.common.utils.json import CustomJSONEncoder -from backend.layers.common.entities import CollectionVersionId, DatasetId from backend.layers.processing.schema_migration import SchemaMigrate -def rollback_dataset(ctx, report_path: Path): - with report_path.open("r") as f: - report = json.load(f) - for entry in report: - if entry["rollback"] is True: - collection_version_id = entry["collection_version_id"] - dataset_id = entry["dataset_id"] - ctx.obj["business_logic"].restore_previous_dataset_version( - CollectionVersionId(collection_version_id), DatasetId(dataset_id) - ) - - def generate_report(ctx, execution_id: str, report_path: str, artifact_bucket: str): schema_migration = SchemaMigrate(ctx.obj["business_logic"], None) report = schema_migration.report(execution_id=execution_id, artifact_bucket=artifact_bucket, dry_run=True) diff --git a/tests/unit/backend/layers/business/test_business.py b/tests/unit/backend/layers/business/test_business.py index 283d9de69a237..9a2a3002a1520 100644 --- a/tests/unit/backend/layers/business/test_business.py +++ b/tests/unit/backend/layers/business/test_business.py @@ -33,6 +33,7 @@ DatasetNotFoundException, InvalidMetadataException, InvalidURIException, + NoPreviousCollectionVersionException, NoPreviousDatasetVersionException, ) from backend.layers.common.entities import ( @@ -495,6 +496,26 @@ def test_get_collection_version_for_tombstoned_collection(self): self.assertIsNotNone(past_version_tombstoned) self.assertEqual(True, past_version_tombstoned.canonical_collection.tombstoned) + def test_get_unpublished_collection_versions_from_canonical(self): + """ + Unpublished Collection versions can be retrieved from a canonical Collection + """ + published_collection = self.initialize_published_collection() + non_migration_revision = self.business_logic.create_collection_version( + published_collection.collection_id, is_auto_version=False + ) + migration_revision = self.business_logic.create_collection_version( + published_collection.collection_id, is_auto_version=True + ) + + unpublished_versions = self.business_logic.get_unpublished_collection_versions_from_canonical( + published_collection.collection_id + ) + unpublished_version_ids = [version.version_id.id for version in unpublished_versions] + self.assertCountEqual( + [non_migration_revision.version_id.id, migration_revision.version_id.id], unpublished_version_ids + ) + class TestGetAllCollections(BaseBusinessLogicTestCase): def test_get_all_collections_unfiltered_ok(self): @@ -2877,6 +2898,37 @@ def test__delete_datasets_from_public_access_bucket(self): ) self.assertEqual(expected_delete_keys, actual_delete_keys) + def test__restore_previous_collection_version(self): + """ + Test restoring a previous version of a collection + """ + original_collection_version = self.initialize_published_collection() + collection_id = original_collection_version.collection_id + new_version = self.business_logic.create_collection_version(collection_id) + self.business_logic.publish_collection_version(new_version.version_id) + + # Restore the previous version + self.business_logic.restore_previous_collection_version(collection_id) + + # Fetch the collection + restored_collection = self.business_logic.get_canonical_collection(collection_id) + + # Ensure the collection is pointing back at the original collection version + self.assertEqual(original_collection_version.version_id, restored_collection.version_id) + + # Ensure replaced CollectionVersion is deleted from DB + self.assertIsNone(self.database_provider.get_collection_version(new_version.version_id)) + + def test__restore_previous_collection_version__no_previous_versions(self): + """ + Test restoring a previous version of a collection fails when there are no previous versions + """ + original_collection_version = self.initialize_published_collection() + collection_id = original_collection_version.collection_id + + with self.assertRaises(NoPreviousCollectionVersionException): + self.business_logic.restore_previous_collection_version(collection_id) + class TestGetEntitiesBySchema(BaseBusinessLogicTestCase): def test_get_latest_published_collection_versions_by_schema(self): diff --git a/tests/unit/processing/test_rollback.py b/tests/unit/processing/test_rollback.py new file mode 100644 index 0000000000000..3828e3cf00373 --- /dev/null +++ b/tests/unit/processing/test_rollback.py @@ -0,0 +1,406 @@ +from datetime import datetime +from unittest.mock import Mock + +import pytest + +from backend.layers.business.business import BusinessLogic +from backend.layers.persistence.persistence_mock import DatabaseProviderMock +from backend.layers.processing.rollback import RollbackEntity, RollbackType +from backend.layers.thirdparty.s3_provider_mock import MockS3Provider +from backend.layers.thirdparty.uri_provider import UriProviderInterface + + +@pytest.fixture +def business_logic(): + return BusinessLogic( + DatabaseProviderMock(), + None, + None, + None, + MockS3Provider(), + UriProviderInterface(), + ) + + +@pytest.fixture +def rollback_entity_private_collections(business_logic): + return RollbackEntity(business_logic, RollbackType.PRIVATE_COLLECTIONS) + + +@pytest.fixture +def rollback_entity_private_collection_list(business_logic): + return RollbackEntity(business_logic, RollbackType.PRIVATE_COLLECTION_LIST, list()) + + +@pytest.fixture +def rollback_entity_public_collections(business_logic): + return RollbackEntity(business_logic, RollbackType.PUBLIC_COLLECTIONS) + + +@pytest.fixture +def rollback_entity_public_collection_list(business_logic): + return RollbackEntity(business_logic, RollbackType.PUBLIC_COLLECTION_LIST, list()) + + +@pytest.fixture +def rollback_entity_private_dataset_list(business_logic): + return RollbackEntity(business_logic, RollbackType.PRIVATE_DATASET_LIST, list()) + + +def initialize_unpublished_collection(rollback_entity, num_datasets): + version = rollback_entity.business_logic.database_provider.create_canonical_collection( + "owner", + "curator_name", + None, + ) + for _ in range(num_datasets): + dataset_version = rollback_entity.business_logic.database_provider.create_canonical_dataset( + version.version_id, + ) + rollback_entity.business_logic.database_provider.add_dataset_to_collection_version_mapping( + version.version_id, dataset_version.version_id + ) + return rollback_entity.business_logic.database_provider.get_collection_version_with_datasets(version.version_id) + + +def initialize_published_collection(rollback_entity, num_datasets): + version = initialize_unpublished_collection(rollback_entity, num_datasets=num_datasets) + + rollback_entity.business_logic.database_provider.finalize_collection_version( + version.collection_id, + version.version_id, + "5.1.0", + "1.0.0", + published_at=datetime.utcnow(), + ) + return rollback_entity.business_logic.database_provider.get_collection_version_with_datasets(version.version_id) + + +def create_and_publish_collection_revision(rollback_entity, collection_id, update_first_dataset_only=False): + new_collection_version = rollback_entity.business_logic.create_collection_version(collection_id) + + for dataset in new_collection_version.datasets: + rollback_entity.business_logic.create_empty_dataset_version_for_current_dataset( + new_collection_version.version_id, dataset.version_id + ) + if update_first_dataset_only: + break + + rollback_entity.business_logic.database_provider.finalize_collection_version( + collection_id, + new_collection_version.version_id, + "5.1.0", + "1.0.0", + published_at=datetime.utcnow(), + ) + return rollback_entity.business_logic.database_provider.get_collection_version_with_datasets( + new_collection_version.version_id + ) + + +# Tests + + +@pytest.mark.parametrize( + "rollback_args", + [ + ("rollback_entity_private_collections", "collections_private_rollback"), + ("rollback_entity_private_collection_list", "collection_list_private_rollback"), + ("rollback_entity_public_collections", "collections_public_rollback"), + ("rollback_entity_public_collection_list", "collection_list_public_rollback"), + ("rollback_entity_private_dataset_list", "dataset_list_private_rollback"), + ], +) +def test_rollback(request, rollback_args): + rollback_entity_name, rollback_function_name = rollback_args + rollback_entity = request.getfixturevalue(rollback_entity_name) + setattr(rollback_entity, rollback_function_name, Mock()) + rollback_entity.rollback() + assert getattr(rollback_entity, rollback_function_name).call_count == 1 + + +def test_rollback__unsupported_rollback_type(): + with pytest.raises(ValueError): + RollbackEntity(Mock(), "unsupported_rollback_type").rollback() + + +# TestPrivateDatasetRollback + + +@pytest.mark.parametrize("pass_arg_collection_version_id", [True, False]) +def test_rollback_private_dataset(rollback_entity_private_collections, pass_arg_collection_version_id): + business_logic = rollback_entity_private_collections.business_logic + + collection_version = initialize_unpublished_collection(rollback_entity_private_collections, num_datasets=1) + original_dataset_version = collection_version.datasets[0] + newer_dataset_version_id = business_logic.create_empty_dataset_version_for_current_dataset( + collection_version.version_id, original_dataset_version.version_id + ).version_id + newest_dataset_version_id = business_logic.create_empty_dataset_version_for_current_dataset( + collection_version.version_id, newer_dataset_version_id + ).version_id + + # Test with and without optional collection_version_id arg + collection_version_id = collection_version.version_id if pass_arg_collection_version_id else None + rolled_back_version = rollback_entity_private_collections.dataset_private_rollback( + newest_dataset_version_id, collection_version_id + ) + + # Assert returns expected rolled back version + assert rolled_back_version.version_id.id == newest_dataset_version_id.id + + # Assert DatasetVersion is rolled back to most recent previous version, not the "original" + restored_dataset_version = business_logic.get_collection_version(collection_version.version_id).datasets[0] + assert restored_dataset_version.version_id.id == newer_dataset_version_id.id + + +def test_rollback_private_dataset_list(rollback_entity_private_dataset_list): + business_logic = rollback_entity_private_dataset_list.business_logic + + collection_version = initialize_unpublished_collection(rollback_entity_private_dataset_list, num_datasets=3) + original_dataset_versions = collection_version.datasets + new_dataset_version_ids = [ + business_logic.create_empty_dataset_version_for_current_dataset( + collection_version.version_id, original_dataset_version.version_id + ).version_id + for original_dataset_version in original_dataset_versions + ] + + # Rollback two of three new dataset versions + rollback_entity_private_dataset_list.dataset_list_private_rollback( + [new_dataset_version_ids[0], new_dataset_version_ids[1]] + ) + + post_rollback_dataset_version_ids = [ + dataset_version.version_id.id + for dataset_version in business_logic.get_collection_version(collection_version.version_id).datasets + ] + + # Assert collection version is pointing to the original dataset version IDs for the rolled back datasets + assert len(post_rollback_dataset_version_ids) == 3 + assert original_dataset_versions[0].version_id.id in post_rollback_dataset_version_ids + assert original_dataset_versions[1].version_id.id in post_rollback_dataset_version_ids + # Assert collection version is still pointing to newest dataset version for the non-rolled back dataset + assert new_dataset_version_ids[2].id in post_rollback_dataset_version_ids + + +# TestPrivateCollectionRollback + + +def test_rollback_private_collection(rollback_entity_private_collections): + business_logic = rollback_entity_private_collections.business_logic + + collection_version = initialize_unpublished_collection(rollback_entity_private_collections, num_datasets=2) + original_dataset_versions = collection_version.datasets + newer_dataset_version_ids = [ + business_logic.create_empty_dataset_version_for_current_dataset( + collection_version.version_id, original_dataset_version.version_id + ).version_id + for original_dataset_version in original_dataset_versions + ] + # Create a third dataset version for each dataset to roll back from, to test we rollback to "newer", not original + for newer_dataset_version_id in newer_dataset_version_ids: + business_logic.create_empty_dataset_version_for_current_dataset( + collection_version.version_id, newer_dataset_version_id + ) + + rollback_entity_private_collections.collection_private_rollback(collection_version.version_id) + + post_rollback_dataset_version_ids = [ + dataset_version.version_id.id + for dataset_version in business_logic.get_collection_version(collection_version.version_id).datasets + ] + + # Assert collection version is pointing to the previous most recent dataset version IDs for all datasets + assert len(post_rollback_dataset_version_ids) == 2 + assert newer_dataset_version_ids[0].id in post_rollback_dataset_version_ids + assert newer_dataset_version_ids[1].id in post_rollback_dataset_version_ids + + +def test_rollback_private_collections(rollback_entity_private_collections): + business_logic = rollback_entity_private_collections.business_logic + + original_collection_versions = [ + initialize_unpublished_collection(rollback_entity_private_collections, num_datasets=1) for _ in range(2) + ] + # Create a newer dataset version for each dataset to roll back from + for collection_version in original_collection_versions: + business_logic.create_empty_dataset_version_for_current_dataset( + collection_version.version_id, collection_version.datasets[0].version_id + ) + # Create published collection + original_published_collection_version = initialize_published_collection( + rollback_entity_private_collections, num_datasets=1 + ) + new_published_collection_version = create_and_publish_collection_revision( + rollback_entity_private_collections, original_published_collection_version.collection_id + ) + + rollback_entity_private_collections.collections_private_rollback() + + # Assert unpublished collection versions datasets are all rolled back + for original_collection_version in original_collection_versions: + rolled_back_collection_version = business_logic.get_collection_version(original_collection_version.version_id) + assert ( + rolled_back_collection_version.datasets[0].version_id.id + == original_collection_version.datasets[0].version_id.id + ) + + # Assert published collection version is not rolled back + published_collection = business_logic.get_canonical_collection(new_published_collection_version.collection_id) + published_collection_version = business_logic.get_collection_version(published_collection.version_id) + assert ( + published_collection_version.datasets[0].version_id.id + == new_published_collection_version.datasets[0].version_id.id + ) + + +def test_rollback_private_collection_list(rollback_entity_private_collection_list): + business_logic = rollback_entity_private_collection_list.business_logic + + original_collection_versions = [ + initialize_unpublished_collection(rollback_entity_private_collection_list, num_datasets=1) for _ in range(3) + ] + new_dataset_versions = [ + business_logic.create_empty_dataset_version_for_current_dataset( + collection_version.version_id, collection_version.datasets[0].version_id + ).version_id.id + for collection_version in original_collection_versions + ] + + rollback_entity_private_collection_list.collection_list_private_rollback( + [ + original_collection_versions[0].version_id, + original_collection_versions[1].version_id, + ] + ) + + # Assert collection versions are pointing to the original dataset version IDs for the rolled back collections + for original_collection_version in original_collection_versions[:2]: + rolled_back_collection_version = business_logic.get_collection_version(original_collection_version.version_id) + assert ( + rolled_back_collection_version.datasets[0].version_id.id + == original_collection_version.datasets[0].version_id.id + ) + + # Assert collection version is still pointing to newest dataset version for the non-rolled back collection + assert ( + business_logic.get_collection_version(original_collection_versions[2].version_id).datasets[0].version_id.id + == new_dataset_versions[2] + ) + + +# TestPublishedCollectionRollback + + +def test_rollback_public_collection(rollback_entity_public_collections): + business_logic = rollback_entity_public_collections.business_logic + + original_collection_version = initialize_published_collection(rollback_entity_public_collections, num_datasets=1) + newer_collection_version = create_and_publish_collection_revision( + rollback_entity_public_collections, original_collection_version.collection_id + ) + newest_collection_version = create_and_publish_collection_revision( + rollback_entity_public_collections, newer_collection_version.collection_id + ) + + rollback_entity_public_collections.collection_public_rollback(newest_collection_version.collection_id) + + rolled_back_collection_version = business_logic.get_canonical_collection(newest_collection_version.collection_id) + + # Assert CollectionVersionId is rolled back to most recent previous version, not the "original" + assert rolled_back_collection_version.version_id.id == newer_collection_version.version_id.id + + +def test_rollback_published_collections(rollback_entity_public_collections): + business_logic = rollback_entity_public_collections.business_logic + + original_collection_versions = [ + initialize_published_collection(rollback_entity_public_collections, num_datasets=1) for _ in range(2) + ] + + for original_collection_version in original_collection_versions: + create_and_publish_collection_revision( + rollback_entity_public_collections, original_collection_version.collection_id + ) + + private_collection_version = initialize_unpublished_collection(rollback_entity_public_collections, num_datasets=1) + new_dataset_version = business_logic.create_empty_dataset_version_for_current_dataset( + private_collection_version.version_id, private_collection_version.datasets[0].version_id + ) + + rollback_entity_public_collections.collections_public_rollback() + + # Assert published collection versions are all rolled back + for original_collection_version in original_collection_versions: + rolled_back_collection = business_logic.get_canonical_collection(original_collection_version.collection_id) + assert rolled_back_collection.version_id.id == original_collection_version.version_id.id + + # Assert private collection version is not rolled back + private_collection = business_logic.get_collection_version(private_collection_version.version_id) + assert private_collection.datasets[0].version_id.id == new_dataset_version.version_id.id + + +def test_rollback_published_collection_list(rollback_entity_public_collection_list): + # init 2-3 public collections with 1 dataset each with prior versions. Rollback 1-2. Check those are rolled back, + # others are pointing to new collection version still + business_logic = rollback_entity_public_collection_list.business_logic + + original_collection_versions = [ + initialize_published_collection(rollback_entity_public_collection_list, num_datasets=1) for _ in range(3) + ] + + new_collection_versions = [ + create_and_publish_collection_revision( + rollback_entity_public_collection_list, original_collection_version.collection_id + ) + for original_collection_version in original_collection_versions + ] + + rollback_entity_public_collection_list.collection_list_public_rollback( + [ + original_collection_versions[0].collection_id, + original_collection_versions[1].collection_id, + ] + ) + + # Assert collection versions are pointing to the original dataset version IDs for the rolled back collections + for original_collection_version in original_collection_versions[:2]: + rolled_back_collection = business_logic.get_canonical_collection(original_collection_version.collection_id) + assert rolled_back_collection.version_id.id == original_collection_version.version_id.id + + # Assert collection is still pointing to newest collection version for the non-rolled back collection + assert ( + business_logic.get_canonical_collection(original_collection_versions[2].collection_id).version_id.id + == new_collection_versions[2].version_id.id + ) + + +# TestRollbackCleanUp + + +def test__clean_up(rollback_entity_public_collections): + business_logic = rollback_entity_public_collections.business_logic + + original_collection_version = initialize_published_collection(rollback_entity_public_collections, num_datasets=2) + + # publish 2 new collection versions for total version history of 3 + # only update 1 of 2 datasets + collection_revisions = [ + create_and_publish_collection_revision( + rollback_entity_public_collections, + original_collection_version.collection_id, + update_first_dataset_only=True, + ) + for _ in range(2) + ] + + rollback_entity_public_collections.collections_public_rollback() + + rolled_back_revision = collection_revisions[-1] + assert business_logic.get_collection_version(rolled_back_revision.version_id) is None + + revised_dataset = rolled_back_revision.datasets[0] + unrevised_dataset = rolled_back_revision.datasets[1] + assert business_logic.get_dataset_version(revised_dataset.version_id) is None + assert business_logic.get_dataset_version(unrevised_dataset.version_id) is not None From 05066e60b54453035f13109f4bb7184f2a5d0968 Mon Sep 17 00:00:00 2001 From: Nayib Gloria <55710092+nayib-jose-gloria@users.noreply.github.com> Date: Thu, 15 Aug 2024 15:19:02 -0400 Subject: [PATCH 2/4] feat: enforce canonical format of anndata for seurat conversion (#7326) --- backend/layers/processing/process_seurat.py | 9 +++ .../layers/processing/utils/matrix_utils.py | 13 +++ tests/unit/processing/test_matrix_utils.py | 79 ++++++++++++++----- 3 files changed, 82 insertions(+), 19 deletions(-) diff --git a/backend/layers/processing/process_seurat.py b/backend/layers/processing/process_seurat.py index 3156fdb8de652..d0a4fc7dddbf9 100644 --- a/backend/layers/processing/process_seurat.py +++ b/backend/layers/processing/process_seurat.py @@ -13,6 +13,7 @@ ) from backend.layers.processing.logger import logit from backend.layers.processing.process_logic import ProcessingLogic +from backend.layers.processing.utils.matrix_utils import enforce_canonical_format from backend.layers.processing.utils.rds_citation_from_h5ad import rds_citation_from_h5ad from backend.layers.thirdparty.s3_provider import S3ProviderInterface from backend.layers.thirdparty.uri_provider import UriProviderInterface @@ -74,6 +75,14 @@ def process(self, dataset_version_id: DatasetVersionId, artifact_bucket: str, da adata = anndata.read_h5ad(labeled_h5ad_filename) if "citation" in adata.uns: adata.uns["citation"] = rds_citation_from_h5ad(adata.uns["citation"]) + + # enforce for canonical + logger.info("enforce canonical format in X") + enforce_canonical_format(adata) + if adata.raw: + logger.info("enforce canonical format in raw.X") + enforce_canonical_format(adata.raw) + adata.write_h5ad(labeled_h5ad_filename) # Use Seurat to convert to RDS diff --git a/backend/layers/processing/utils/matrix_utils.py b/backend/layers/processing/utils/matrix_utils.py index c0ab4780ee3ca..0572c63230500 100644 --- a/backend/layers/processing/utils/matrix_utils.py +++ b/backend/layers/processing/utils/matrix_utils.py @@ -2,6 +2,8 @@ import numpy as np +logger: logging.Logger = logging.getLogger("matrix_utils") + def is_matrix_sparse(matrix: np.ndarray, sparse_threshold): """ @@ -57,3 +59,14 @@ def is_matrix_sparse(matrix: np.ndarray, sparse_threshold): is_sparse = (100.0 * number_of_non_zero_elements / total_number_of_matrix_elements) < sparse_threshold return is_sparse + + +def enforce_canonical_format(adata): + """ + Enforce canonical format for an AnnData, if not already in canonical format. This function will modify the + matrix in place. + """ + X = adata.X + if hasattr(X, "has_canonical_format") and not X.has_canonical_format: + logger.warning("noncanonical data found in X; converting to canonical format using sum_duplicates.") + X.sum_duplicates() diff --git a/tests/unit/processing/test_matrix_utils.py b/tests/unit/processing/test_matrix_utils.py index 63f9740244277..58dcc74a2b202 100644 --- a/tests/unit/processing/test_matrix_utils.py +++ b/tests/unit/processing/test_matrix_utils.py @@ -1,23 +1,30 @@ -import unittest +import logging +from unittest.mock import Mock import numpy as np +import pytest +from anndata import AnnData +from scipy.sparse import coo_matrix -from backend.layers.processing.utils.matrix_utils import is_matrix_sparse +from backend.layers.processing.utils.matrix_utils import enforce_canonical_format, is_matrix_sparse +LOGGER = logging.getLogger("matrix_utils") +LOGGER.propagate = True -class TestMatrixUtils(unittest.TestCase): + +class TestMatrixUtils: def test__is_matrix_sparse__zero_and_one_hundred_percent_threshold(self): matrix = np.array([1, 2, 3]) - self.assertFalse(is_matrix_sparse(matrix, 0)) - self.assertTrue(is_matrix_sparse(matrix, 100)) + assert not is_matrix_sparse(matrix, 0) + assert is_matrix_sparse(matrix, 100) def test__is_matrix_sparse__partially_populated_sparse_matrix_returns_true(self): matrix = np.zeros([3, 4]) matrix[2][3] = 1.0 matrix[1][1] = 2.2 - self.assertTrue(is_matrix_sparse(matrix, 50)) + assert is_matrix_sparse(matrix, 50) def test__is_matrix_sparse__partially_populated_dense_matrix_returns_false(self): matrix = np.zeros([2, 2]) @@ -25,24 +32,58 @@ def test__is_matrix_sparse__partially_populated_dense_matrix_returns_false(self) matrix[0][1] = 2.2 matrix[1][1] = 3.7 - self.assertFalse(is_matrix_sparse(matrix, 50)) + assert not is_matrix_sparse(matrix, 50) - def test__is_matrix_sparse__giant_matrix_returns_false_early(self): + def test__is_matrix_sparse__giant_matrix_returns_false_early(self, caplog): + caplog.set_level(logging.INFO) matrix = np.ones([20000, 20]) - with self.assertLogs(level="INFO") as logger: - self.assertFalse(is_matrix_sparse(matrix, 1)) + assert not is_matrix_sparse(matrix, 1) - # Because the function returns early a log will output the _estimate_ instead of the _exact_ percentage of - # non-zero elements in the matrix. - self.assertIn("Percentage of non-zero elements (estimate)", logger.output[0]) + # Because the function returns early a log will output the _estimate_ instead of the _exact_ percentage of + # non-zero elements in the matrix. + assert "Percentage of non-zero elements (estimate)" in caplog.text - def test__is_matrix_sparse_with_column_shift_encoding__giant_matrix_returns_false_early(self): + def test__is_matrix_sparse_with_column_shift_encoding__giant_matrix_returns_false_early(self, caplog): + caplog.set_level(logging.INFO) matrix = np.random.rand(20000, 20) - with self.assertLogs(level="INFO") as logger: - self.assertFalse(is_matrix_sparse(matrix, 1)) + assert not is_matrix_sparse(matrix, 1) + + # Because the function returns early a log will output the _estimate_ instead of the _exact_ percentage of + # non-zero elements in the matrix. + assert "Percentage of non-zero elements (estimate)" in caplog.text + + +@pytest.fixture +def noncanonical_matrix(): + array = np.array([[1, 0, 1], [3, 2, 3], [4, 5, 4]]) + return coo_matrix((array[0], (array[1], array[2]))) + + +@pytest.fixture +def canonical_adata(): + return Mock(X=Mock(has_canonical_format=True)) + + +class TestEnforceCanonical: + def test_adata_with_noncanonical_X_and_raw_X(self, noncanonical_matrix, caplog): + assert noncanonical_matrix.has_canonical_format is False + adata = AnnData(noncanonical_matrix) + enforce_canonical_format(adata) + assert adata.X.has_canonical_format is True + assert "noncanonical data found in X; converting to canonical format using sum_duplicates." in caplog.text + + def test_adata_with_noncanonical_raw_X(self, noncanonical_matrix, caplog): + caplog.set_level(logging.WARNING) + assert noncanonical_matrix.has_canonical_format is False + adata = AnnData(raw=AnnData(noncanonical_matrix)) + enforce_canonical_format(adata.raw) + assert adata.raw.X.has_canonical_format is True + assert "noncanonical data found in X; converting to canonical format using sum_duplicates." in caplog.text - # Because the function returns early a log will output the _estimate_ instead of the _exact_ percentage of - # non-zero elements in the matrix. - self.assertIn("Percentage of non-zero elements (estimate)", logger.output[0]) + def test_adata_with_canonical_X(self, canonical_adata, caplog): + caplog.set_level(logging.WARNING) + enforce_canonical_format(canonical_adata) + assert canonical_adata.X.has_canonical_format is True + assert "noncanonical data found in X; converting to canonical format using sum_duplicates." not in caplog.text From db2cc478874c567bb129502006d41b43f2d1bed4 Mon Sep 17 00:00:00 2001 From: Nayib Gloria <55710092+nayib-jose-gloria@users.noreply.github.com> Date: Wed, 21 Aug 2024 20:26:33 -0400 Subject: [PATCH 3/4] chore: bump dependencies (#7330) --- python_dependencies/processing/requirements.txt | 6 +++--- tests/unit/processing/test_h5ad_data_file.py | 2 +- tests/unit/processing/test_type_conversion_utils.py | 9 ++++----- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/python_dependencies/processing/requirements.txt b/python_dependencies/processing/requirements.txt index 972c9add3f136..3001fe45b9843 100644 --- a/python_dependencies/processing/requirements.txt +++ b/python_dependencies/processing/requirements.txt @@ -1,4 +1,4 @@ -anndata==0.8.0 +anndata==0.10.8 awscli boto3>=1.11.17 botocore>=1.14.17 @@ -6,8 +6,8 @@ cellxgene-schema dataclasses-json ddtrace==2.1.4 numba==0.59.1 -numpy==1.26.4 -pandas==1.4.4 +numpy<2 +pandas>2,<3 psutil>=5.9.0 psycopg2-binary==2.* pyarrow>=1.0 diff --git a/tests/unit/processing/test_h5ad_data_file.py b/tests/unit/processing/test_h5ad_data_file.py index 7fbeb0e7bcceb..82f7dfae89e07 100644 --- a/tests/unit/processing/test_h5ad_data_file.py +++ b/tests/unit/processing/test_h5ad_data_file.py @@ -304,7 +304,7 @@ def _write_anndata_to_file(self, anndata): def _create_sample_anndata_dataset(self): # Create X - X = np.random.rand(3, 4) + X = np.random.rand(3, 4).astype(np.float32) # Create obs random_string_category = Series(data=["a", "b", "b"], dtype="category") diff --git a/tests/unit/processing/test_type_conversion_utils.py b/tests/unit/processing/test_type_conversion_utils.py index 37f0397fc7ab0..d28185639209f 100644 --- a/tests/unit/processing/test_type_conversion_utils.py +++ b/tests/unit/processing/test_type_conversion_utils.py @@ -128,7 +128,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): "expected_schema_hint": {"type": "float32"}, "logs": None if data.dtype != np.float64 else {"level": logging.WARNING, "output": "may lose precision"}, } - for dtype in [np.float16, np.float32, np.float64] + for dtype in [np.float32, np.float64] for data in [ np.arange(-128, 1000, dtype=dtype), pd.Series(np.arange(-128, 1000, dtype=dtype)), @@ -201,9 +201,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): "data": data, "expected_encoding_dtype": np.float32, "expected_schema_hint": {"type": "categorical"}, - "logs": {"level": logging.WARNING, "output": "may lose precision"}, + "logs": None if dtype == np.float32 else {"level": logging.WARNING, "output": "may lose precision"}, } - for dtype in [np.float16, np.float32, np.float64] + for dtype in [np.float32, np.float64] for data in [ pd.Series(np.array([0, 1, 2], dtype=dtype), dtype="category"), pd.Series(np.array([0, 1, 2], dtype=dtype), dtype="category").cat.remove_categories([1]), @@ -216,7 +216,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): "data": data, "expected_encoding_dtype": np.float32, "expected_schema_hint": {"type": "categorical"}, - "logs": {"level": logging.WARNING, "output": "may lose precision"}, + "logs": None if dtype == np.float32 else {"level": logging.WARNING, "output": "may lose precision"}, } for dtype in [ np.int8, @@ -227,7 +227,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): np.uint32, np.int64, np.uint64, - np.float16, np.float32, np.float64, ] From 0442ec7015d02f1351a3c27ae35a1b065b33a04d Mon Sep 17 00:00:00 2001 From: Ronen Date: Mon, 26 Aug 2024 11:45:14 -0400 Subject: [PATCH 4/4] fix: top DE gene result for e2e test (#7332) --- .../differentialExpression.test.ts | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/frontend/tests/features/differentialExpression/differentialExpression.test.ts b/frontend/tests/features/differentialExpression/differentialExpression.test.ts index ad8b00cc52a64..9f64b664cd8ab 100644 --- a/frontend/tests/features/differentialExpression/differentialExpression.test.ts +++ b/frontend/tests/features/differentialExpression/differentialExpression.test.ts @@ -586,7 +586,23 @@ describe("Differential Expression", () => { ]); const newPageUrl2 = newPage2.url(); - expect(newPageUrl2).toContain("SAA1"); + + const effectSizeSort = page.getByTestId( + DIFFERENTIAL_EXPRESSION_SORT_DIRECTION + ); + + // Sort by top negative effect size + await effectSizeSort.click(); + + const topGene = await page + .getByTestId("differential-expression-results-table") + .locator("tr:first-child td:first-child") + .textContent(); + + // Reset to sort by top positive effect size + await effectSizeSort.click(); + + expect(newPageUrl2).toContain(topGene); expect(newPageUrl2).toContain("tissues=UBERON%3A0002048"); expect(newPageUrl2).toContain("cellTypes=acinar+cell"); expect(newPageUrl2).toContain("ver=2");