From dca45f532be4d5c017ff09dcd4f66f72a94c19b8 Mon Sep 17 00:00:00 2001 From: Kori Kuzma Date: Thu, 5 Oct 2023 14:31:21 -0400 Subject: [PATCH] build: update ga4gh.vrsatile.pydantic version (#281) - Upgrades to pydantic v2 --- Pipfile | 2 +- docs/scripts/generate_normalize_figure.py | 2 +- pyproject.toml | 2 +- src/gene/database/dynamodb.py | 2 +- .../etl/vrs_locations/chromosome_location.py | 2 +- .../etl/vrs_locations/sequence_location.py | 2 +- src/gene/main.py | 3 + src/gene/query.py | 4 +- src/gene/schemas.py | 191 ++++++------------ src/gene/version.py | 2 +- tests/unit/test_ensembl_source.py | 2 +- tests/unit/test_hgnc_source.py | 2 +- tests/unit/test_ncbi_source.py | 2 +- tests/unit/test_query.py | 80 ++++---- 14 files changed, 115 insertions(+), 183 deletions(-) diff --git a/Pipfile b/Pipfile index e2c57184..28ff142f 100644 --- a/Pipfile +++ b/Pipfile @@ -9,7 +9,7 @@ fastapi = "*" uvicorn = "*" click = "*" boto3 = "*" -"ga4gh.vrsatile.pydantic" = "~=0.0.12" +"ga4gh.vrsatile.pydantic" = "~=0.2.0" "ga4gh.vrs" = "~=0.8.1" [dev-packages] diff --git a/docs/scripts/generate_normalize_figure.py b/docs/scripts/generate_normalize_figure.py index abc71ede..1a39a085 100644 --- a/docs/scripts/generate_normalize_figure.py +++ b/docs/scripts/generate_normalize_figure.py @@ -49,7 +49,7 @@ def create_gjgf(result: UnmergedNormalizationService) -> Dict: "metadata": { "color": COLORS[i], "hover": f"{match.concept_id}\n{match.symbol}\n{match.label}", # noqa: E501 - "click": f"

{json.dumps(match.dict(), indent=2)}

", # noqa: E501 + "click": f"

{json.dumps(match.model_dump(), indent=2)}

", # noqa: E501 } } for xref in match.xrefs: diff --git a/pyproject.toml b/pyproject.toml index 36ec0837..72b3730f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "uvicorn", "click", "boto3", - "ga4gh.vrsatile.pydantic~=0.0.12", + "ga4gh.vrsatile.pydantic~=0.2.0", "ga4gh.vrs~=0.8.1" ] dynamic = ["version"] diff --git a/src/gene/database/dynamodb.py b/src/gene/database/dynamodb.py index 4d2ae889..5df9e0d0 100644 --- a/src/gene/database/dynamodb.py +++ b/src/gene/database/dynamodb.py @@ -356,7 +356,7 @@ def add_source_metadata(self, src_name: SourceName, metadata: SourceMeta) -> Non :raise DatabaseWriteException: if write fails """ src_name_value = src_name.value - metadata_item = metadata.dict() + metadata_item = metadata.model_dump() metadata_item["src_name"] = src_name_value metadata_item["label_and_type"] = f"{str(src_name_value).lower()}##source" metadata_item["concept_id"] = f"source:{str(src_name_value).lower()}" diff --git a/src/gene/etl/vrs_locations/chromosome_location.py b/src/gene/etl/vrs_locations/chromosome_location.py index cd6d1409..464f87f3 100644 --- a/src/gene/etl/vrs_locations/chromosome_location.py +++ b/src/gene/etl/vrs_locations/chromosome_location.py @@ -32,7 +32,7 @@ def get_location(self, location: Dict, gene: Dict) -> Optional[Dict]: try: chr_location = GeneChromosomeLocation( chr=location["chr"], start=location["start"], end=location["end"] - ).dict() + ).model_dump() except ValidationError as e: logger.info(f"{e} for {gene['symbol']}") else: diff --git a/src/gene/etl/vrs_locations/sequence_location.py b/src/gene/etl/vrs_locations/sequence_location.py index b47eba7e..5546f9b7 100644 --- a/src/gene/etl/vrs_locations/sequence_location.py +++ b/src/gene/etl/vrs_locations/sequence_location.py @@ -52,7 +52,7 @@ def add_location( start=gene.start - 1, # type: ignore end=gene.end, # type: ignore sequence_id=sequence_id, - ).dict() # type: ignore + ).model_dump() # type: ignore else: logger.info( f"{params['concept_id']} has invalid interval:" diff --git a/src/gene/main.py b/src/gene/main.py index 67d6e013..02bb7215 100644 --- a/src/gene/main.py +++ b/src/gene/main.py @@ -61,6 +61,7 @@ summary=read_query_summary, response_description=response_description, response_model=SearchService, + response_model_exclude_none=True, description=search_description, tags=["Query"], ) @@ -99,6 +100,7 @@ def search( summary=normalize_summary, response_description=normalize_response_descr, response_model=NormalizeService, + response_model_exclude_none=True, description=normalize_descr, tags=["Query"], ) @@ -134,6 +136,7 @@ def normalize(q: str = Query(..., description=normalize_q_descr)) -> NormalizeSe operation_id="getUnmergedRecords", response_description=unmerged_response_descr, response_model=UnmergedNormalizationService, + response_model_exclude_none=True, description=unmerged_normalize_description, tags=["Query"], ) diff --git a/src/gene/query.py b/src/gene/query.py index 75aeb2ee..6e85fb5d 100644 --- a/src/gene/query.py +++ b/src/gene/query.py @@ -24,6 +24,7 @@ RefType, SearchService, ServiceMeta, + SourceMeta, SourceName, SourcePriority, UnmergedNormalizationService, @@ -370,7 +371,8 @@ def _add_merged_meta(self, response: NormalizeService) -> NormalizeService: prefix = concept_id.split(":")[0] src_name = PREFIX_LOOKUP[prefix.lower()] if src_name not in sources_meta: - sources_meta[src_name] = self.db.get_source_metadata(src_name) + _sources_meta = self.db.get_source_metadata(src_name) + sources_meta[SourceName(src_name)] = SourceMeta(**_sources_meta) response.source_meta_ = sources_meta return response diff --git a/src/gene/schemas.py b/src/gene/schemas.py index 032fda0b..15d5e548 100644 --- a/src/gene/schemas.py +++ b/src/gene/schemas.py @@ -1,6 +1,6 @@ """Contains data models for representing VICC normalized gene records.""" from enum import Enum, IntEnum -from typing import Any, Dict, List, Literal, Optional, Type, Union +from typing import Dict, List, Literal, Optional, Union from ga4gh.vrsatile.pydantic import return_value from ga4gh.vrsatile.pydantic.vrs_models import ( @@ -10,8 +10,14 @@ VRSTypes, ) from ga4gh.vrsatile.pydantic.vrsatile_models import GeneDescriptor -from pydantic import BaseModel, StrictBool, validator -from pydantic.types import StrictInt, StrictStr +from pydantic import ( + BaseModel, + ConfigDict, + StrictBool, + StrictInt, + StrictStr, + field_validator, +) class SymbolStatus(str, Enum): @@ -85,9 +91,9 @@ class BaseGene(BaseModel): concept_id: CURIE symbol: StrictStr - symbol_status: Optional[SymbolStatus] - label: Optional[StrictStr] - strand: Optional[Strand] + symbol_status: Optional[SymbolStatus] = None + label: Optional[StrictStr] = None + strand: Optional[Strand] = None location_annotations: Optional[List[StrictStr]] = [] locations: Optional[ Union[ @@ -99,13 +105,11 @@ class BaseGene(BaseModel): previous_symbols: Optional[List[StrictStr]] = [] xrefs: Optional[List[CURIE]] = [] associated_with: Optional[List[CURIE]] = [] - gene_type: Optional[StrictStr] + gene_type: Optional[StrictStr] = None - _get_concept_id_val = validator("concept_id", allow_reuse=True)(return_value) - _get_xrefs_val = validator("xrefs", allow_reuse=True)(return_value) - _get_associated_with_val = validator("associated_with", allow_reuse=True)( - return_value - ) + _get_concept_id_val = field_validator("concept_id")(return_value) + _get_xrefs_val = field_validator("xrefs")(return_value) + _get_associated_with_val = field_validator("associated_with")(return_value) class Gene(BaseGene): @@ -113,19 +117,9 @@ class Gene(BaseGene): match_type: MatchType - class Config: - """Configure model example""" - - use_enum_values = True - - @staticmethod - def schema_extra(schema: Dict[str, Any], model: Type["Gene"]) -> None: - """Configure OpenAPI schema""" - if "title" in schema.keys(): - schema.pop("title", None) - for p in schema.get("properties", {}).values(): - p.pop("title", None) - schema["example"] = { + model_config = ConfigDict( + json_schema_extra={ + "example": { "label": None, "concept_id": "ensembl:ENSG00000157764", "symbol": "BRAF", @@ -136,6 +130,8 @@ def schema_extra(schema: Dict[str, Any], model: Type["Gene"]) -> None: "strand": "-", "location": [], } + } + ) class GeneGroup(Gene): @@ -236,23 +232,13 @@ class SourceMeta(BaseModel): data_license_url: StrictStr version: StrictStr data_url: Dict[str, str] - rdp_url: Optional[StrictStr] + rdp_url: Optional[StrictStr] = None data_license_attributes: Dict[StrictStr, StrictBool] - genome_assemblies: Optional[List[StrictStr]] - - class Config: - """Configure model example""" + genome_assemblies: Optional[List[StrictStr]] = None - use_enum_values = True - - @staticmethod - def schema_extra(schema: Dict[str, Any], model: Type["SourceMeta"]) -> None: - """Configure OpenAPI schema""" - if "title" in schema.keys(): - schema.pop("title", None) - for prop in schema.get("properties", {}).values(): - prop.pop("title", None) - schema["example"] = { + model_config = ConfigDict( + json_schema_extra={ + "example": { "data_license": "custom", "data_license_url": "https://www.ncbi.nlm.nih.gov/home/about/policies/", # noqa: E501 "version": "20201215", @@ -269,6 +255,8 @@ def schema_extra(schema: Dict[str, Any], model: Type["SourceMeta"]) -> None: }, "genome_assemblies": None, } + } + ) class SourceSearchMatches(BaseModel): @@ -277,21 +265,9 @@ class SourceSearchMatches(BaseModel): records: List[Gene] source_meta_: SourceMeta - class Config: - """Configure model example""" - - use_enum_values = True - - @staticmethod - def schema_extra( - schema: Dict[str, Any], model: Type["SourceSearchMatches"] - ) -> None: - """Configure OpenAPI schema""" - if "title" in schema.keys(): - schema.pop("title", None) - for prop in schema.get("properties", {}).values(): - prop.pop("title", None) - schema["example"] = { + model_config = ConfigDict( + json_schema_extra={ + "example": { "NCBI": { "match_type": 0, "records": [], @@ -314,57 +290,43 @@ def schema_extra( }, } } + } + ) class ServiceMeta(BaseModel): """Metadata regarding the gene-normalization service.""" - name = "gene-normalizer" + name: Literal["gene-normalizer"] = "gene-normalizer" version: StrictStr response_datetime: StrictStr - url = "https://github.com/cancervariants/gene-normalization" - - class Config: - """Configure model example""" - - use_enum_values = True + url: Literal[ + "https://github.com/cancervariants/gene-normalization" + ] = "https://github.com/cancervariants/gene-normalization" - @staticmethod - def schema_extra(schema: Dict[str, Any], model: Type["ServiceMeta"]) -> None: - """Configure OpenAPI schema""" - if "title" in schema.keys(): - schema.pop("title", None) - for prop in schema.get("properties", {}).values(): - prop.pop("title", None) - schema["example"] = { + model_config = ConfigDict( + json_schema_extra={ + "example": { "name": "gene-normalizer", "version": "0.1.0", "response_datetime": "2022-03-23 15:57:14.180908", "url": "https://github.com/cancervariants/gene-normalization", } + } + ) class SearchService(BaseModel): """Define model for returning highest match typed concepts from sources.""" query: StrictStr - warnings: Optional[List[Dict]] + warnings: Optional[List[Dict]] = None source_matches: Dict[SourceName, SourceSearchMatches] service_meta_: ServiceMeta - class Config: - """Configure model example""" - - use_enum_values = True - - @staticmethod - def schema_extra(schema: Dict[str, Any], model: Type["SearchService"]) -> None: - """Configure OpenAPI schema""" - if "title" in schema.keys(): - schema.pop("title", None) - for prop in schema.get("properties", {}).values(): - prop.pop("title", None) - schema["example"] = { + model_config = ConfigDict( + json_schema_extra={ + "example": { "query": "NCBIgene:293", "warnings": [], "source_matches": { @@ -420,6 +382,8 @@ def schema_extra(schema: Dict[str, Any], model: Type["SearchService"]) -> None: "url": "https://github.com/cancervariants/gene-normalization", }, } + } + ) class GeneTypeFieldName(str, Enum): @@ -436,7 +400,7 @@ class BaseNormalizationService(BaseModel): """Base method providing shared attributes to Normalization service classes.""" query: StrictStr - warnings: Optional[List[Dict]] + warnings: Optional[List[Dict]] = None match_type: MatchType service_meta_: ServiceMeta @@ -444,24 +408,12 @@ class BaseNormalizationService(BaseModel): class NormalizeService(BaseNormalizationService): """Define model for returning normalized concept.""" - gene_descriptor: Optional[GeneDescriptor] - source_meta_: Optional[Dict[SourceName, SourceMeta]] - - class Config: - """Configure model example""" + gene_descriptor: Optional[GeneDescriptor] = None + source_meta_: Optional[Dict[SourceName, SourceMeta]] = None - use_enum_values = True - - @staticmethod - def schema_extra( - schema: Dict[str, Any], model: Type["NormalizeService"] - ) -> None: - """Configure OpenAPI schema""" - if "title" in schema.keys(): - schema.pop("title", None) - for prop in schema.get("properties", {}).values(): - prop.pop("title", None) - schema["example"] = { + model_config = ConfigDict( + json_schema_extra={ + "example": { "query": "BRAF", "warnings": [], "match_type": 100, @@ -575,6 +527,8 @@ def schema_extra( "url": "https://github.com/cancervariants/gene-normalization", # noqa: E501 }, } + } + ) class MatchesNormalized(BaseModel): @@ -583,19 +537,6 @@ class MatchesNormalized(BaseModel): records: List[BaseGene] source_meta_: SourceMeta - class Config: - """Configure OpenAPI schema""" - - @staticmethod - def schema_extra( - schema: Dict[str, Any], model: Type["MatchesNormalized"] - ) -> None: - """Configure OpenAPI schema""" - if "title" in schema.keys(): - schema.pop("title", None) - for prop in schema.get("properties", {}).values(): - prop.pop("title", None) - class UnmergedNormalizationService(BaseNormalizationService): """Response providing source records corresponding to normalization of user query. @@ -603,22 +544,12 @@ class UnmergedNormalizationService(BaseNormalizationService): attributes. """ - normalized_concept_id: Optional[CURIE] + normalized_concept_id: Optional[CURIE] = None source_matches: Dict[SourceName, MatchesNormalized] - class Config: - """Configure OpenAPI schema""" - - @staticmethod - def schema_extra( - schema: Dict[str, Any], model: Type["UnmergedNormalizationService"] - ) -> None: - """Configure OpenAPI schema example""" - if "title" in schema.keys(): - schema.pop("title", None) - for prop in schema.get("properties", {}).values(): - prop.pop("title", None) - schema["example"] = { + model_config = ConfigDict( + json_schema_extra={ + "example": { "query": "hgnc:108", "warnings": [], "match_type": 100, @@ -802,3 +733,5 @@ def schema_extra( }, }, } + } + ) diff --git a/src/gene/version.py b/src/gene/version.py index bc1f5ed2..4b01c103 100644 --- a/src/gene/version.py +++ b/src/gene/version.py @@ -1,2 +1,2 @@ """Gene normalizer version""" -__version__ = "0.1.40-dev0" +__version__ = "0.1.40-dev1" diff --git a/tests/unit/test_ensembl_source.py b/tests/unit/test_ensembl_source.py index 0f81d646..9e40fda2 100644 --- a/tests/unit/test_ensembl_source.py +++ b/tests/unit/test_ensembl_source.py @@ -16,7 +16,7 @@ def __init__(self): def search(self, query_str, incl="ensembl"): resp = self.query_handler.search(query_str, incl=incl) - return resp.source_matches[SourceName.ENSEMBL.value] + return resp.source_matches[SourceName.ENSEMBL] e = QueryGetter() return e diff --git a/tests/unit/test_hgnc_source.py b/tests/unit/test_hgnc_source.py index dd5af8b5..dc445d57 100644 --- a/tests/unit/test_hgnc_source.py +++ b/tests/unit/test_hgnc_source.py @@ -18,7 +18,7 @@ def __init__(self): def search(self, query_str, incl="hgnc"): resp = self.query_handler.search(query_str, incl=incl) - return resp.source_matches[SourceName.HGNC.value] + return resp.source_matches[SourceName.HGNC] h = QueryGetter() return h diff --git a/tests/unit/test_ncbi_source.py b/tests/unit/test_ncbi_source.py index f5cfa738..7f1dc9dc 100644 --- a/tests/unit/test_ncbi_source.py +++ b/tests/unit/test_ncbi_source.py @@ -22,7 +22,7 @@ def __init__(self): def search(self, query_str, incl="ncbi"): resp = self.query_handler.search(query_str, incl=incl) - return resp.source_matches[SourceName.NCBI.value] + return resp.source_matches[SourceName.NCBI] n = QueryGetter() return n diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index 5428fe0c..17c25356 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -875,7 +875,7 @@ def num_sources(): @pytest.fixture(scope="module") def source_meta(): """Create test fixture for source meta""" - return [SourceName.HGNC.value, SourceName.ENSEMBL.value, SourceName.NCBI.value] + return [SourceName.HGNC, SourceName.ENSEMBL, SourceName.NCBI] def compare_warnings(actual_warnings, expected_warnings): @@ -976,8 +976,8 @@ def compare_gene_descriptor(test, actual): if actual.xrefs or test.xrefs: assert set(actual.xrefs) == set(test.xrefs), "xrefs" assert set(actual.alternate_labels) == set(test.alternate_labels), "alt labels" - extensions_present = "extensions" in test.__fields__.keys() - assert ("extensions" in actual.__fields__.keys()) == extensions_present + extensions_present = "extensions" in test.model_fields.keys() + assert ("extensions" in actual.model_fields.keys()) == extensions_present if extensions_present: assert len(actual.extensions) == len(test.extensions), "len of extensions" n_ext_correct = 0 @@ -1022,15 +1022,15 @@ def test_search_query_inc_exc(query_handler, num_sources): resp = query_handler.search("BRAF", incl=sources) matches = resp.source_matches assert len(matches) == len(sources.split()) - assert SourceName.HGNC.value in matches - assert SourceName.NCBI.value in matches + assert SourceName.HGNC in matches + assert SourceName.NCBI in matches sources = "HGnC" resp = query_handler.search("BRAF", excl=sources) matches = resp.source_matches assert len(matches) == num_sources - len(sources.split()) - assert SourceName.ENSEMBL.value in matches - assert SourceName.NCBI.value in matches + assert SourceName.ENSEMBL in matches + assert SourceName.NCBI in matches def test_search_invalid_parameter_exception(query_handler): @@ -1048,25 +1048,23 @@ def test_ache_query(query_handler, num_sources, normalized_ache, source_meta): resp = query_handler.search("ncbigene:43") matches = resp.source_matches assert len(matches) == num_sources - assert matches[SourceName.HGNC.value].records[0].match_type == MatchType.XREF - assert len(matches[SourceName.ENSEMBL.value].records) == 0 - assert matches[SourceName.NCBI.value].records[0].match_type == MatchType.CONCEPT_ID + assert matches[SourceName.HGNC].records[0].match_type == MatchType.XREF + assert len(matches[SourceName.ENSEMBL].records) == 0 + assert matches[SourceName.NCBI].records[0].match_type == MatchType.CONCEPT_ID resp = query_handler.search("hgnc:108") matches = resp.source_matches assert len(matches) == num_sources - assert matches[SourceName.HGNC.value].records[0].match_type == MatchType.CONCEPT_ID - assert matches[SourceName.ENSEMBL.value].records[0].match_type == MatchType.XREF - assert matches[SourceName.NCBI.value].records[0].match_type == MatchType.XREF + assert matches[SourceName.HGNC].records[0].match_type == MatchType.CONCEPT_ID + assert matches[SourceName.ENSEMBL].records[0].match_type == MatchType.XREF + assert matches[SourceName.NCBI].records[0].match_type == MatchType.XREF resp = query_handler.search("ensembl:ENSG00000087085") matches = resp.source_matches assert len(matches) == num_sources - assert matches[SourceName.HGNC.value].records[0].match_type == MatchType.XREF - assert ( - matches[SourceName.ENSEMBL.value].records[0].match_type == MatchType.CONCEPT_ID - ) - assert matches[SourceName.NCBI.value].records[0].match_type == MatchType.XREF + assert matches[SourceName.HGNC].records[0].match_type == MatchType.XREF + assert matches[SourceName.ENSEMBL].records[0].match_type == MatchType.CONCEPT_ID + assert matches[SourceName.NCBI].records[0].match_type == MatchType.XREF # Normalize q = "ACHE" @@ -1170,25 +1168,23 @@ def test_braf_query(query_handler, num_sources, normalized_braf, source_meta): resp = query_handler.search("ncbigene:673") matches = resp.source_matches assert len(matches) == num_sources - assert matches[SourceName.HGNC.value].records[0].match_type == MatchType.XREF - assert len(matches[SourceName.ENSEMBL.value].records) == 0 - assert matches[SourceName.NCBI.value].records[0].match_type == MatchType.CONCEPT_ID + assert matches[SourceName.HGNC].records[0].match_type == MatchType.XREF + assert len(matches[SourceName.ENSEMBL].records) == 0 + assert matches[SourceName.NCBI].records[0].match_type == MatchType.CONCEPT_ID resp = query_handler.search("hgnc:1097") matches = resp.source_matches assert len(matches) == num_sources - assert matches[SourceName.HGNC.value].records[0].match_type == MatchType.CONCEPT_ID - assert matches[SourceName.ENSEMBL.value].records[0].match_type == MatchType.XREF - assert matches[SourceName.NCBI.value].records[0].match_type == MatchType.XREF + assert matches[SourceName.HGNC].records[0].match_type == MatchType.CONCEPT_ID + assert matches[SourceName.ENSEMBL].records[0].match_type == MatchType.XREF + assert matches[SourceName.NCBI].records[0].match_type == MatchType.XREF resp = query_handler.search("ensembl:ENSG00000157764") matches = resp.source_matches assert len(matches) == num_sources - assert matches[SourceName.HGNC.value].records[0].match_type == MatchType.XREF - assert ( - matches[SourceName.ENSEMBL.value].records[0].match_type == MatchType.CONCEPT_ID - ) - assert matches[SourceName.NCBI.value].records[0].match_type == MatchType.XREF + assert matches[SourceName.HGNC].records[0].match_type == MatchType.XREF + assert matches[SourceName.ENSEMBL].records[0].match_type == MatchType.CONCEPT_ID + assert matches[SourceName.NCBI].records[0].match_type == MatchType.XREF # Normalize q = "BRAF" @@ -1270,25 +1266,23 @@ def test_abl1_query(query_handler, num_sources, normalized_abl1, source_meta): resp = query_handler.search("ncbigene:25") matches = resp.source_matches assert len(matches) == num_sources - assert matches[SourceName.HGNC.value].records[0].match_type == MatchType.XREF - assert len(matches[SourceName.ENSEMBL.value].records) == 0 - assert matches[SourceName.NCBI.value].records[0].match_type == MatchType.CONCEPT_ID + assert matches[SourceName.HGNC].records[0].match_type == MatchType.XREF + assert len(matches[SourceName.ENSEMBL].records) == 0 + assert matches[SourceName.NCBI].records[0].match_type == MatchType.CONCEPT_ID resp = query_handler.search("hgnc:76") matches = resp.source_matches assert len(matches) == num_sources - assert matches[SourceName.HGNC.value].records[0].match_type == MatchType.CONCEPT_ID - assert matches[SourceName.ENSEMBL.value].records[0].match_type == MatchType.XREF - assert matches[SourceName.NCBI.value].records[0].match_type == MatchType.XREF + assert matches[SourceName.HGNC].records[0].match_type == MatchType.CONCEPT_ID + assert matches[SourceName.ENSEMBL].records[0].match_type == MatchType.XREF + assert matches[SourceName.NCBI].records[0].match_type == MatchType.XREF resp = query_handler.search("ensembl:ENSG00000097007") matches = resp.source_matches assert len(matches) == num_sources - assert matches[SourceName.HGNC.value].records[0].match_type == MatchType.XREF - assert ( - matches[SourceName.ENSEMBL.value].records[0].match_type == MatchType.CONCEPT_ID - ) - assert matches[SourceName.NCBI.value].records[0].match_type == MatchType.XREF + assert matches[SourceName.HGNC].records[0].match_type == MatchType.XREF + assert matches[SourceName.ENSEMBL].records[0].match_type == MatchType.CONCEPT_ID + assert matches[SourceName.NCBI].records[0].match_type == MatchType.XREF # Normalize q = "ABL1" @@ -1426,7 +1420,7 @@ def test_normalize_single_entry(query_handler, normalized_loc_653303): q, MatchType.SYMBOL, normalized_loc_653303, - expected_source_meta=[SourceName.NCBI.value], + expected_source_meta=[SourceName.NCBI], ) @@ -1441,7 +1435,7 @@ def test_normalize_no_locations(query_handler, normalized_ifnr): q, MatchType.SYMBOL, normalized_ifnr, - expected_source_meta=[SourceName.HGNC.value, SourceName.NCBI.value], + expected_source_meta=[SourceName.HGNC, SourceName.NCBI], ) @@ -1534,7 +1528,7 @@ def test_normalize_unmerged( def test_invalid_queries(query_handler): """Test invalid queries""" resp = query_handler.normalize("B R A F") - assert resp.match_type is MatchType.NO_MATCH.value + assert resp.match_type is MatchType.NO_MATCH with pytest.raises(TypeError): resp["match_type"]