From a913c4707337e0f32fdb51f644043d75488288d9 Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 11:50:24 -0800 Subject: [PATCH 01/34] Remove get_data_by_id --- mp_api/client/core/client.py | 85 ++++-------------------------------- 1 file changed, 9 insertions(+), 76 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 8ff9e3758..2151fd69a 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -128,7 +128,9 @@ def __init__( self._s3_resource = None self.document_model = ( - api_sanitize(self.document_model) if self.document_model is not None else None # type: ignore + api_sanitize(self.document_model) + if self.document_model is not None + else None # type: ignore ) @property @@ -237,7 +239,9 @@ def _post_resource( if isinstance(data["data"], dict): data["data"] = self.document_model.parse_obj(data["data"]) # type: ignore elif isinstance(data["data"], list): - data["data"] = [self.document_model.parse_obj(d) for d in data["data"]] # type: ignore + data["data"] = [ + self.document_model.parse_obj(d) for d in data["data"] + ] # type: ignore return data @@ -307,7 +311,9 @@ def _patch_resource( if isinstance(data["data"], dict): data["data"] = self.document_model.parse_obj(data["data"]) # type: ignore elif isinstance(data["data"], list): - data["data"] = [self.document_model.parse_obj(d) for d in data["data"]] # type: ignore + data["data"] = [ + self.document_model.parse_obj(d) for d in data["data"] + ] # type: ignore return data @@ -967,79 +973,6 @@ def _query_resource_data( num_chunks=1, ).get("data") - def get_data_by_id( - self, - document_id: str, - fields: list[str] | None = None, - ) -> T: - """Query the endpoint for a single document. - - Arguments: - document_id: the unique key for this kind of document, typically a task_id - fields: list of fields to return, by default will return all fields - - Returns: - A single document. - """ - if document_id is None: - raise ValueError( - "Please supply a specific ID. You can use the query method to find " - "ids of interest." - ) - - if self.primary_key in ["material_id", "task_id"]: - validate_ids([document_id]) - - if fields is None: - criteria = {"_all_fields": True, "_limit": 1} # type: dict - else: - criteria = {"_limit": 1} - - if isinstance(fields, str): # pragma: no cover - fields = (fields,) - - results = [] # type: List - - try: - results = self._query_resource_data(criteria=criteria, fields=fields, suburl=document_id) # type: ignore - except MPRestError: - if self.primary_key == "material_id": - # see if the material_id has changed, perhaps a task_id was supplied - # this should likely be re-thought - from mp_api.client.routes.materials.materials import MaterialsRester - - with MaterialsRester( - api_key=self.api_key, - endpoint=self.base_endpoint, - use_document_model=False, - monty_decode=False, - session=self.session, - headers=self.headers, - ) as mpr: - docs = mpr.search(task_ids=[document_id], fields=["material_id"]) - - if len(docs) > 0: - new_document_id = docs[0].get("material_id", None) - - if new_document_id is not None: - warnings.warn( - f"Document primary key has changed from {document_id} to {new_document_id}, " - f"returning data for {new_document_id} in {self.suffix} route. " - ) - - results = self._query_resource_data( - criteria=criteria, fields=fields, suburl=new_document_id # type: ignore - ) - - if not results: - raise MPRestError(f"No result for record {document_id}.") - elif len(results) > 1: # pragma: no cover - raise ValueError( - f"Multiple records for {document_id}, this shouldn't happen. Please report as a bug." - ) - else: - return results[0] - def _search( self, num_chunks: int | None = None, From 8f3d95df9120945cf98c72ec6da6f8185175baee Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 11:50:32 -0800 Subject: [PATCH 02/34] Add search method for DOI rester --- mp_api/client/routes/materials/doi.py | 37 ++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/mp_api/client/routes/materials/doi.py b/mp_api/client/routes/materials/doi.py index 2be082d96..b3742c6db 100644 --- a/mp_api/client/routes/materials/doi.py +++ b/mp_api/client/routes/materials/doi.py @@ -1,18 +1,43 @@ from __future__ import annotations +from collections import defaultdict + from emmet.core.dois import DOIDoc from mp_api.client.core import BaseRester +from mp_api.client.core.utils import validate_ids class DOIRester(BaseRester[DOIDoc]): suffix = "doi" document_model = DOIDoc # type: ignore - primary_key = "task_id" + primary_key = "material_id" + + def search( + self, + material_ids: str | list[str] | None = None, + num_chunks: int | None = None, + chunk_size: int = 1000, + all_fields: bool = True, + fields: list[str] | None = None, + ): + query_params = defaultdict(dict) # type: dict + if material_ids: + if isinstance(material_ids, str): + material_ids = [material_ids] + + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) + + query_params = { + entry: query_params[entry] + for entry in query_params + if query_params[entry] is not None + } - def search(*args, **kwargs): # pragma: no cover - raise NotImplementedError( - """ - The DOIRester.search method does not exist as no search endpoint is present. Use get_data_by_id instead. - """ + return super()._search( + num_chunks=num_chunks, + chunk_size=chunk_size, + all_fields=all_fields, + fields=fields, + **query_params, ) From 719722ab36c51be03663049a7cea057e6b75fab4 Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 13:09:05 -0800 Subject: [PATCH 03/34] Big endpoint updates --- mp_api/client/routes/materials/absorption.py | 4 +- mp_api/client/routes/materials/alloys.py | 4 +- mp_api/client/routes/materials/bonds.py | 15 +---- .../client/routes/materials/charge_density.py | 4 +- mp_api/client/routes/materials/chemenv.py | 4 +- mp_api/client/routes/materials/dielectric.py | 15 +---- mp_api/client/routes/materials/doi.py | 15 ++++- mp_api/client/routes/materials/elasticity.py | 15 +---- mp_api/client/routes/materials/electrodes.py | 21 ++----- .../routes/materials/electronic_structure.py | 8 +-- mp_api/client/routes/materials/eos.py | 16 +---- mp_api/client/routes/materials/fermi.py | 27 ++++++++- .../client/routes/materials/grain_boundary.py | 16 +---- mp_api/client/routes/materials/magnetism.py | 19 +----- mp_api/client/routes/materials/materials.py | 17 +----- .../routes/materials/oxidation_states.py | 4 +- mp_api/client/routes/materials/phonon.py | 51 ++++++++++++++-- mp_api/client/routes/materials/piezo.py | 18 +----- mp_api/client/routes/materials/provenance.py | 11 +++- mp_api/client/routes/materials/robocrys.py | 58 +++++++++++++++---- mp_api/client/routes/materials/similarity.py | 50 ++++++++++++++-- mp_api/client/routes/materials/substrates.py | 18 +----- mp_api/client/routes/materials/summary.py | 27 +++------ .../routes/materials/surface_properties.py | 28 +++++---- mp_api/client/routes/materials/synthesis.py | 16 +---- mp_api/client/routes/materials/tasks.py | 20 ++----- mp_api/client/routes/materials/thermo.py | 27 ++++----- mp_api/client/routes/materials/xas.py | 38 +++++++----- 28 files changed, 284 insertions(+), 282 deletions(-) diff --git a/mp_api/client/routes/materials/absorption.py b/mp_api/client/routes/materials/absorption.py index 64262fcef..9ef7ff338 100644 --- a/mp_api/client/routes/materials/absorption.py +++ b/mp_api/client/routes/materials/absorption.py @@ -24,7 +24,7 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ) -> list[AbsorptionDoc]: + ) -> list[AbsorptionDoc] | list[dict]: """Query for optical absorption spectra data. Arguments: @@ -42,7 +42,7 @@ def search( fields (List[str]): List of fields in AbsorptionDoc to return data for. Returns: - ([AbsorptionDoc]) List of optical absorption documents. + ([AbsorptionDoc], [dict]) List of optical absorption documents or dictionaries. """ query_params = defaultdict(dict) # type: dict diff --git a/mp_api/client/routes/materials/alloys.py b/mp_api/client/routes/materials/alloys.py index 11edb630b..53b79064a 100644 --- a/mp_api/client/routes/materials/alloys.py +++ b/mp_api/client/routes/materials/alloys.py @@ -21,7 +21,7 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ) -> list[AlloyPairDoc]: + ) -> list[AlloyPairDoc] | list[dict]: """Query for hypothetical alloys formed between two commensurate crystal structures, following the methodology in https://doi.org/10.48550/arXiv.2206.10715. @@ -38,7 +38,7 @@ def search( fields (List[str]): List of fields in AlloyPairDoc to return data for. Returns: - ([AlloyPairDoc]) List of alloy pair documents. + ([AlloyPairDoc], [dict]) List of alloy pair documents or dictionaries. """ query_params = defaultdict(dict) # type: dict diff --git a/mp_api/client/routes/materials/bonds.py b/mp_api/client/routes/materials/bonds.py index 03909ad58..bdbe4ca2d 100644 --- a/mp_api/client/routes/materials/bonds.py +++ b/mp_api/client/routes/materials/bonds.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from collections import defaultdict from emmet.core.bonds import BondingDoc @@ -14,16 +13,6 @@ class BondsRester(BaseRester[BondingDoc]): document_model = BondingDoc # type: ignore primary_key = "material_id" - def search_bonds_docs(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "MPRester.bonds.search_bonds_docs is deprecated. Please use MPRester.bonds.search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( self, material_ids: str | list[str] | None = None, @@ -36,7 +25,7 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): + ) -> list[BondingDoc] | list[dict]: """Query bonding docs using a variety of search criteria. Arguments: @@ -57,7 +46,7 @@ def search( Default is material_id and last_updated if all_fields is False. Returns: - ([BondingDoc]) List of bonding documents. + ([BondingDoc], [dict]) List of bonding documents or dictionaries. """ query_params = defaultdict(dict) # type: dict diff --git a/mp_api/client/routes/materials/charge_density.py b/mp_api/client/routes/materials/charge_density.py index 2d160d159..c1c3a0edb 100644 --- a/mp_api/client/routes/materials/charge_density.py +++ b/mp_api/client/routes/materials/charge_density.py @@ -42,12 +42,12 @@ def download_for_task_ids( num_downloads += 1 return num_downloads - def search( # type: ignore + def search( self, task_ids: list[str] | None = None, num_chunks: int | None = 1, chunk_size: int = 10, - ) -> list[ChgcarDataDoc] | list[dict]: # type: ignore + ) -> list[ChgcarDataDoc] | list[dict]: """A search method to find what charge densities are available via this API. Arguments: diff --git a/mp_api/client/routes/materials/chemenv.py b/mp_api/client/routes/materials/chemenv.py index 384c3aca5..bc05fff85 100644 --- a/mp_api/client/routes/materials/chemenv.py +++ b/mp_api/client/routes/materials/chemenv.py @@ -46,7 +46,7 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): + ) -> list[ChemEnvDoc] | list[dict]: """Query for chemical environment data. Arguments: @@ -73,7 +73,7 @@ def search( fields (List[str]): List of fields in ChemEnvDoc to return data for. Returns: - ([ChemEnvDoc]) List of chemenv documents. + ([ChemEnvDoc], [dict]) List of chemenv documents or dictionaries. """ query_params = defaultdict(dict) # type: dict diff --git a/mp_api/client/routes/materials/dielectric.py b/mp_api/client/routes/materials/dielectric.py index 0b862af17..4adedc37e 100644 --- a/mp_api/client/routes/materials/dielectric.py +++ b/mp_api/client/routes/materials/dielectric.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from collections import defaultdict from emmet.core.polar import DielectricDoc @@ -14,16 +13,6 @@ class DielectricRester(BaseRester[DielectricDoc]): document_model = DielectricDoc # type: ignore primary_key = "material_id" - def search_dielectric_docs(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "MPRester.dielectric.search_dielectric_docs is deprecated. Please use MPRester.dielectric.search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( self, material_ids: str | list[str] | None = None, @@ -35,7 +24,7 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): + ) -> list[DielectricDoc] | list[dict]: """Query dielectric docs using a variety of search criteria. Arguments: @@ -52,7 +41,7 @@ def search( Default is material_id and last_updated if all_fields is False. Returns: - ([DielectricDoc]) List of dielectric documents. + ([DielectricDoc], [dict]) List of dielectric documents or dictionaries. """ query_params = defaultdict(dict) # type: dict diff --git a/mp_api/client/routes/materials/doi.py b/mp_api/client/routes/materials/doi.py index b3742c6db..bb66f185d 100644 --- a/mp_api/client/routes/materials/doi.py +++ b/mp_api/client/routes/materials/doi.py @@ -20,8 +20,21 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): + ) -> list[DOIDoc] | list[dict]: + """Query for DOI data. + + Arguments: + material_ids (str, List[str]): Search for DOI data associated with the specified Material IDs + num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. + chunk_size (int): Number of data entries per chunk. + all_fields (bool): Whether to return all fields in the document. Defaults to True. + fields (List[str]): List of fields in DOIDoc to return data for. + + Returns: + ([DOIDoc], [dict]) List of DOIDoc documents or dictionaries. + """ query_params = defaultdict(dict) # type: dict + if material_ids: if isinstance(material_ids, str): material_ids = [material_ids] diff --git a/mp_api/client/routes/materials/elasticity.py b/mp_api/client/routes/materials/elasticity.py index 074333322..e6c416968 100644 --- a/mp_api/client/routes/materials/elasticity.py +++ b/mp_api/client/routes/materials/elasticity.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from collections import defaultdict from emmet.core.elasticity import ElasticityDoc @@ -14,16 +13,6 @@ class ElasticityRester(BaseRester[ElasticityDoc]): document_model = ElasticityDoc # type: ignore primary_key = "material_id" - def search_elasticity_docs(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "MPRester.elasticity.search_elasticity_docs is deprecated. Please use MPRester.elasticity.search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( self, material_ids: str | list[str] | None = None, @@ -39,7 +28,7 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): + ) -> list[ElasticityDoc] | list[dict]: """Query elasticity docs using a variety of search criteria. Arguments: @@ -68,7 +57,7 @@ def search( Default is material_id and prett-formula if all_fields is False. Returns: - ([ElasticityDoc]) List of elasticity documents. + ([ElasticityDoc], [dict]) List of elasticity documents or dictionaries. """ query_params = defaultdict(dict) # type: dict diff --git a/mp_api/client/routes/materials/electrodes.py b/mp_api/client/routes/materials/electrodes.py index 70d4bd67a..e2555ac06 100644 --- a/mp_api/client/routes/materials/electrodes.py +++ b/mp_api/client/routes/materials/electrodes.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from collections import defaultdict from emmet.core.electrode import InsertionElectrodeDoc @@ -15,16 +14,6 @@ class ElectrodeRester(BaseRester[InsertionElectrodeDoc]): document_model = InsertionElectrodeDoc # type: ignore primary_key = "battery_id" - def search_electrode_docs(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "MPRester.electrode.search_electrode_docs is deprecated. Please use MPRester.electrode.search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( # pragma: ignore self, material_ids: str | list[str] | None = None, @@ -50,8 +39,8 @@ def search( # pragma: ignore chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): - """Query equations of state docs using a variety of search criteria. + ) -> list[InsertionElectrodeDoc] | list[dict]: + """Query using a variety of search criteria. Arguments: material_ids (str, List[str]): A single Material ID string or list of strings @@ -92,7 +81,7 @@ def search( # pragma: ignore Default is battery_id and last_updated if all_fields is False. Returns: - ([InsertionElectrodeDoc]) List of insertion electrode documents. + ([InsertionElectrodeDoc], [dict]) List of insertion electrode documents or dictionaries. """ query_params = defaultdict(dict) # type: dict @@ -112,7 +101,9 @@ def search( # pragma: ignore if isinstance(working_ion, (str, Element)): working_ion = [working_ion] # type: ignore - query_params.update({"working_ion": ",".join([str(ele) for ele in working_ion])}) # type: ignore + query_params.update( + {"working_ion": ",".join([str(ele) for ele in working_ion])} # type: ignore + ) if formula: if isinstance(formula, str): diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index 668abf138..a9d75f142 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -41,8 +41,8 @@ def search( elements: list[str] | None = None, exclude_elements: list[str] | None = None, formula: str | list[str] | None = None, - is_gap_direct: bool = None, - is_metal: bool = None, + is_gap_direct: bool | None = None, + is_metal: bool | None = None, magnetic_ordering: Ordering | None = None, num_elements: tuple[int, int] | None = None, num_chunks: int | None = None, @@ -161,8 +161,8 @@ def search( self, band_gap: tuple[float, float] | None = None, efermi: tuple[float, float] | None = None, - is_gap_direct: bool = None, - is_metal: bool = None, + is_gap_direct: bool | None = None, + is_metal: bool | None = None, magnetic_ordering: Ordering | None = None, path_type: BSPathType = BSPathType.setyawan_curtarolo, num_chunks: int | None = None, diff --git a/mp_api/client/routes/materials/eos.py b/mp_api/client/routes/materials/eos.py index e897412b8..b0698ed56 100644 --- a/mp_api/client/routes/materials/eos.py +++ b/mp_api/client/routes/materials/eos.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from collections import defaultdict from emmet.core.eos import EOSDoc @@ -13,17 +12,6 @@ class EOSRester(BaseRester[EOSDoc]): document_model = EOSDoc # type: ignore primary_key = "task_id" - def search_eos_docs(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "MPRester.eos.search_eos_docs is deprecated. " - "Please use MPRester.eos.search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( self, energies: tuple[float, float] | None = None, @@ -32,7 +20,7 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): + ) -> list[EOSDoc] | list[dict]: """Query equations of state docs using a variety of search criteria. Arguments: @@ -45,7 +33,7 @@ def search( Default is material_id only if all_fields is False. Returns: - ([EOSDoc]) List of eos documents + ([EOSDoc], [dict]) List of equations of state docs or dictionaries. """ query_params = defaultdict(dict) # type: dict diff --git a/mp_api/client/routes/materials/fermi.py b/mp_api/client/routes/materials/fermi.py index 8287bf38e..606fed354 100644 --- a/mp_api/client/routes/materials/fermi.py +++ b/mp_api/client/routes/materials/fermi.py @@ -1,25 +1,31 @@ from __future__ import annotations +from collections import defaultdict + from emmet.core.fermi import FermiDoc from mp_api.client.core import BaseRester +from mp_api.client.core.utils import validate_ids class FermiRester(BaseRester[FermiDoc]): suffix = "materials/fermi" document_model = FermiDoc # type: ignore - primary_key = "task_id" + primary_key = "material_id" def search( self, + material_ids: str | list[str] | None = None, num_chunks: int | None = None, chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): + ) -> list[FermiDoc] | list[dict]: """Query fermi surface docs using a variety of search criteria. Arguments: + material_ids (str, List[str]): A single Material ID string or list of strings + (e.g., mp-149, [mp-149, mp-13]). num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. chunk_size (int): Number of data entries per chunk. all_fields (bool): Whether to return all fields in the document. Defaults to True. @@ -27,11 +33,26 @@ def search( Default is material_id, last_updated, and formula_pretty if all_fields is False. Returns: - ([FermiDoc]) List of material documents + ([FermiDoc], [dict]) List of fermi documents or dictionaries. """ + query_params = defaultdict(dict) # type: dict + + if material_ids: + if isinstance(material_ids, str): + material_ids = [material_ids] + + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) + + query_params = { + entry: query_params[entry] + for entry in query_params + if query_params[entry] is not None + } + return super()._search( num_chunks=num_chunks, chunk_size=chunk_size, all_fields=all_fields, fields=fields, + **query_params, ) diff --git a/mp_api/client/routes/materials/grain_boundary.py b/mp_api/client/routes/materials/grain_boundary.py index 6caef989e..efeea2674 100644 --- a/mp_api/client/routes/materials/grain_boundary.py +++ b/mp_api/client/routes/materials/grain_boundary.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from collections import defaultdict from emmet.core.grain_boundary import GBTypeEnum, GrainBoundaryDoc @@ -14,17 +13,6 @@ class GrainBoundaryRester(BaseRester[GrainBoundaryDoc]): document_model = GrainBoundaryDoc # type: ignore primary_key = "task_id" - def search_grain_boundary_docs(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "MPRester.grain_boundary.search_grain_boundary_docs is deprecated. " - "Please use MPRester.grain_boundary.search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( self, chemsys: str | None = None, @@ -41,7 +29,7 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): + ) -> list[GrainBoundaryDoc] | list[dict]: """Query grain boundary docs using a variety of search criteria. Arguments: @@ -63,7 +51,7 @@ def search( Default is material_id and last_updated if all_fields is False. Returns: - ([GrainBoundaryDoc]) List of grain boundary documents + ([GrainBoundaryDoc], [dict]) List of grain boundary documents or dictionaries. """ query_params = defaultdict(dict) # type: dict diff --git a/mp_api/client/routes/materials/magnetism.py b/mp_api/client/routes/materials/magnetism.py index 103a23a0d..96c3c59a6 100644 --- a/mp_api/client/routes/materials/magnetism.py +++ b/mp_api/client/routes/materials/magnetism.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from collections import defaultdict from emmet.core.magnetism import MagnetismDoc @@ -15,17 +14,6 @@ class MagnetismRester(BaseRester[MagnetismDoc]): document_model = MagnetismDoc # type: ignore primary_key = "material_id" - def search_magnetism_docs(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "MPRester.magnetism.search_magnetism_docs is deprecated. " - "Please use MPRester.magnetism.search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( self, material_ids: str | list[str] | None = None, @@ -39,12 +27,11 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): + ) -> list[MagnetismDoc] | list[dict]: """Query magnetism docs using a variety of search criteria. Arguments: - material_ids (str, List[str]): A single Material ID string or list of strings - (e.g., mp-149, [mp-149, mp-13]). + material_ids (str, List[str]): A single Material ID string or list of strings (e.g., mp-149, [mp-149, mp-13]). num_magnetic_sites (Tuple[int,int]): Minimum and maximum number of magnetic sites to consider. num_unique_magnetic_sites (Tuple[int,int]): Minimum and maximum number of unique magnetic sites to consider. @@ -61,7 +48,7 @@ def search( Default is material_id and last_updated if all_fields is False. Returns: - ([MagnetismDoc]) List of magnetism documents + ([MagnetismDoc], [dict]) List of magnetism documents or dictionaries. """ query_params = defaultdict(dict) # type: dict diff --git a/mp_api/client/routes/materials/materials.py b/mp_api/client/routes/materials/materials.py index bf5d20e3b..f015fa5f4 100644 --- a/mp_api/client/routes/materials/materials.py +++ b/mp_api/client/routes/materials/materials.py @@ -40,7 +40,7 @@ XASRester, ) -_EMMET_SETTINGS = EmmetSettings() +_EMMET_SETTINGS = EmmetSettings() # type: ignore class MaterialsRester(BaseRester[MaterialsDoc]): @@ -134,17 +134,6 @@ def get_structure_by_material_id( response = self.get_data_by_id(material_id, fields=["initial_structures"]) return response.initial_structures if response is not None else response # type: ignore - def search_material_docs(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "MPRester.materials.search_material_docs is deprecated. " - "Please use MPRester.materials.search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( self, material_ids: str | list[str] | None = None, @@ -165,7 +154,7 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): + ) -> list[MaterialsDoc] | list[dict]: """Query core material docs using a variety of search criteria. Arguments: @@ -194,7 +183,7 @@ def search( Default is material_id, last_updated, and formula_pretty if all_fields is False. Returns: - ([MaterialsDoc]) List of material documents + ([MaterialsDoc], [dict]) List of material documents or dictionaries. """ query_params = {"deprecated": deprecated} # type: dict diff --git a/mp_api/client/routes/materials/oxidation_states.py b/mp_api/client/routes/materials/oxidation_states.py index 4fb29a577..fca6e93de 100644 --- a/mp_api/client/routes/materials/oxidation_states.py +++ b/mp_api/client/routes/materials/oxidation_states.py @@ -23,7 +23,7 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): + ) -> list[OxidationStateDoc] | list[dict]: """Query oxidation state docs using a variety of search criteria. Arguments: @@ -42,7 +42,7 @@ def search( Default is material_id, last_updated, and formula_pretty if all_fields is False. Returns: - ([OxidationStateDoc]) List of oxidation state documents + ([OxidationStateDoc], [dict]) List of oxidation state documents or dictionaries. """ query_params = defaultdict(dict) # type: dict diff --git a/mp_api/client/routes/materials/phonon.py b/mp_api/client/routes/materials/phonon.py index 0f0ee6e5e..5651a3726 100644 --- a/mp_api/client/routes/materials/phonon.py +++ b/mp_api/client/routes/materials/phonon.py @@ -1,8 +1,11 @@ from __future__ import annotations +from collections import defaultdict + from emmet.core.phonon import PhononBSDOSDoc from mp_api.client.core import BaseRester +from mp_api.client.core.utils import validate_ids class PhononRester(BaseRester[PhononBSDOSDoc]): @@ -10,10 +13,46 @@ class PhononRester(BaseRester[PhononBSDOSDoc]): document_model = PhononBSDOSDoc # type: ignore primary_key = "material_id" - def search(*args, **kwargs): # pragma: no cover - raise NotImplementedError( - """ - The PhononRester.search method does not exist as no search endpoint is present. - Use get_data_by_id instead. - """ + def search( + self, + material_ids: str | list[str] | None = None, + num_chunks: int | None = None, + chunk_size: int = 1000, + all_fields: bool = True, + fields: list[str] | None = None, + ) -> list[PhononBSDOSDoc] | list[dict]: + """Query phonon docs using a variety of search criteria. + + Arguments: + material_ids (str, List[str]): A single Material ID string or list of strings + (e.g., mp-149, [mp-149, mp-13]). + num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. + chunk_size (int): Number of data entries per chunk. + all_fields (bool): Whether to return all fields in the document. Defaults to True. + fields (List[str]): List of fields in PhononBSDOSDoc to return data for. + Default is material_id, last_updated, and formula_pretty if all_fields is False. + + Returns: + ([PhononBSDOSDoc], [dict]) List of phonon documents or dictionaries. + """ + query_params = defaultdict(dict) # type: dict + + if material_ids: + if isinstance(material_ids, str): + material_ids = [material_ids] + + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) + + query_params = { + entry: query_params[entry] + for entry in query_params + if query_params[entry] is not None + } + + return super()._search( + num_chunks=num_chunks, + chunk_size=chunk_size, + all_fields=all_fields, + fields=fields, + **query_params, ) diff --git a/mp_api/client/routes/materials/piezo.py b/mp_api/client/routes/materials/piezo.py index 496b2cab0..306cf5069 100644 --- a/mp_api/client/routes/materials/piezo.py +++ b/mp_api/client/routes/materials/piezo.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from collections import defaultdict from emmet.core.polar import PiezoelectricDoc @@ -14,17 +13,6 @@ class PiezoRester(BaseRester[PiezoelectricDoc]): document_model = PiezoelectricDoc # type: ignore primary_key = "material_id" - def search_piezoelectric_docs(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "MPRester.piezoelectric.search_piezoelectric_docs is deprecated. " - "Please use MPRester.piezoelectric.search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( self, material_ids: str | list[str] | None = None, @@ -33,8 +21,8 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): - """Query equations of state docs using a variety of search criteria. + ) -> list[PiezoelectricDoc] | list[dict]: + """Query piezoelectric data using a variety of search criteria. Arguments: material_ids (str, List[str]): A single Material ID string or list of strings @@ -48,7 +36,7 @@ def search( Default is material_id and last_updated if all_fields is False. Returns: - ([PiezoDoc]) List of piezoelectric documents + ([PiezoDoc], [dict]) List of piezoelectric documents """ query_params = defaultdict(dict) # type: dict diff --git a/mp_api/client/routes/materials/provenance.py b/mp_api/client/routes/materials/provenance.py index 4c49414b7..b4ff240fc 100644 --- a/mp_api/client/routes/materials/provenance.py +++ b/mp_api/client/routes/materials/provenance.py @@ -19,7 +19,7 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): + ) -> list[ProvenanceDoc] | list[dict]: """Query provenance docs using a variety of search criteria. Arguments: @@ -29,11 +29,11 @@ def search( num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. chunk_size (int): Number of data entries per chunk. all_fields (bool): Whether to return all fields in the document. Defaults to True. - fields (List[str]): List of fields in Provenance to return data for. + fields (List[str]): List of fields in ProvenanceDoc to return data for. Default is material_id, last_updated, and formula_pretty if all_fields is False. Returns: - ([ProvenanceDoc]) List of provenance documents + ([ProvenanceDoc], [dict]) List of provenance documents or dictionaries. """ query_params = {"deprecated": deprecated} # type: dict @@ -43,6 +43,11 @@ def search( query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) + query_params = { + entry: query_params[entry] + for entry in query_params + if query_params[entry] is not None + } return super()._search( num_chunks=num_chunks, chunk_size=chunk_size, diff --git a/mp_api/client/routes/materials/robocrys.py b/mp_api/client/routes/materials/robocrys.py index 845701e3c..b76adf491 100644 --- a/mp_api/client/routes/materials/robocrys.py +++ b/mp_api/client/routes/materials/robocrys.py @@ -1,10 +1,9 @@ from __future__ import annotations -import warnings - from emmet.core.robocrys import RobocrystallogapherDoc from mp_api.client.core import BaseRester, MPRestError +from mp_api.client.core.utils import validate_ids class RobocrysRester(BaseRester[RobocrystallogapherDoc]): @@ -12,16 +11,6 @@ class RobocrysRester(BaseRester[RobocrystallogapherDoc]): document_model = RobocrystallogapherDoc # type: ignore primary_key = "material_id" - def search_robocrys_text(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "search_robocrys_text is deprecated. " "Please use search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( self, keywords: list[str], @@ -52,3 +41,48 @@ def search( raise MPRestError("Cannot find any matches.") return robocrys_docs + + def search_docs( + self, + material_ids: str | list[str] | None = None, + deprecated: bool | None = False, + num_chunks: int | None = None, + chunk_size: int = 1000, + all_fields: bool = True, + fields: list[str] | None = None, + ) -> list[RobocrystallogapherDoc] | list[dict]: + """Query robocrystallographer docs using a variety of search criteria. + + Arguments: + material_ids (str, List[str]): A single Material ID string or list of strings + (e.g., mp-149, [mp-149, mp-13]). + deprecated (bool): Whether the material is tagged as deprecated. + num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. + chunk_size (int): Number of data entries per chunk. + all_fields (bool): Whether to return all fields in the document. Defaults to True. + fields (List[str]): List of fields in RobocrystallogapherDoc to return data for. + Default is material_id, last_updated, and formula_pretty if all_fields is False. + + Returns: + ([RobocrystallogapherDoc], [dict]) List of robocrystallographer documents or dictionaries. + """ + query_params = {"deprecated": deprecated} # type: dict + + if material_ids: + if isinstance(material_ids, str): + material_ids = [material_ids] + + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) + + query_params = { + entry: query_params[entry] + for entry in query_params + if query_params[entry] is not None + } + return super()._search( + num_chunks=num_chunks, + chunk_size=chunk_size, + all_fields=all_fields, + fields=fields, + **query_params, + ) diff --git a/mp_api/client/routes/materials/similarity.py b/mp_api/client/routes/materials/similarity.py index 33be807ea..fa054f61f 100644 --- a/mp_api/client/routes/materials/similarity.py +++ b/mp_api/client/routes/materials/similarity.py @@ -3,6 +3,7 @@ from emmet.core.similarity import SimilarityDoc from mp_api.client.core import BaseRester +from mp_api.client.core.utils import validate_ids class SimilarityRester(BaseRester[SimilarityDoc]): @@ -10,10 +11,47 @@ class SimilarityRester(BaseRester[SimilarityDoc]): document_model = SimilarityDoc # type: ignore primary_key = "material_id" - def search(*args, **kwargs): # pragma: no cover - raise NotImplementedError( - """ - The SimilarityRester.search method does not exist as no search endpoint is present. - Use get_data_by_id instead. - """ + def search_docs( + self, + material_ids: str | list[str] | None = None, + deprecated: bool | None = False, + num_chunks: int | None = None, + chunk_size: int = 1000, + all_fields: bool = True, + fields: list[str] | None = None, + ) -> list[SimilarityDoc] | list[dict]: + """Query similarity docs using a variety of search criteria. + + Arguments: + material_ids (str, List[str]): A single Material ID string or list of strings + (e.g., mp-149, [mp-149, mp-13]). + deprecated (bool): Whether the material is tagged as deprecated. + num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. + chunk_size (int): Number of data entries per chunk. + all_fields (bool): Whether to return all fields in the document. Defaults to True. + fields (List[str]): List of fields in SimilarityDoc to return data for. + Default is material_id, last_updated, and formula_pretty if all_fields is False. + + Returns: + ([SimilarityDoc], [dict]) List of similarity documents or dictionaries. + """ + query_params = {"deprecated": deprecated} # type: dict + + if material_ids: + if isinstance(material_ids, str): + material_ids = [material_ids] + + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) + + query_params = { + entry: query_params[entry] + for entry in query_params + if query_params[entry] is not None + } + return super()._search( + num_chunks=num_chunks, + chunk_size=chunk_size, + all_fields=all_fields, + fields=fields, + **query_params, ) diff --git a/mp_api/client/routes/materials/substrates.py b/mp_api/client/routes/materials/substrates.py index 4524537a9..ac2ba6c8e 100644 --- a/mp_api/client/routes/materials/substrates.py +++ b/mp_api/client/routes/materials/substrates.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from collections import defaultdict from emmet.core.substrates import SubstratesDoc @@ -13,17 +12,6 @@ class SubstratesRester(BaseRester[SubstratesDoc]): document_model = SubstratesDoc # type: ignore primary_key = "film_id" - def search_substrates_docs(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "MPRester.substrates.search_substrates_docs is deprecated. " - "Please use MPRester.substrates.search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( self, area: tuple[float, float] | None = None, @@ -37,8 +25,8 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): - """Query equations of state docs using a variety of search criteria. + ) -> list[SubstratesDoc] | list[dict]: + """Query substrate docs using a variety of search criteria. Arguments: area (Tuple[float,float]): Minimum and maximum volume in Ų to consider for the minimum coincident @@ -56,7 +44,7 @@ def search( Default is the film_id and substrate_id only if all_fields is False. Returns: - ([SubstratesDoc]) List of substrate documents + ([SubstratesDoc], [dict]) List of substrate documents or dictionaries. """ query_params = defaultdict(dict) # type: dict diff --git a/mp_api/client/routes/materials/summary.py b/mp_api/client/routes/materials/summary.py index 3f11f27d3..e68782adf 100644 --- a/mp_api/client/routes/materials/summary.py +++ b/mp_api/client/routes/materials/summary.py @@ -1,9 +1,7 @@ from __future__ import annotations -import warnings from collections import defaultdict -from emmet.core.mpid import MPID from emmet.core.summary import HasProps, SummaryDoc from emmet.core.symmetry import CrystalSystem from pymatgen.analysis.magnetism import Ordering @@ -17,17 +15,6 @@ class SummaryRester(BaseRester[SummaryDoc]): document_model = SummaryDoc # type: ignore primary_key = "material_id" - def search_summary_docs(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "MPRester.summary.search_summary_docs is deprecated. " - "Please use MPRester.summary.search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( self, band_gap: tuple[float, float] | None = None, @@ -58,7 +45,7 @@ def search( k_voigt: tuple[float, float] | None = None, k_vrh: tuple[float, float] | None = None, magnetic_ordering: Ordering | None = None, - material_ids: list[MPID] | None = None, + material_ids: str | list[str] | None = None, n: tuple[float, float] | None = None, num_elements: tuple[int, int] | None = None, num_sites: tuple[int, int] | None = None, @@ -84,7 +71,7 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): + ) -> list[SummaryDoc] | list[dict]: """Query core data using a variety of search criteria. Arguments: @@ -126,7 +113,8 @@ def search( k_vrh (Tuple[float,float]): Minimum and maximum value in GPa to consider for the Voigt-Reuss-Hill average of the bulk modulus. magnetic_ordering (Ordering): Magnetic ordering of the material. - material_ids (List[MPID]): List of Materials Project IDs to return data for. + material_ids (str, List[str]): A single Material ID string or list of strings + (e.g., mp-149, [mp-149, mp-13]). n (Tuple[float,float]): Minimum and maximum refractive index to consider. num_elements (Tuple[int,int]): Minimum and maximum number of elements to consider. num_sites (Tuple[int,int]): Minimum and maximum number of sites to consider. @@ -155,11 +143,11 @@ def search( num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. chunk_size (int): Number of data entries per chunk. all_fields (bool): Whether to return all fields in the document. Defaults to True. - fields (List[str]): List of fields in SearchDoc to return data for. + fields (List[str]): List of fields in SummaryDoc to return data for. Default is material_id if all_fields is False. Returns: - ([SummaryDoc]) List of SummaryDoc documents + ([SummaryDoc], [dict]) List of SummaryDoc documents or dictionaries. """ query_params = defaultdict(dict) # type: dict @@ -212,6 +200,9 @@ def search( ) if material_ids: + if isinstance(material_ids, str): + material_ids = [material_ids] + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) if deprecated is not None: diff --git a/mp_api/client/routes/materials/surface_properties.py b/mp_api/client/routes/materials/surface_properties.py index 7499f4b07..a3c92c247 100644 --- a/mp_api/client/routes/materials/surface_properties.py +++ b/mp_api/client/routes/materials/surface_properties.py @@ -1,31 +1,21 @@ from __future__ import annotations -import warnings from collections import defaultdict from emmet.core.surface_properties import SurfacePropDoc from mp_api.client.core import BaseRester +from mp_api.client.core.utils import validate_ids class SurfacePropertiesRester(BaseRester[SurfacePropDoc]): suffix = "materials/surface_properties" document_model = SurfacePropDoc # type: ignore - primary_key = "task_id" - - def search_surface_properties_docs(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "MPRester.surface_properties.search_surface_properties_docs is deprecated. " - "Please use MPRester.surface_properties.search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) + primary_key = "material_id" def search( self, + material_ids: str | list[str] | None = None, has_reconstructed: bool | None = None, shape_factor: tuple[float, float] | None = None, surface_energy_anisotropy: tuple[float, float] | None = None, @@ -35,10 +25,12 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): + ) -> list[SurfacePropDoc] | list[dict]: """Query surface properties docs using a variety of search criteria. Arguments: + material_ids (str, List[str]): A single Material ID string or list of strings + (e.g., mp-149, [mp-149, mp-13]). has_reconstructed (bool): Whether the entry has any reconstructed surfaces. shape_factor (Tuple[float,float]): Minimum and maximum shape factor values to consider. surface_energy_anisotropy (Tuple[float,float]): Minimum and maximum surface energy anisotropy values to @@ -53,10 +45,16 @@ def search( Default is material_id only if all_fields is False. Returns: - ([SurfacePropDoc]) List of surface properties documents + ([SurfacePropDoc], [dict]) List of surface properties documents """ query_params = defaultdict(dict) # type: dict + if material_ids: + if isinstance(material_ids, str): + material_ids = [material_ids] + + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) + if weighted_surface_energy: query_params.update( { diff --git a/mp_api/client/routes/materials/synthesis.py b/mp_api/client/routes/materials/synthesis.py index 0d60170f5..171f5e2b4 100644 --- a/mp_api/client/routes/materials/synthesis.py +++ b/mp_api/client/routes/materials/synthesis.py @@ -1,7 +1,5 @@ from __future__ import annotations -import warnings - from emmet.core.synthesis import ( OperationTypeEnum, SynthesisSearchResultModel, @@ -15,16 +13,6 @@ class SynthesisRester(BaseRester[SynthesisSearchResultModel]): suffix = "materials/synthesis" document_model = SynthesisSearchResultModel # type: ignore - def search_synthesis_text(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "search_synthesis_text is deprecated. " "Please use search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( self, keywords: list[str] | None = None, @@ -41,7 +29,7 @@ def search( condition_mixing_media: list[str] | None = None, num_chunks: int | None = None, chunk_size: int | None = 10, - ): + ) -> list[SynthesisSearchResultModel] | list[dict]: """Search synthesis recipe text. Arguments: @@ -62,7 +50,7 @@ def search( Returns: - synthesis_docs ([SynthesisDoc]): List of synthesis documents. + ([SynthesisSearchResultModel], [dict]): List of synthesis documents or dictionaries. """ # Turn None and empty list into None keywords = keywords or None diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index b684a2dac..e1a7d8034 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -19,8 +19,9 @@ def get_trajectory(self, task_id): material throughout a calculation. This is most useful for observing how a material relaxes during a geometry optimization. - :param task_id: A specified task_id - :return: List of trajectory objects + Args: + task_id (str): Task ID + """ traj_data = self._query_resource_data( suburl=f"trajectory/{task_id}/", use_document_model=False @@ -31,17 +32,6 @@ def get_trajectory(self, task_id): return traj_data - def search_task_docs(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "MPRester.tasks.search_task_docs is deprecated. " - "Please use MPRester.tasks.search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( self, task_ids: list[str] | None = None, @@ -54,7 +44,7 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): + ) -> list[TaskDoc] | list[dict]: """Query core task docs using a variety of search criteria. Arguments: @@ -74,7 +64,7 @@ def search( Default is material_id, last_updated, and formula_pretty if all_fields is False. Returns: - ([TaskDoc]) List of task documents + ([TaskDoc], [dict]) List of task documents or dictionaries. """ query_params = {} # type: dict diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index 69ddda2cd..e0cecd2bb 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from collections import defaultdict import numpy as np @@ -18,26 +17,15 @@ class ThermoRester(BaseRester[ThermoDoc]): supports_versions = True primary_key = "thermo_id" - def search_thermo_docs(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "MPRester.thermo.search_thermo_docs is deprecated. " - "Please use MPRester.thermo.search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( self, + material_ids: str | list[str] | None = None, chemsys: str | list[str] | None = None, energy_above_hull: tuple[float, float] | None = None, equilibrium_reaction_energy: tuple[float, float] | None = None, formation_energy: tuple[float, float] | None = None, formula: str | list[str] | None = None, is_stable: bool | None = None, - material_ids: list[str] | None = None, num_elements: tuple[int, int] | None = None, thermo_ids: list[str] | None = None, thermo_types: list[ThermoType | str] | None = None, @@ -47,10 +35,12 @@ def search( chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, - ): - """Query core material docs using a variety of search criteria. + ) -> list[ThermoDoc] | list[dict]: + """Query core thermo docs using a variety of search criteria. Arguments: + material_ids (str, List[str]): A single Material ID string or list of strings + (e.g., mp-149, [mp-149, mp-13]). chemsys (str, List[str]): A chemical system or list of chemical systems (e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]). energy_above_hull (Tuple[float,float]): Minimum and maximum energy above the hull in eV/atom to consider. @@ -76,7 +66,7 @@ def search( Default is material_id and last_updated if all_fields is False. Returns: - ([ThermoDoc]) List of thermo documents + ([ThermoDoc], [dict]) List of thermo documents or dictionaries. """ query_params = defaultdict(dict) # type: dict @@ -93,6 +83,9 @@ def search( query_params.update({"chemsys": ",".join(chemsys)}) if material_ids: + if isinstance(material_ids, str): + material_ids = [material_ids] + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) if thermo_ids: @@ -160,7 +153,7 @@ def get_phase_diagram_from_chemsys( Returns: - phase_diagram (PhaseDiagram): Pymatgen phase diagram object. + (PhaseDiagram): Pymatgen phase diagram object. """ t_type = thermo_type if isinstance(thermo_type, str) else thermo_type.value valid_types = {*map(str, ThermoType.__members__.values())} diff --git a/mp_api/client/routes/materials/xas.py b/mp_api/client/routes/materials/xas.py index 7f0592624..b4ee1d82e 100644 --- a/mp_api/client/routes/materials/xas.py +++ b/mp_api/client/routes/materials/xas.py @@ -1,7 +1,5 @@ from __future__ import annotations -import warnings - from emmet.core.xas import Edge, Type, XASDoc from pymatgen.core.periodic_table import Element @@ -14,17 +12,6 @@ class XASRester(BaseRester[XASDoc]): document_model = XASDoc # type: ignore primary_key = "spectrum_id" - def search_xas_docs(self, *args, **kwargs): # pragma: no cover - """Deprecated.""" - warnings.warn( - "MPRester.xas.search_xas_docs is deprecated. " - "Please use MPRester.xas.search instead.", - DeprecationWarning, - stacklevel=2, - ) - - return self.search(*args, **kwargs) - def search( self, edge: Edge | None = None, @@ -34,6 +21,7 @@ def search( elements: list[str] | None = None, material_ids: list[str] | None = None, spectrum_type: Type | None = None, + spectrum_ids: str | list[str] | None = None, num_chunks: int | None = None, chunk_size: int = 1000, all_fields: bool = True, @@ -49,8 +37,11 @@ def search( chemsys (str, List[str]): A chemical system or list of chemical systems (e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]). elements (List[str]): A list of elements. - material_ids (List[str]): List of Materials Project IDs to return data for. + material_ids (str, List[str]): A single Material ID string or list of strings + (e.g., mp-149, [mp-149, mp-13]). spectrum_type (Type): Spectrum type (e.g. EXAFS, XAFS, or XANES). + spectrum_ids (str, List[str]): A single Spectrum ID string or list of strings + (e.g., mp-149-XANES-Li-K, [mp-149-XANES-Li-K, mp-13-XANES-Li-K]). num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. chunk_size (int): Number of data entries per chunk. all_fields (bool): Whether to return all fields in the document. Defaults to True. @@ -89,8 +80,23 @@ def search( if elements: query_params.update({"elements": ",".join(elements)}) - if material_ids is not None: - query_params["material_ids"] = ",".join(validate_ids(material_ids)) + if material_ids: + if isinstance(material_ids, str): + material_ids = [material_ids] + + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) + + if spectrum_ids: + if isinstance(spectrum_ids, str): + spectrum_ids = [spectrum_ids] + + query_params.update({"spectrum_ids": ",".join(spectrum_ids)}) + + query_params = { + entry: query_params[entry] + for entry in query_params + if query_params[entry] is not None + } return super()._search( num_chunks=num_chunks, From 0d38c1abe06c2afda23e85fc653bb9020cf94392 Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 13:45:59 -0800 Subject: [PATCH 04/34] Remove charge density rester --- .../client/routes/materials/charge_density.py | 72 ------------------- 1 file changed, 72 deletions(-) delete mode 100644 mp_api/client/routes/materials/charge_density.py diff --git a/mp_api/client/routes/materials/charge_density.py b/mp_api/client/routes/materials/charge_density.py deleted file mode 100644 index c1c3a0edb..000000000 --- a/mp_api/client/routes/materials/charge_density.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import Literal - -from emmet.core.charge_density import ChgcarDataDoc -from monty.serialization import dumpfn - -from mp_api.client.core import BaseRester -from mp_api.client.core.utils import validate_ids - - -class ChargeDensityRester(BaseRester[ChgcarDataDoc]): - suffix = "materials/charge_density" - primary_key = "fs_id" - document_model = ChgcarDataDoc # type: ignore - boto_resource = None - - def download_for_task_ids( - self, - path: str, - task_ids: list[str], - ext: Literal["json.gz", "json", "mpk", "mpk.gz"] = "json.gz", # type: ignore - ) -> int: - """Download a set of charge densities. - - :param path: Your local directory to save the charge densities to. Each charge - density will be serialized as a separate JSON file with name given by the relevant - task_id. - :param task_ids: A list of task_ids. - :param ext: Choose from any file type supported by `monty`, e.g. json or msgpack. - :return: An integer for the number of charge densities saved. - """ - num_downloads = 0 - path_obj = Path(path) - chgcar_summary_docs = self.search(task_ids=task_ids) - for entry in chgcar_summary_docs: - fs_id = entry.fs_id # type: ignore - task_id = entry.task_id # type: ignore - doc = self.get_data_by_id(fs_id) - dumpfn(doc, path_obj / f"{task_id}.{ext}") - num_downloads += 1 - return num_downloads - - def search( - self, - task_ids: list[str] | None = None, - num_chunks: int | None = 1, - chunk_size: int = 10, - ) -> list[ChgcarDataDoc] | list[dict]: - """A search method to find what charge densities are available via this API. - - Arguments: - task_ids (List[str]): List of Materials Project IDs to return data for. - num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. - chunk_size (int): Number of data entries per chunk. - - Returns: - A list of ChgcarDataDoc that contain task_id references. - """ - query_params = {} - - if task_ids: - query_params.update({"task_ids": ",".join(validate_ids(task_ids))}) - - return super()._search( - num_chunks=num_chunks, - chunk_size=chunk_size, - all_fields=False, - fields=["last_updated", "task_id", "fs_id"], - **query_params, - ) From 57ac59351ff3efe30b1816743ce124792082be4c Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 13:46:06 -0800 Subject: [PATCH 05/34] Fix materials structure method --- mp_api/client/routes/materials/materials.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/mp_api/client/routes/materials/materials.py b/mp_api/client/routes/materials/materials.py index f015fa5f4..36ae699a6 100644 --- a/mp_api/client/routes/materials/materials.py +++ b/mp_api/client/routes/materials/materials.py @@ -127,12 +127,16 @@ def get_structure_by_material_id( structure (Union[Structure, List[Structure]]): Pymatgen structure object or list of pymatgen structure objects. """ - if final: - response = self.get_data_by_id(material_id, fields=["structure"]) - return response.structure if response is not None else response # type: ignore - else: - response = self.get_data_by_id(material_id, fields=["initial_structures"]) - return response.initial_structures if response is not None else response # type: ignore + field = "structure" if final else "initial_structures" + + response = self.search(material_ids=material_id, fields=[field]) + + if response: + response = ( + response[0].model_dump() if self.use_document_model else response[0] # type: ignore + ) + + return response[field] if response else response # type: ignore def search( self, From 00826466fd7e03591dc26ab2713125a6048fa530 Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 13:47:02 -0800 Subject: [PATCH 06/34] Remove charge density references --- mp_api/client/mprester.py | 17 ++++++++++------- mp_api/client/routes/materials/materials.py | 2 -- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 6029995c2..880a7df4f 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -31,7 +31,6 @@ AlloysRester, BandStructureRester, BondsRester, - ChargeDensityRester, ChemenvRester, DielectricRester, DOIRester, @@ -101,7 +100,6 @@ class MPRester: robocrys: RobocrysRester synthesis: SynthesisRester insertion_electrodes: ElectrodeRester - charge_density: ChargeDensityRester electronic_structure: ElectronicStructureRester electronic_structure_bandstructure: BandStructureRester electronic_structure_dos: DosRester @@ -207,7 +205,6 @@ def __init__( "robocrys", "synthesis", "insertion_electrodes", - "charge_density", "electronic_structure", "electronic_structure_bandstructure", "electronic_structure_dos", @@ -1225,8 +1222,10 @@ def get_bandstructure_by_material_id( Returns: bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object """ - return self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore - material_id=material_id, path_type=path_type, line_mode=line_mode + return ( + self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore + material_id=material_id, path_type=path_type, line_mode=line_mode + ) ) def get_dos_by_material_id(self, material_id: str): @@ -1238,7 +1237,9 @@ def get_dos_by_material_id(self, material_id: str): Returns: dos (CompleteDos): CompleteDos object """ - return self.electronic_structure_dos.get_dos_from_material_id(material_id=material_id) # type: ignore + return self.electronic_structure_dos.get_dos_from_material_id( + material_id=material_id + ) # type: ignore def get_phonon_dos_by_material_id(self, material_id: str): """Get phonon density of states data corresponding to a material_id. @@ -1315,7 +1316,9 @@ def get_charge_density_from_material_id( task_ids = self.get_task_ids_associated_with_material_id( material_id, calc_types=[CalcType.GGA_Static, CalcType.GGA_U_Static] ) - results: list[TaskDoc] = self.tasks.search(task_ids=task_ids, fields=["last_updated", "task_id"]) # type: ignore + results: list[TaskDoc] = self.tasks.search( + task_ids=task_ids, fields=["last_updated", "task_id"] + ) # type: ignore if len(results) == 0: return None diff --git a/mp_api/client/routes/materials/materials.py b/mp_api/client/routes/materials/materials.py index 36ae699a6..bb2604f5e 100644 --- a/mp_api/client/routes/materials/materials.py +++ b/mp_api/client/routes/materials/materials.py @@ -14,7 +14,6 @@ AlloysRester, BandStructureRester, BondsRester, - ChargeDensityRester, ChemenvRester, DielectricRester, DosRester, @@ -67,7 +66,6 @@ class MaterialsRester(BaseRester[MaterialsDoc]): "robocrys", "synthesis", "insertion_electrodes", - "charge_density", "electronic_structure", "electronic_structure_bandstructure", "electronic_structure_dos", From 449a81e1905782669ee3ac0d2c6efe00206eba9d Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 13:47:35 -0800 Subject: [PATCH 07/34] More charge density fixes --- mp_api/client/routes/materials/materials.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mp_api/client/routes/materials/materials.py b/mp_api/client/routes/materials/materials.py index bb2604f5e..96cefb914 100644 --- a/mp_api/client/routes/materials/materials.py +++ b/mp_api/client/routes/materials/materials.py @@ -1,7 +1,5 @@ from __future__ import annotations -import warnings - from emmet.core.settings import EmmetSettings from emmet.core.symmetry import CrystalSystem from emmet.core.vasp.material import MaterialsDoc @@ -97,7 +95,6 @@ class MaterialsRester(BaseRester[MaterialsDoc]): robocrys: RobocrysRester synthesis: SynthesisRester insertion_electrodes: ElectrodeRester - charge_density: ChargeDensityRester electronic_structure: ElectronicStructureRester electronic_structure_bandstructure: BandStructureRester electronic_structure_dos: DosRester From 3610a56b5b32c8754232d8b17ec1ed660c341e11 Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 13:48:58 -0800 Subject: [PATCH 08/34] Add type ignore --- mp_api/client/routes/materials/tasks.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index e1a7d8034..9e10c4f3f 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from datetime import datetime from emmet.core.tasks import TaskDoc @@ -25,7 +24,7 @@ def get_trajectory(self, task_id): """ traj_data = self._query_resource_data( suburl=f"trajectory/{task_id}/", use_document_model=False - )[0].get("trajectories", None) + )[0].get("trajectories", None) # type: ignore if traj_data is None: raise MPRestError(f"No trajectory data for {task_id} found") From 41827d4bb1d7364a64d3904a2430e622de0c6cc7 Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 14:01:44 -0800 Subject: [PATCH 09/34] Fix task_id method --- mp_api/client/mprester.py | 47 ++++++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 880a7df4f..99291df4f 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -65,8 +65,8 @@ "methods will be retained until at least January 2022 for backwards compatibility." ) -_EMMET_SETTINGS = EmmetSettings() -_MAPI_SETTINGS = MAPIClientSettings() +_EMMET_SETTINGS = EmmetSettings() # type: ignore +_MAPI_SETTINGS = MAPIClientSettings() # type: ignore DEFAULT_API_KEY = environ.get("MP_API_KEY", None) DEFAULT_ENDPOINT = environ.get("MP_API_ENDPOINT", "https://api.materialsproject.org/") @@ -127,8 +127,8 @@ def __init__( include_user_agent: bool = True, monty_decode: bool = True, use_document_model: bool = True, - session: Session = None, - headers: dict = None, + session: Session | None = None, + headers: dict | None = None, mute_progress_bars: bool = _MAPI_SETTINGS.MUTE_PROGRESS_BARS, ): """Args: @@ -336,8 +336,8 @@ def __molecules_getattr__(_self, attr): rester = __core_custom_getattr(_self, attr, _rester_map) return rester - MaterialsRester.__getattr__ = __materials_getattr__ - MoleculeRester.__getattr__ = __molecules_getattr__ + MaterialsRester.__getattr__ = __materials_getattr__ # type: ignore + MoleculeRester.__getattr__ = __molecules_getattr__ # type: ignore for attr, rester in core_resters.items(): setattr( @@ -353,7 +353,9 @@ def contribs(self): from mpcontribs.client import Client self._contribs = Client( - self.api_key, headers=self.headers, session=self.session + self.api_key, # type: ignore + headers=self.headers, + session=self.session, ) except ImportError: @@ -407,19 +409,34 @@ def __dir__(self): def get_task_ids_associated_with_material_id( self, material_id: str, calc_types: list[CalcType] | None = None ) -> list[str]: - """:param material_id: - :param calc_types: if specified, will restrict to certain task types, e.g. [CalcType.GGA_STATIC] - :return: + """Get Task ID values associated with a specific Material ID. + + Args: + material_id (str): Material ID + calc_types ([CalcType]): If specified, will restrict to a certain task type, e.g. [CalcType.GGA_STATIC] + + Returns: + ([str]): List of Task ID values. """ - tasks = self.materials.get_data_by_id( - material_id, fields=["calc_types"] - ).calc_types + tasks = self.materials.search(material_ids=material_id, fields=["calc_types"]) + + if not tasks: + return [] + + calculations = ( + tasks[0].calc_types # type: ignore + if self.use_document_model + else tasks[0]["calc_types"] # type: ignore + ) + if calc_types: return [ - task for task, calc_type in tasks.items() if calc_type in calc_types + task + for task, calc_type in calculations.items() + if calc_type in calc_types ] else: - return list(tasks.keys()) + return list(calculations.keys()) def get_structure_by_material_id( self, material_id: str, final: bool = True, conventional_unit_cell: bool = False From b940687b3491c9ef6462ce88a08d72835661aba7 Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 14:04:31 -0800 Subject: [PATCH 10/34] Fix references method --- mp_api/client/mprester.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 99291df4f..260b3597c 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -545,7 +545,12 @@ def get_material_id_references(self, material_id: str) -> list[str]: Returns: List of BibTeX references ([str]) """ - return self.provenance.get_data_by_id(material_id).references + docs = self.provenance.search(material_ids=material_id) + + if not docs: + return [] + + return docs[0].references if self.use_document_model else docs[0]["references"] # type: ignore def get_materials_id_references(self, material_id: str) -> list[str]: """This method is deprecated, please use get_material_id_references.""" From f98c3be3db92707c88df284556e0fd5a0772f6fd Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 14:22:49 -0800 Subject: [PATCH 11/34] Fix phonon methods --- mp_api/client/mprester.py | 77 ++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 45 deletions(-) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 260b3597c..4197c03b6 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -552,14 +552,6 @@ def get_material_id_references(self, material_id: str) -> list[str]: return docs[0].references if self.use_document_model else docs[0]["references"] # type: ignore - def get_materials_id_references(self, material_id: str) -> list[str]: - """This method is deprecated, please use get_material_id_references.""" - warnings.warn( - "This method is deprecated, please use get_material_id_references instead.", - DeprecationWarning, - ) - return self.get_material_id_references(material_id) - def get_material_ids( self, chemsys_formula: str | list[str], @@ -581,7 +573,7 @@ def get_material_ids( input_params = {"formula": chemsys_formula} return sorted( - doc.material_id + doc.material_id if self.use_document_model else doc["material_id"] # type: ignore for doc in self.materials.search( **input_params, # type: ignore all_fields=False, @@ -589,17 +581,6 @@ def get_material_ids( ) ) - def get_materials_ids( - self, - chemsys_formula: str | list[str], - ) -> list[MPID]: - """This method is deprecated, please use get_material_ids.""" - warnings.warn( - "This method is deprecated, please use get_material_ids.", - DeprecationWarning, - ) - return self.get_material_ids(chemsys_formula) - def get_structures( self, chemsys_formula: str | list[str], final=True ) -> list[Structure]: @@ -623,7 +604,7 @@ def get_structures( if final: return [ - doc.structure + doc.structure if self.use_document_model else doc["structure"] # type: ignore for doc in self.materials.search( **input_params, # type: ignore all_fields=False, @@ -638,7 +619,11 @@ def get_structures( all_fields=False, fields=["initial_structures"], ): - structures.extend(doc.initial_structures) + structures.extend( + doc.initial_structures # type: ignore + if self.use_document_model + else doc["initial_structures"] # type: ignore + ) return structures @@ -682,10 +667,10 @@ def get_entries( self, chemsys_formula_mpids: str | list[str], compatible_only: bool = True, - inc_structure: bool = None, - property_data: list[str] = None, + inc_structure: bool | None = None, + property_data: list[str] | None = None, conventional_unit_cell: bool = False, - additional_criteria: dict = None, + additional_criteria: dict | None = None, ) -> list[ComputedStructureEntry]: """Get a list of ComputedEntries or ComputedStructureEntries corresponding to a chemical system or formula. This returns entries for all thermo types @@ -752,19 +737,19 @@ def get_entries( ) docs = self.thermo.search( - **input_params, + **input_params, # type: ignore all_fields=False, - fields=fields, # type: ignore + fields=fields, ) for doc in docs: entry_list = ( - doc.entries.values() + doc.entries.values() # type: ignore if self.use_document_model - else doc["entries"].values() + else doc["entries"].values() # type: ignore ) for entry in entry_list: - entry_dict = entry.as_dict() if self.monty_decode else entry + entry_dict: dict = entry.as_dict() if self.monty_decode else entry # type: ignore if not compatible_only: entry_dict["correction"] = 0.0 entry_dict["energy_adjustments"] = [] @@ -772,9 +757,9 @@ def get_entries( if property_data: for property in property_data: entry_dict["data"][property] = ( - doc.model_dump()[property] + doc.model_dump()[property] # type: ignore if self.use_document_model - else doc[property] + else doc[property] # type: ignore ) if conventional_unit_cell: @@ -904,7 +889,7 @@ def get_pourbaix_entries( ion_ref_entries = GibbsComputedStructureEntry.from_entries( ion_ref_entries, temp=use_gibbs ) - ion_ref_pd = PhaseDiagram(ion_ref_entries) + ion_ref_pd = PhaseDiagram(ion_ref_entries) # type: ignore ion_entries = self.get_ion_entries(ion_ref_pd, ion_ref_data=ion_data) pbx_entries = [PourbaixEntry(e, f"ion-{n}") for n, e in enumerate(ion_entries)] @@ -923,7 +908,7 @@ def get_pourbaix_entries( or extra_elts.intersection(entry_elts) ): # Create new computed entry - form_e = ion_ref_pd.get_form_energy(entry) + form_e = ion_ref_pd.get_form_energy(entry) # type: ignore new_entry = ComputedEntry( entry.composition, form_e, entry_id=entry.entry_id ) @@ -961,11 +946,11 @@ def get_ion_reference_data(self) -> list[dict]: 'reference': 'H. E. Barner and R. V. Scheuerman, Handbook of thermochemical data for compounds and aqueous species, Wiley, New York (1978)'}} """ - return self.contribs.query_contributions( + return self.contribs.query_contributions( # type: ignore query={"project": "ion_ref_data"}, fields=["identifier", "formula", "data"], paginate=True, - ).get("data") + ).get("data") # type: ignore def get_ion_reference_data_for_chemsys(self, chemsys: str | list) -> list[dict]: """Download aqueous ion reference data used in the construction of Pourbaix diagrams. @@ -1007,7 +992,7 @@ def get_ion_reference_data_for_chemsys(self, chemsys: str | list) -> list[dict]: return [d for d in ion_data if d["data"]["MajElements"] in chemsys] def get_ion_entries( - self, pd: PhaseDiagram, ion_ref_data: list[dict] = None + self, pd: PhaseDiagram, ion_ref_data: list[dict] | None = None ) -> list[IonEntry]: """Retrieve IonEntry objects that can be used in the construction of Pourbaix Diagrams. The energies of the IonEntry are calculaterd from @@ -1052,7 +1037,7 @@ def get_ion_entries( # position the ion energies relative to most stable reference state ion_entries = [] - for _n, i_d in enumerate(ion_data): + for _, i_d in enumerate(ion_data): ion = Ion.from_formula(i_d["formula"]) refs = [ e @@ -1100,8 +1085,8 @@ def get_entry_by_material_id( self, material_id: str, compatible_only: bool = True, - inc_structure: bool = None, - property_data: list[str] = None, + inc_structure: bool | None = None, + property_data: list[str] | None = None, conventional_unit_cell: bool = False, ): """Get all ComputedEntry objects corresponding to a material_id. @@ -1142,8 +1127,8 @@ def get_entries_in_chemsys( elements: str | list[str], use_gibbs: int | None = None, compatible_only: bool = True, - inc_structure: bool = None, - property_data: list[str] = None, + inc_structure: bool | None = None, + property_data: list[str] | None = None, conventional_unit_cell: bool = False, additional_criteria=None, ): @@ -1200,7 +1185,7 @@ def get_entries_in_chemsys( for els in itertools.combinations(elements_set, i + 1): all_chemsyses.append("-".join(sorted(els))) - entries = [] # type: List[ComputedEntry] + entries = [] entries.extend( self.get_entries( @@ -1273,7 +1258,8 @@ def get_phonon_dos_by_material_id(self, material_id: str): CompletePhononDos: A phonon DOS object. """ - return self.phonon.get_data_by_id(material_id, fields=["ph_dos"]).ph_dos + doc = self.phonon.search(material_ids=material_id, fields=["ph_dos"]) + return doc.ph_dos if self.use_document_model else doc["ph_dos"] # type: ignore def get_phonon_bandstructure_by_material_id(self, material_id: str): """Get phonon dispersion data corresponding to a material_id. @@ -1284,7 +1270,8 @@ def get_phonon_bandstructure_by_material_id(self, material_id: str): Returns: PhononBandStructureSymmLine: phonon band structure. """ - return self.phonon.get_data_by_id(material_id, fields=["ph_bs"]).ph_bs + doc = self.phonon.search(material_ids=material_id, fields=["ph_bs"]) + return doc.ph_bs if self.use_document_model else doc["ph_bs"] # type: ignore def get_wulff_shape(self, material_id: str): """Constructs a Wulff shape for a material. From c520641c07dcdf447baf11fce5458a56247219c5 Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 14:30:34 -0800 Subject: [PATCH 12/34] Fix phonon and task methods --- mp_api/client/mprester.py | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 4197c03b6..27f10dbe1 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -1259,7 +1259,9 @@ def get_phonon_dos_by_material_id(self, material_id: str): """ doc = self.phonon.search(material_ids=material_id, fields=["ph_dos"]) - return doc.ph_dos if self.use_document_model else doc["ph_dos"] # type: ignore + if not doc: + return None + return doc[0].ph_dos if self.use_document_model else doc[0]["ph_dos"] # type: ignore def get_phonon_bandstructure_by_material_id(self, material_id: str): """Get phonon dispersion data corresponding to a material_id. @@ -1271,7 +1273,10 @@ def get_phonon_bandstructure_by_material_id(self, material_id: str): PhononBandStructureSymmLine: phonon band structure. """ doc = self.phonon.search(material_ids=material_id, fields=["ph_bs"]) - return doc.ph_bs if self.use_document_model else doc["ph_bs"] # type: ignore + if not doc: + return None + + return doc[0].ph_bs if self.use_document_model else doc[0]["ph_bs"] # type: ignore def get_wulff_shape(self, material_id: str): """Constructs a Wulff shape for a material. @@ -1287,9 +1292,15 @@ def get_wulff_shape(self, material_id: str): from pymatgen.symmetry.analyzer import SpacegroupAnalyzer structure = self.get_structure_by_material_id(material_id) - surfaces = surfaces = self.surface_properties.get_data_by_id( - material_id - ).surfaces + doc = self.surface_properties.search(material_ids=material_id) + + if not doc: + return None + + surfaces: list = ( + doc[0].surfaces if self.use_document_model else doc[0]["surfaces"] # type: ignore + ) + lattice = ( SpacegroupAnalyzer(structure).get_conventional_standard_structure().lattice ) @@ -1304,7 +1315,7 @@ def get_wulff_shape(self, material_id: str): def get_charge_density_from_material_id( self, material_id: str, inc_task_doc: bool = False - ) -> Chgcar | None: + ) -> Chgcar | tuple[Chgcar, TaskDoc | dict] | None: """Get charge density data for a given Materials Project ID. Arguments: @@ -1312,7 +1323,7 @@ def get_charge_density_from_material_id( inc_task_doc (bool): Whether to include the task document in the returned data. Returns: - chgcar: Pymatgen Chgcar object. + (Chgcar, (Chgcar, TaskDoc | dict), None): Pymatgen Chgcar object, or tuple with object and TaskDoc """ if not hasattr(self, "charge_density"): raise MPRestError( @@ -1332,7 +1343,12 @@ def get_charge_density_from_material_id( if len(results) == 0: return None - latest_doc = max(results, key=lambda x: x.last_updated) + latest_doc = max( # type: ignore + results, + key=lambda x: x.last_updated # type: ignore + if self.use_document_model + else x["last_updated"], # type: ignore + ) result = ( self.tasks._query_open_data( @@ -1349,7 +1365,12 @@ def get_charge_density_from_material_id( raise MPRestError(f"No charge density fetched for {material_id}.") if inc_task_doc: - task_doc = self.tasks.get_data_by_id(latest_doc.task_id) + task_doc = self.tasks.search( + task_ids=latest_doc.task_id + if self.use_document_model + else latest_doc["task_id"] + )[0] + return chgcar, task_doc return chgcar From cccfdbef31fe0d96f9ed25e92ed39e8104094746 Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 14:45:57 -0800 Subject: [PATCH 13/34] Fix remaining type issues in MPRester --- mp_api/client/mprester.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 27f10dbe1..44226e9cc 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -879,7 +879,7 @@ def get_pourbaix_entries( warnings.filterwarnings( "ignore", message="Failed to guess oxidation states.*" ) - ion_ref_entries = compat.process_entries(ion_ref_entries) + ion_ref_entries = compat.process_entries(ion_ref_entries) # type: ignore # TODO - if the commented line above would work, this conditional block # could be removed if use_gibbs: @@ -1399,10 +1399,11 @@ def get_download_info(self, material_ids, calc_types=None, file_patterns=None): task_ids=material_ids, fields=["calc_types", "deprecated_tasks", "material_id"], ): - for task_id, calc_type in doc.calc_types.items(): + doc_dict: dict = doc.model_dump() if self.use_document_model else doc # type: ignore + for task_id, calc_type in doc_dict["calc_types"].items(): if calc_types and calc_type not in calc_types: continue - mp_id = doc.material_id + mp_id = doc_dict["material_id"] if meta.get(mp_id) is None: meta[mp_id] = [{"task_id": task_id, "calc_type": calc_type}] else: From 343bb5ccd29abb5fcec8cd0f9e12a5a19f7b1cbb Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 15:42:43 -0800 Subject: [PATCH 14/34] Remove get_bv refs in es methods --- .../routes/materials/electronic_structure.py | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index a9d75f142..b6f739657 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -262,16 +262,26 @@ def get_bandstructure_from_material_id( ) if line_mode: - bs_data = es_rester.get_data_by_id( - document_id=material_id, fields=["bandstructure"] - ).bandstructure + bs_doc = es_rester.search( + material_ids=material_id, fields=["bandstructure"] + ) + if not bs_doc: + raise MPRestError("No electronic structure data found.") + + bs_data = ( + bs_doc[0].bandstructure # type: ignore + if self.use_document_model + else bs_doc[0]["bandstructure"] # type: ignore + ) if bs_data is None: raise MPRestError( f"No {path_type.value} band structure data found for {material_id}" ) else: - bs_data = bs_data.model_dump() + bs_data: dict = ( + bs_data.model_dump() if self.use_document_model else bs_data # type: ignore + ) if bs_data.get(path_type.value, None): bs_task_id = bs_data[path_type.value]["task_id"] @@ -280,16 +290,25 @@ def get_bandstructure_from_material_id( f"No {path_type.value} band structure data found for {material_id}" ) else: - bs_data = es_rester.get_data_by_id( - document_id=material_id, fields=["dos"] - ).dos + bs_doc = es_rester.search(material_ids=material_id, fields=["dos"]) + + if not bs_doc: + raise MPRestError("No electronic structure data found.") + + bs_data = ( + bs_doc[0].dos # type: ignore + if self.use_document_model + else bs_doc[0]["dos"] # type: ignore + ) if bs_data is None: raise MPRestError( f"No uniform band structure data found for {material_id}" ) else: - bs_data = bs_data.model_dump() + bs_data: dict = ( + bs_data.model_dump() if self.use_document_model else bs_data # type: ignore + ) if bs_data.get("total", None): bs_task_id = bs_data["total"]["1"]["task_id"] @@ -421,9 +440,13 @@ def get_dos_from_material_id(self, material_id: str): endpoint=self.base_endpoint, api_key=self.api_key ) - dos_data = es_rester.get_data_by_id( - document_id=material_id, fields=["dos"] - ).model_dump() + dos_doc = es_rester.search(material_ids=material_id, fields=["dos"]) + if not dos_doc: + return None + + dos_data: dict = ( + dos_doc[0].model_dump() if self.use_document_model else dos_doc[0] # type: ignore + ) if dos_data["dos"]: dos_task_id = dos_data["dos"]["total"]["1"]["task_id"] From 8a6a7baa2afd5ff9bb9115b578bb369870703b19 Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 15:49:26 -0800 Subject: [PATCH 15/34] Fix user settings method --- mp_api/client/routes/_user_settings.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mp_api/client/routes/_user_settings.py b/mp_api/client/routes/_user_settings.py index 5c1ebd2d6..059c2fe96 100644 --- a/mp_api/client/routes/_user_settings.py +++ b/mp_api/client/routes/_user_settings.py @@ -89,4 +89,7 @@ def get_user_settings(self, consumer_id, fields): # pragma: no cover Raises: MPRestError. """ - return self.get_data_by_id(consumer_id, fields) + return super()._search( + consumer_id=consumer_id, + fields=fields, + ) From 32bd0e79fbb3b9098a1ba3f28c029dd5d5e3ae07 Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 15:49:37 -0800 Subject: [PATCH 16/34] Fix get_by ref in molecules --- mp_api/client/routes/molecules/molecules.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/mp_api/client/routes/molecules/molecules.py b/mp_api/client/routes/molecules/molecules.py index 27d493cc5..8a242d27b 100644 --- a/mp_api/client/routes/molecules/molecules.py +++ b/mp_api/client/routes/molecules/molecules.py @@ -29,12 +29,16 @@ def get_molecule_by_mpculeid( molecule (Union[Molecule, List[Molecule]]): Pymatgen Molecule object or list of pymatgen Molecule objects. """ - if final: - response = self.get_data_by_id(mpcule_id, fields=["molecule"]) - return response.molecule if response is not None else response # type: ignore - else: - response = self.get_data_by_id(mpcule_id, fields=["initial_molecules"]) - return response.initial_molecules if response is not None else response # type: ignore + field = "molecule" if final else "initial_molecules" + + response = self.search(molecule_ids=[mpcule_id], fields=[field]) # type: ignore + + if response: + response = ( + response[0].model_dump() if self.use_document_model else response[0] # type: ignore + ) + + return response[field] if response else response # type: ignore def find_molecule( self, From 1199f97c02eddb731280565d6c7fa8356b2474f4 Mon Sep 17 00:00:00 2001 From: munrojm Date: Wed, 31 Jan 2024 15:53:28 -0800 Subject: [PATCH 17/34] Remove remaining references to get_data_by_id --- tests/test_client.py | 10 +++++----- tests/test_core_client.py | 12 ------------ tests/test_mprester.py | 4 ++-- 3 files changed, 7 insertions(+), 19 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index f3b9ac9c0..3d93c82ef 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -56,7 +56,7 @@ ) @pytest.mark.parametrize("rester", resters_to_test) def test_generic_get_methods(rester): - # -- Test generic search and get_data_by_id methods + # -- Test generic search name = rester.suffix.replace("/", "_") rester = rester( @@ -77,14 +77,14 @@ def test_generic_get_methods(rester): assert isinstance(doc, rester.document_model) if name not in search_only_resters: - doc = rester.get_data_by_id( - doc.model_dump()[rester.primary_key], fields=[rester.primary_key] + doc = rester.search( + **{rester.primary_key: doc.model_dump()[rester.primary_key]}, fields=[rester.primary_key] ) assert isinstance(doc, rester.document_model) elif name not in special_resters: - doc = rester.get_data_by_id( - key_only_resters[name], fields=[rester.primary_key] + doc = rester.search( + **{rester.primary_key: key_only_resters[name]}, fields=[rester.primary_key] ) assert isinstance(doc, rester.document_model) diff --git a/tests/test_core_client.py b/tests/test_core_client.py index c8c6a8157..39f9b5e44 100644 --- a/tests/test_core_client.py +++ b/tests/test_core_client.py @@ -42,18 +42,6 @@ def test_count(mpr): assert count == 1 -@pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.") -@pytest.mark.xfail -def test_get_document_no_id(mpr): - mpr.materials.get_data_by_id(None) - - -@pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.") -@pytest.mark.xfail -def test_get_document_no_doc(mpr): - mpr.materials.get_data_by_id("mp-1a") - - def test_available_fields(rester, mpr): assert len(mpr.materials.available_fields) > 0 assert rester.available_fields == ["Unknown fields."] diff --git a/tests/test_mprester.py b/tests/test_mprester.py index d9e86e66e..382295dd9 100644 --- a/tests/test_mprester.py +++ b/tests/test_mprester.py @@ -72,12 +72,12 @@ def test_get_materials_ids_references(self, mpr): def test_get_materials_ids_doc(self, mpr): mp_ids = mpr.get_materials_ids("Al2O3") random.shuffle(mp_ids) - doc = mpr.materials.get_data_by_id(mp_ids.pop(0)) + doc = mpr.materials.search(material_ids=mp_ids.pop(0))[0] assert doc.formula_pretty == "Al2O3" mp_ids = mpr.get_materials_ids("Al-O") random.shuffle(mp_ids) - doc = mpr.materials.get_data_by_id(mp_ids.pop(0)) + doc = mpr.materials.search(material_ids=mp_ids.pop(0))[0] assert doc.chemsys == "Al-O" def test_get_structures(self, mpr): From 200722c26abf0253577b1328233843eb9cb39b74 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 6 Feb 2024 17:28:27 -0800 Subject: [PATCH 18/34] Deprecate get_data_by_id --- mp_api/client/core/client.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 2151fd69a..990310416 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -1011,6 +1011,31 @@ def _search( num_chunks=num_chunks, ) + def get_data_by_id( + self, + document_id: str, + fields: list[str] | None = None, + ) -> T | dict: + warnings.warn( + f"get_data_by_id is deprecated and will be removed soon. Please use the search method instead.", + DeprecationWarning, + stacklevel=2, + ) + + if self.primary_key in ["material_id", "task_id"]: + validate_ids([document_id]) + + if isinstance(fields, str): # pragma: no cover + fields = (fields,) # type: ignore + + return self._search( # type: ignore + **{self.primary_key + "s": document_id}, + num_chunks=1, + chunk_size=1, + all_fields=fields is None, + fields=fields, + ) + def _get_all_documents( self, query_params, From c18a153b6933a802c70dc13594de57ea468e2359 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 15:13:39 -0800 Subject: [PATCH 19/34] Linting --- mp_api/client/core/client.py | 2 +- mp_api/client/mprester.py | 10 +++++----- mp_api/client/routes/materials/tasks.py | 4 +++- tests/test_client.py | 8 +++++--- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 990310416..4bdd675fb 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -1017,7 +1017,7 @@ def get_data_by_id( fields: list[str] | None = None, ) -> T | dict: warnings.warn( - f"get_data_by_id is deprecated and will be removed soon. Please use the search method instead.", + "get_data_by_id is deprecated and will be removed soon. Please use the search method instead.", DeprecationWarning, stacklevel=2, ) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 44226e9cc..ceb1301db 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -950,7 +950,9 @@ def get_ion_reference_data(self) -> list[dict]: query={"project": "ion_ref_data"}, fields=["identifier", "formula", "data"], paginate=True, - ).get("data") # type: ignore + ).get( + "data" + ) # type: ignore def get_ion_reference_data_for_chemsys(self, chemsys: str | list) -> list[dict]: """Download aqueous ion reference data used in the construction of Pourbaix diagrams. @@ -1229,10 +1231,8 @@ def get_bandstructure_by_material_id( Returns: bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object """ - return ( - self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore - material_id=material_id, path_type=path_type, line_mode=line_mode - ) + return self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore + material_id=material_id, path_type=path_type, line_mode=line_mode ) def get_dos_by_material_id(self, material_id: str): diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 9e10c4f3f..c1a8828a8 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -24,7 +24,9 @@ def get_trajectory(self, task_id): """ traj_data = self._query_resource_data( suburl=f"trajectory/{task_id}/", use_document_model=False - )[0].get("trajectories", None) # type: ignore + )[0].get( + "trajectories", None + ) # type: ignore if traj_data is None: raise MPRestError(f"No trajectory data for {task_id} found") diff --git a/tests/test_client.py b/tests/test_client.py index 3d93c82ef..4833f0bc3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -56,7 +56,7 @@ ) @pytest.mark.parametrize("rester", resters_to_test) def test_generic_get_methods(rester): - # -- Test generic search + # -- Test generic search name = rester.suffix.replace("/", "_") rester = rester( @@ -78,13 +78,15 @@ def test_generic_get_methods(rester): if name not in search_only_resters: doc = rester.search( - **{rester.primary_key: doc.model_dump()[rester.primary_key]}, fields=[rester.primary_key] + **{rester.primary_key: doc.model_dump()[rester.primary_key]}, + fields=[rester.primary_key], ) assert isinstance(doc, rester.document_model) elif name not in special_resters: doc = rester.search( - **{rester.primary_key: key_only_resters[name]}, fields=[rester.primary_key] + **{rester.primary_key: key_only_resters[name]}, + fields=[rester.primary_key], ) assert isinstance(doc, rester.document_model) From 01cf5c4b63f8bd7571a031f3cb97f1dc935afcdc Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 15:20:06 -0800 Subject: [PATCH 20/34] Remove charge density rester tests --- tests/materials/test_charge_density.py | 56 -------------------------- 1 file changed, 56 deletions(-) delete mode 100644 tests/materials/test_charge_density.py diff --git a/tests/materials/test_charge_density.py b/tests/materials/test_charge_density.py deleted file mode 100644 index 9c4c90a2d..000000000 --- a/tests/materials/test_charge_density.py +++ /dev/null @@ -1,56 +0,0 @@ -import os - -import pytest -from core_function import client_search_testing - -from mp_api.client.routes.materials.charge_density import ChargeDensityRester - - -@pytest.fixture -def rester(): - rester = ChargeDensityRester() - yield rester - rester.session.close() - - -excluded_params = [ - "sort_fields", - "chunk_size", - "num_chunks", - "all_fields", - "fields", - "return", -] - -sub_doc_fields = [] # type: list - -alt_name_dict = { - "task_ids": "task_id", -} # type: dict - -custom_field_tests = {"task_ids": ["mp-1985345", "mp-1896118"]} # type: dict - - -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) -def test_client(rester): - search_method = rester.search - - client_search_testing( - search_method=search_method, - excluded_params=excluded_params, - alt_name_dict=alt_name_dict, - custom_field_tests=custom_field_tests, - sub_doc_fields=sub_doc_fields, - ) - - -def test_download_for_task_ids(tmpdir, rester): - rester.download_for_task_ids( - task_ids=["mp-655585", "mp-1057373", "mp-1059589", "mp-1440634", "mp-1791788"], - path=tmpdir, - ) - files = [f for f in os.listdir(tmpdir)] - - assert "mp-1791788.json.gz" in files From 796b5bbc3567dcc50ac0a38e186468e8ec14197b Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 15:29:23 -0800 Subject: [PATCH 21/34] Remove generic get method tests --- tests/test_client.py | 95 -------------------------------------------- 1 file changed, 95 deletions(-) delete mode 100644 tests/test_client.py diff --git a/tests/test_client.py b/tests/test_client.py deleted file mode 100644 index 4833f0bc3..000000000 --- a/tests/test_client.py +++ /dev/null @@ -1,95 +0,0 @@ -import os - -import pytest - -from mp_api.client import MPRester -from mp_api.client.routes.materials import TaskRester, ProvenanceRester - -# -- Rester name data for generic tests - -key_only_resters = { - "materials_phonon": "mp-11703", - "materials_similarity": "mp-149", - "doi": "mp-149", - "materials_wulff": "mp-149", - "materials_charge_density": "mp-1936745", - "materials_provenance": "mp-149", - "materials_robocrys": "mp-1025395", -} - -search_only_resters = [ - "materials_grain_boundary", - "materials_electronic_structure_bandstructure", - "materials_electronic_structure_dos", - "materials_substrates", - "materials_synthesis", -] - -special_resters = [ - "materials_charge_density", -] - -ignore_generic = [ - "_user_settings", - "_general_store", - "_messages", - # "tasks", - # "bonds", - "materials_xas", - "materials_elasticity", - "materials_fermi", - # "alloys", - # "summary", -] # temp - - -mpr = MPRester() - -# Temporarily ignore molecules resters while molecules query operators are changed -resters_to_test = [ - rester for rester in mpr._all_resters if "molecule" not in rester.suffix -] - - -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) -@pytest.mark.parametrize("rester", resters_to_test) -def test_generic_get_methods(rester): - # -- Test generic search - name = rester.suffix.replace("/", "_") - - rester = rester( - endpoint=mpr.endpoint, - include_user_agent=True, - session=mpr.session, - monty_decode=True - if rester not in [TaskRester, ProvenanceRester] # type: ignore - else False, # Disable monty decode on nested data which may give errors - use_document_model=True, - ) - - if name not in ignore_generic: - if name not in key_only_resters: - doc = rester._query_resource_data( - {"_limit": 1}, fields=[rester.primary_key] - )[0] - assert isinstance(doc, rester.document_model) - - if name not in search_only_resters: - doc = rester.search( - **{rester.primary_key: doc.model_dump()[rester.primary_key]}, - fields=[rester.primary_key], - ) - assert isinstance(doc, rester.document_model) - - elif name not in special_resters: - doc = rester.search( - **{rester.primary_key: key_only_resters[name]}, - fields=[rester.primary_key], - ) - assert isinstance(doc, rester.document_model) - - -if os.getenv("MP_API_KEY", None) is None: - pytest.mark.skip(test_generic_get_methods) From e7d8e25ac515d6b9de44b47e7bd0be3912ac6422 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 15:36:25 -0800 Subject: [PATCH 22/34] Remove warning expectation --- tests/test_mprester.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_mprester.py b/tests/test_mprester.py index 382295dd9..e136106a3 100644 --- a/tests/test_mprester.py +++ b/tests/test_mprester.py @@ -48,9 +48,6 @@ def test_get_structure_by_material_id(self, mpr): s1 = mpr.get_structure_by_material_id("mp-149", final=False) assert {s.formula for s in s1} == {"Si2"} - # # requesting via task-id instead of mp-id - with pytest.warns(UserWarning): - mpr.get_structure_by_material_id("mp-698856") def test_get_database_version(self, mpr): db_version = mpr.get_database_version() From 1d79579373196f573678a6472aafcd0d54b04ac7 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 15:38:27 -0800 Subject: [PATCH 23/34] Linting --- tests/test_mprester.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_mprester.py b/tests/test_mprester.py index e136106a3..d78b9957b 100644 --- a/tests/test_mprester.py +++ b/tests/test_mprester.py @@ -48,7 +48,6 @@ def test_get_structure_by_material_id(self, mpr): s1 = mpr.get_structure_by_material_id("mp-149", final=False) assert {s.formula for s in s1} == {"Si2"} - def test_get_database_version(self, mpr): db_version = mpr.get_database_version() assert db_version is not None From 48e2f3a050f250041ad5f3501d93dea743d3a878 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 15:47:23 -0800 Subject: [PATCH 24/34] Fix method reference in test --- tests/test_mprester.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mprester.py b/tests/test_mprester.py index d78b9957b..ce55c4e33 100644 --- a/tests/test_mprester.py +++ b/tests/test_mprester.py @@ -62,7 +62,7 @@ def test_get_task_ids_associated_with_material_id(self, mpr): assert len(results) > 0 def test_get_materials_ids_references(self, mpr): - data = mpr.get_materials_id_references("mp-123") + data = mpr.get_material_id_references("mp-123") assert len(data) > 5 def test_get_materials_ids_doc(self, mpr): From c387ac35519c4f69303792e38b9b61c0d6d13f89 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 15:49:00 -0800 Subject: [PATCH 25/34] More method typo fixes --- tests/test_mprester.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mprester.py b/tests/test_mprester.py index ce55c4e33..0ca4e2db1 100644 --- a/tests/test_mprester.py +++ b/tests/test_mprester.py @@ -66,7 +66,7 @@ def test_get_materials_ids_references(self, mpr): assert len(data) > 5 def test_get_materials_ids_doc(self, mpr): - mp_ids = mpr.get_materials_ids("Al2O3") + mp_ids = mpr.get_materials_id("Al2O3") random.shuffle(mp_ids) doc = mpr.materials.search(material_ids=mp_ids.pop(0))[0] assert doc.formula_pretty == "Al2O3" From ec019625a371a9cd91a23accf87183d30342799a Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 15:55:53 -0800 Subject: [PATCH 26/34] More materials typos --- mp_api/client/mprester.py | 18 +++++------------- tests/test_mprester.py | 12 ++++++------ 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index ceb1301db..69c045a10 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -528,14 +528,6 @@ def get_material_id_from_task_id(self, task_id: str) -> str | None: ) return None - def get_materials_id_from_task_id(self, task_id: str) -> str | None: - """This method is deprecated, please use get_material_id_from_task_id.""" - warnings.warn( - "This method is deprecated, please use get_material_id_from_task_id.", - DeprecationWarning, - ) - return self.get_material_id_from_task_id(task_id) - def get_material_id_references(self, material_id: str) -> list[str]: """Returns all references for a material id. @@ -950,9 +942,7 @@ def get_ion_reference_data(self) -> list[dict]: query={"project": "ion_ref_data"}, fields=["identifier", "formula", "data"], paginate=True, - ).get( - "data" - ) # type: ignore + ).get("data") # type: ignore def get_ion_reference_data_for_chemsys(self, chemsys: str | list) -> list[dict]: """Download aqueous ion reference data used in the construction of Pourbaix diagrams. @@ -1231,8 +1221,10 @@ def get_bandstructure_by_material_id( Returns: bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object """ - return self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore - material_id=material_id, path_type=path_type, line_mode=line_mode + return ( + self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore + material_id=material_id, path_type=path_type, line_mode=line_mode + ) ) def get_dos_by_material_id(self, material_id: str): diff --git a/tests/test_mprester.py b/tests/test_mprester.py index 0ca4e2db1..fe7df4de8 100644 --- a/tests/test_mprester.py +++ b/tests/test_mprester.py @@ -52,8 +52,8 @@ def test_get_database_version(self, mpr): db_version = mpr.get_database_version() assert db_version is not None - def test_get_materials_id_from_task_id(self, mpr): - assert mpr.get_materials_id_from_task_id("mp-540081") == "mp-19017" + def test_get_material_id_from_task_id(self, mpr): + assert mpr.get_material_id_from_task_id("mp-540081") == "mp-19017" def test_get_task_ids_associated_with_material_id(self, mpr): results = mpr.get_task_ids_associated_with_material_id( @@ -61,17 +61,17 @@ def test_get_task_ids_associated_with_material_id(self, mpr): ) assert len(results) > 0 - def test_get_materials_ids_references(self, mpr): + def test_get_material_id_references(self, mpr): data = mpr.get_material_id_references("mp-123") assert len(data) > 5 - def test_get_materials_ids_doc(self, mpr): - mp_ids = mpr.get_materials_id("Al2O3") + def test_get_material_id_doc(self, mpr): + mp_ids = mpr.get_material_ids("Al2O3") random.shuffle(mp_ids) doc = mpr.materials.search(material_ids=mp_ids.pop(0))[0] assert doc.formula_pretty == "Al2O3" - mp_ids = mpr.get_materials_ids("Al-O") + mp_ids = mpr.get_material_ids("Al-O") random.shuffle(mp_ids) doc = mpr.materials.search(material_ids=mp_ids.pop(0))[0] assert doc.chemsys == "Al-O" From 70bd6445e9f794fbe065fd967ff73bf36cf268fc Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 15:57:57 -0800 Subject: [PATCH 27/34] Linting --- mp_api/client/mprester.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index 69c045a10..d912a6492 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -942,7 +942,9 @@ def get_ion_reference_data(self) -> list[dict]: query={"project": "ion_ref_data"}, fields=["identifier", "formula", "data"], paginate=True, - ).get("data") # type: ignore + ).get( + "data" + ) # type: ignore def get_ion_reference_data_for_chemsys(self, chemsys: str | list) -> list[dict]: """Download aqueous ion reference data used in the construction of Pourbaix diagrams. @@ -1221,10 +1223,8 @@ def get_bandstructure_by_material_id( Returns: bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object """ - return ( - self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore - material_id=material_id, path_type=path_type, line_mode=line_mode - ) + return self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore + material_id=material_id, path_type=path_type, line_mode=line_mode ) def get_dos_by_material_id(self, material_id: str): From 180a27da4c80b7261dbe2603cfa6fadf1c53ed4f Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 16:14:00 -0800 Subject: [PATCH 28/34] Fix boto3 warning --- mp_api/client/mprester.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index d912a6492..f19f4a3e4 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -942,9 +942,7 @@ def get_ion_reference_data(self) -> list[dict]: query={"project": "ion_ref_data"}, fields=["identifier", "formula", "data"], paginate=True, - ).get( - "data" - ) # type: ignore + ).get("data") # type: ignore def get_ion_reference_data_for_chemsys(self, chemsys: str | list) -> list[dict]: """Download aqueous ion reference data used in the construction of Pourbaix diagrams. @@ -1223,8 +1221,10 @@ def get_bandstructure_by_material_id( Returns: bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object """ - return self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore - material_id=material_id, path_type=path_type, line_mode=line_mode + return ( + self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore + material_id=material_id, path_type=path_type, line_mode=line_mode + ) ) def get_dos_by_material_id(self, material_id: str): @@ -1317,12 +1317,6 @@ def get_charge_density_from_material_id( Returns: (Chgcar, (Chgcar, TaskDoc | dict), None): Pymatgen Chgcar object, or tuple with object and TaskDoc """ - if not hasattr(self, "charge_density"): - raise MPRestError( - "boto3 not installed. " - "To query charge density data install the boto3 package." - ) - # TODO: really we want a recommended task_id for charge densities here # this could potentially introduce an ambiguity task_ids = self.get_task_ids_associated_with_material_id( From 138dfc5b2627b4a4c45a243b29465e363566bffa Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 16:18:54 -0800 Subject: [PATCH 29/34] Fix generic search tests --- tests/materials/core_function.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/materials/core_function.py b/tests/materials/core_function.py index 19c2ac2bd..2d3c7ecd8 100644 --- a/tests/materials/core_function.py +++ b/tests/materials/core_function.py @@ -28,9 +28,10 @@ def client_search_testing( for entry in param_tuples: param = entry[0] - if param not in excluded_params: + if param not in excluded_params + ["return"]: param_type = entry[1] q = None + if "tuple[int, int]" in param_type: project_field = alt_name_dict.get(param, None) q = { @@ -60,6 +61,10 @@ def client_search_testing( "num_chunks": 1, } + if q is None: + raise ValueError(f"Parameter '{param}' with type '{param_type}' was not " + "properly identified in the generic search method test.") + doc = search_method(**q)[0].model_dump() for sub_field in sub_doc_fields: From 75aa6833110f66a2625356e18bb8cc05bde38540 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 16:20:01 -0800 Subject: [PATCH 30/34] Linting --- mp_api/client/mprester.py | 10 +++++----- tests/materials/core_function.py | 6 ++++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index f19f4a3e4..c357e2667 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -942,7 +942,9 @@ def get_ion_reference_data(self) -> list[dict]: query={"project": "ion_ref_data"}, fields=["identifier", "formula", "data"], paginate=True, - ).get("data") # type: ignore + ).get( + "data" + ) # type: ignore def get_ion_reference_data_for_chemsys(self, chemsys: str | list) -> list[dict]: """Download aqueous ion reference data used in the construction of Pourbaix diagrams. @@ -1221,10 +1223,8 @@ def get_bandstructure_by_material_id( Returns: bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object """ - return ( - self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore - material_id=material_id, path_type=path_type, line_mode=line_mode - ) + return self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore + material_id=material_id, path_type=path_type, line_mode=line_mode ) def get_dos_by_material_id(self, material_id: str): diff --git a/tests/materials/core_function.py b/tests/materials/core_function.py index 2d3c7ecd8..bb7138c62 100644 --- a/tests/materials/core_function.py +++ b/tests/materials/core_function.py @@ -62,8 +62,10 @@ def client_search_testing( } if q is None: - raise ValueError(f"Parameter '{param}' with type '{param_type}' was not " - "properly identified in the generic search method test.") + raise ValueError( + f"Parameter '{param}' with type '{param_type}' was not " + "properly identified in the generic search method test." + ) doc = search_method(**q)[0].model_dump() From e79f4de22eb90514376019ecec1e2c11eb56a937 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 16:27:28 -0800 Subject: [PATCH 31/34] Update task_ids arg --- mp_api/client/routes/materials/tasks.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index c1a8828a8..60b5f8466 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -24,9 +24,7 @@ def get_trajectory(self, task_id): """ traj_data = self._query_resource_data( suburl=f"trajectory/{task_id}/", use_document_model=False - )[0].get( - "trajectories", None - ) # type: ignore + )[0].get("trajectories", None) # type: ignore if traj_data is None: raise MPRestError(f"No trajectory data for {task_id} found") @@ -35,7 +33,7 @@ def get_trajectory(self, task_id): def search( self, - task_ids: list[str] | None = None, + task_ids: str | list[str] | None = None, chemsys: str | list[str] | None = None, elements: list[str] | None = None, exclude_elements: list[str] | None = None, @@ -49,7 +47,7 @@ def search( """Query core task docs using a variety of search criteria. Arguments: - task_ids (List[str]): List of Materials Project IDs to return data for. + task_ids (str, List[str]): List of Materials Project IDs to return data for. chemsys (str, List[str]): A chemical system or list of chemical systems (e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]). elements (List[str]): A list of elements. @@ -70,6 +68,9 @@ def search( query_params = {} # type: dict if task_ids: + if isinstance(task_ids, str): + task_ids = [task_ids] + query_params.update({"task_ids": ",".join(validate_ids(task_ids))}) if formula: From f5e80ea01f9b7f3713aeaaec6813ca64ec6ec85c Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 16:37:20 -0800 Subject: [PATCH 32/34] Fix remaining search methods --- mp_api/client/mprester.py | 10 +++++----- mp_api/client/routes/materials/eos.py | 11 ++++++++++- mp_api/client/routes/materials/grain_boundary.py | 12 ++++++++---- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index c357e2667..f19f4a3e4 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -942,9 +942,7 @@ def get_ion_reference_data(self) -> list[dict]: query={"project": "ion_ref_data"}, fields=["identifier", "formula", "data"], paginate=True, - ).get( - "data" - ) # type: ignore + ).get("data") # type: ignore def get_ion_reference_data_for_chemsys(self, chemsys: str | list) -> list[dict]: """Download aqueous ion reference data used in the construction of Pourbaix diagrams. @@ -1223,8 +1221,10 @@ def get_bandstructure_by_material_id( Returns: bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object """ - return self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore - material_id=material_id, path_type=path_type, line_mode=line_mode + return ( + self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore + material_id=material_id, path_type=path_type, line_mode=line_mode + ) ) def get_dos_by_material_id(self, material_id: str): diff --git a/mp_api/client/routes/materials/eos.py b/mp_api/client/routes/materials/eos.py index b0698ed56..d77dfd459 100644 --- a/mp_api/client/routes/materials/eos.py +++ b/mp_api/client/routes/materials/eos.py @@ -5,15 +5,17 @@ from emmet.core.eos import EOSDoc from mp_api.client.core import BaseRester +from mp_api.client.core.utils import validate_ids class EOSRester(BaseRester[EOSDoc]): suffix = "materials/eos" document_model = EOSDoc # type: ignore - primary_key = "task_id" + primary_key = "material_id" def search( self, + material_ids: str | list[str] | None = None, energies: tuple[float, float] | None = None, volumes: tuple[float, float] | None = None, num_chunks: int | None = None, @@ -24,6 +26,7 @@ def search( """Query equations of state docs using a variety of search criteria. Arguments: + material_ids (str, List[str]): Search for equation of states associated with the specified Material IDs energies (Tuple[float,float]): Minimum and maximum energy in eV/atom to consider for EOS plot range. volumes (Tuple[float,float]): Minimum and maximum volume in A³/atom to consider for EOS plot range. num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. @@ -37,6 +40,12 @@ def search( """ query_params = defaultdict(dict) # type: dict + if material_ids: + if isinstance(material_ids, str): + material_ids = [material_ids] + + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) + if volumes: query_params.update({"volumes_min": volumes[0], "volumes_max": volumes[1]}) diff --git a/mp_api/client/routes/materials/grain_boundary.py b/mp_api/client/routes/materials/grain_boundary.py index efeea2674..1c73ac970 100644 --- a/mp_api/client/routes/materials/grain_boundary.py +++ b/mp_api/client/routes/materials/grain_boundary.py @@ -9,16 +9,16 @@ class GrainBoundaryRester(BaseRester[GrainBoundaryDoc]): - suffix = "materials/grain_boundary" + suffix = "materials/grain_boundaries" document_model = GrainBoundaryDoc # type: ignore - primary_key = "task_id" + primary_key = "material_id" def search( self, + material_ids: str | list[str] | None = None, chemsys: str | None = None, gb_plane: list[str] | None = None, gb_energy: tuple[float, float] | None = None, - material_ids: list[str] | None = None, pretty_formula: str | None = None, rotation_axis: list[str] | None = None, rotation_angle: tuple[float, float] | None = None, @@ -33,6 +33,7 @@ def search( """Query grain boundary docs using a variety of search criteria. Arguments: + material_ids (str, List[str]): Search for grain boundary data associated with the specified Material IDs chemsys (str): Dash-delimited string of elements in the material. gb_plane(List[str]): The Miller index of grain boundary plane. e.g., [1, 1, 1] gb_energy (Tuple[float,float]): Minimum and maximum grain boundary energy in J/m³ to consider. @@ -56,7 +57,10 @@ def search( query_params = defaultdict(dict) # type: dict if material_ids: - query_params.update({"task_ids": ",".join(validate_ids(material_ids))}) + if isinstance(material_ids, str): + material_ids = [material_ids] + + query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) if gb_plane: query_params.update({"gb_plane": ",".join([str(n) for n in gb_plane])}) From 964272730c18ca182477f0a5e7b5e3e0f4874342 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 17:06:21 -0800 Subject: [PATCH 33/34] Fix thermo and tasks --- mp_api/client/routes/materials/tasks.py | 2 +- mp_api/client/routes/materials/thermo.py | 4 +++- tests/materials/test_surface_properties.py | 4 ++-- tests/materials/test_tasks.py | 3 ++- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 60b5f8466..78ca1d63f 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -23,7 +23,7 @@ def get_trajectory(self, task_id): """ traj_data = self._query_resource_data( - suburl=f"trajectory/{task_id}/", use_document_model=False + {"task_ids": [task_id]}, suburl=f"trajectory/", use_document_model=False )[0].get("trajectories", None) # type: ignore if traj_data is None: diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index e0cecd2bb..00ec686d5 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -164,9 +164,11 @@ def get_phase_diagram_from_chemsys( sorted_chemsys = "-".join(sorted(chemsys.split("-"))) phase_diagram_id = f"{sorted_chemsys}_{t_type}" + response = self._query_resource( + criteria={"phase_diagram_ids": phase_diagram_id}, fields=["phase_diagram"], - suburl=f"phase_diagram/{phase_diagram_id}", + suburl=f"phase_diagram/", use_document_model=False, num_chunks=1, chunk_size=1, diff --git a/tests/materials/test_surface_properties.py b/tests/materials/test_surface_properties.py index 22b11a59b..a4e72cb03 100644 --- a/tests/materials/test_surface_properties.py +++ b/tests/materials/test_surface_properties.py @@ -23,9 +23,9 @@ def rester(): sub_doc_fields = [] # type: list -alt_name_dict = {"surface_energy_anisotropy": "surface_anisotropy"} # type: dict +alt_name_dict = {"surface_energy_anisotropy": "surface_anisotropy", "material_ids": "material_id"} # type: dict -custom_field_tests = {} # type: dict +custom_field_tests = {"material_ids": ["mp-23152"]} # type: dict @pytest.mark.skipif(os.getenv("MP_API_KEY", None) is None, reason="No API key found.") diff --git a/tests/materials/test_tasks.py b/tests/materials/test_tasks.py index c87fc2a83..744399a47 100644 --- a/tests/materials/test_tasks.py +++ b/tests/materials/test_tasks.py @@ -1,6 +1,6 @@ import os from core_function import client_search_testing - +from datetime import datetime import pytest from mp_api.client.routes.materials.tasks import TaskRester @@ -34,6 +34,7 @@ def rester(): custom_field_tests = { "chemsys": "Si-O", + "last_updated": (None, datetime.utcnow()), "task_ids": ["mp-149"], } # type: dict From 893027c32e6bb82f08bc1147cbab4b379762bbbf Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 13 Feb 2024 17:08:18 -0800 Subject: [PATCH 34/34] Linting --- mp_api/client/mprester.py | 10 +++++----- mp_api/client/routes/materials/tasks.py | 6 ++++-- mp_api/client/routes/materials/thermo.py | 2 +- tests/materials/test_surface_properties.py | 5 ++++- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index f19f4a3e4..c357e2667 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -942,7 +942,9 @@ def get_ion_reference_data(self) -> list[dict]: query={"project": "ion_ref_data"}, fields=["identifier", "formula", "data"], paginate=True, - ).get("data") # type: ignore + ).get( + "data" + ) # type: ignore def get_ion_reference_data_for_chemsys(self, chemsys: str | list) -> list[dict]: """Download aqueous ion reference data used in the construction of Pourbaix diagrams. @@ -1221,10 +1223,8 @@ def get_bandstructure_by_material_id( Returns: bandstructure (Union[BandStructure, BandStructureSymmLine]): BandStructure or BandStructureSymmLine object """ - return ( - self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore - material_id=material_id, path_type=path_type, line_mode=line_mode - ) + return self.electronic_structure_bandstructure.get_bandstructure_from_material_id( # type: ignore + material_id=material_id, path_type=path_type, line_mode=line_mode ) def get_dos_by_material_id(self, material_id: str): diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 78ca1d63f..c7b860d6e 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -23,8 +23,10 @@ def get_trajectory(self, task_id): """ traj_data = self._query_resource_data( - {"task_ids": [task_id]}, suburl=f"trajectory/", use_document_model=False - )[0].get("trajectories", None) # type: ignore + {"task_ids": [task_id]}, suburl="trajectory/", use_document_model=False + )[0].get( + "trajectories", None + ) # type: ignore if traj_data is None: raise MPRestError(f"No trajectory data for {task_id} found") diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index 00ec686d5..b1b10beee 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -168,7 +168,7 @@ def get_phase_diagram_from_chemsys( response = self._query_resource( criteria={"phase_diagram_ids": phase_diagram_id}, fields=["phase_diagram"], - suburl=f"phase_diagram/", + suburl="phase_diagram/", use_document_model=False, num_chunks=1, chunk_size=1, diff --git a/tests/materials/test_surface_properties.py b/tests/materials/test_surface_properties.py index a4e72cb03..c34f8bcd1 100644 --- a/tests/materials/test_surface_properties.py +++ b/tests/materials/test_surface_properties.py @@ -23,7 +23,10 @@ def rester(): sub_doc_fields = [] # type: list -alt_name_dict = {"surface_energy_anisotropy": "surface_anisotropy", "material_ids": "material_id"} # type: dict +alt_name_dict = { + "surface_energy_anisotropy": "surface_anisotropy", + "material_ids": "material_id", +} # type: dict custom_field_tests = {"material_ids": ["mp-23152"]} # type: dict