Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPRester lazily get endpoint and api_key #936

Merged
merged 25 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b076a8e
get env var lazily
DanielYang59 Sep 28, 2024
dee31a0
remove seemingly unused deprecation warn
DanielYang59 Sep 28, 2024
48fbc54
remove unused ignore tag
DanielYang59 Sep 28, 2024
b7b3d19
fix URL case
DanielYang59 Sep 28, 2024
1214f2e
access self.endpoint for updated entry
DanielYang59 Sep 28, 2024
c1fa506
only patch api key env var in CI env
DanielYang59 Sep 28, 2024
b11178c
add unit test for lazy mp api key
DanielYang59 Sep 28, 2024
d0afe41
remove skip decorator
DanielYang59 Sep 28, 2024
5540253
also check endpoint
DanielYang59 Sep 28, 2024
479a9ba
add more tests for endpoint
DanielYang59 Sep 28, 2024
c31ef7d
don't patch api_key and recover skip mark
DanielYang59 Sep 28, 2024
b3ff75a
avoid duplicate endpoint
DanielYang59 Sep 28, 2024
f740843
also test default and invalid api key
DanielYang59 Sep 28, 2024
2232e5a
os.environ.get -> os.getenv
DanielYang59 Sep 30, 2024
5648026
BaseRester also get lazily
DanielYang59 Sep 30, 2024
1b7510d
make sure self.endpoint is set
tschaume Oct 2, 2024
7edcc7f
remove duplicated pytest skip mark
DanielYang59 Oct 2, 2024
e77077b
turn off fail-fast
DanielYang59 Oct 2, 2024
38b149e
NEED CONFIRM: filter get_data_by_id deprecation warning
DanielYang59 Oct 2, 2024
0f55062
cleanup
tschaume Oct 2, 2024
b466c31
linting
tschaume Oct 2, 2024
284b18d
try upper casing secrets name
tschaume Oct 2, 2024
66f2355
use tr for uppercase
tschaume Oct 2, 2024
6435643
test api key in header
tschaume Oct 2, 2024
63dd563
Revert "test api key in header"
tschaume Oct 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ jobs:
- name: Format API key name (Linux/MacOS)
if: matrix.os == 'ubuntu-latest' || matrix.os == 'macos-latest'
run: |
echo "API_KEY_NAME=$(echo ${{ format('MP_API_KEY_{0}_{1}', matrix.os, matrix.python-version) }} | awk '{gsub(/-|\./, "_"); print}')" >> $GITHUB_ENV
echo "API_KEY_NAME=$(echo ${{ format('MP_API_KEY_{0}_{1}', matrix.os, matrix.python-version) }} | awk '{gsub(/-|\./, "_"); print}' | tr '[:lower:]' '[:upper:]')" >> $GITHUB_ENV

- name: Format API key name (Windows)
if: matrix.os == 'windows-latest'
run: |
echo "API_KEY_NAME=$(echo ${{ format('MP_API_KEY_{0}_{1}', matrix.os, matrix.python-version) }} | awk '{gsub(/-|\./, "_"); print}')" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
echo "API_KEY_NAME=$(echo ${{ format('MP_API_KEY_{0}_{1}', matrix.os, matrix.python-version) }} | awk '{gsub(/-|\./, "_"); print}' | tr '[:lower:]' '[:upper:]')" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append

- name: Test with pytest
env:
Expand Down
24 changes: 12 additions & 12 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from importlib.metadata import PackageNotFoundError, version
from json import JSONDecodeError
from math import ceil
from typing import Any, Callable, Generic, TypeVar
from typing import TYPE_CHECKING, Generic, TypeVar
from urllib.parse import quote, urljoin

import requests
Expand All @@ -42,18 +42,16 @@
except ImportError:
boto3 = None

if TYPE_CHECKING:
from typing import Any, Callable

try:
__version__ = version("mp_api")
except PackageNotFoundError: # pragma: no cover
__version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION")

# TODO: think about how to migrate from PMG_MAPI_KEY
DEFAULT_API_KEY = os.environ.get("MP_API_KEY", None)
DEFAULT_ENDPOINT = os.environ.get(
"MP_API_ENDPOINT", "https://api.materialsproject.org/"
)

settings = MAPIClientSettings() # type: ignore
SETTINGS = MAPIClientSettings() # type: ignore

T = TypeVar("T")

Expand All @@ -69,7 +67,7 @@ class BaseRester(Generic[T]):
def __init__(
self,
api_key: str | None = None,
endpoint: str = DEFAULT_ENDPOINT,
endpoint: str | None = None,
include_user_agent: bool = True,
session: requests.Session | None = None,
s3_client: Any | None = None,
Expand All @@ -78,7 +76,7 @@ def __init__(
use_document_model: bool = True,
timeout: int = 20,
headers: dict | None = None,
mute_progress_bars: bool = settings.MUTE_PROGRESS_BARS,
mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS,
):
"""Initialize the REST API helper class.

Expand Down Expand Up @@ -111,9 +109,11 @@ def __init__(
headers: Custom headers for localhost connections.
mute_progress_bars: Whether to disable progress bars.
"""
self.api_key = api_key or DEFAULT_API_KEY
self.base_endpoint = endpoint
self.endpoint = endpoint
# TODO: think about how to migrate from PMG_MAPI_KEY
self.api_key = api_key or os.getenv("MP_API_KEY")
self.base_endpoint = self.endpoint = endpoint or os.getenv(
"MP_API_ENDPOINT", "https://api.materialsproject.org/"
)
self.debug = debug
self.include_user_agent = include_user_agent
self.monty_decode = monty_decode
Expand Down
35 changes: 15 additions & 20 deletions mp_api/client/mprester.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
from functools import cache, lru_cache
from json import loads
from typing import Literal
from typing import TYPE_CHECKING

from emmet.core.electronic_structure import BSPathType
from emmet.core.mpid import MPID
Expand Down Expand Up @@ -60,19 +60,12 @@
from mp_api.client.routes.materials.materials import MaterialsRester
from mp_api.client.routes.molecules import MoleculeRester

_DEPRECATION_WARNING = (
tschaume marked this conversation as resolved.
Show resolved Hide resolved
"MPRester is being modernized. Please use the new method suggested and "
"read more about these changes at https://docs.materialsproject.org/api. The current "
"methods will be retained until at least January 2022 for backwards compatibility."
)
if TYPE_CHECKING:
from typing import Literal

_EMMET_SETTINGS = EmmetSettings() # type: ignore
_MAPI_SETTINGS = MAPIClientSettings() # typeL ignore # type: ignore

DEFAULT_API_KEY = os.environ.get("MP_API_KEY", None)
DEFAULT_ENDPOINT = os.environ.get(
"MP_API_ENDPOINT", "https://api.materialsproject.org/"
)
_EMMET_SETTINGS = EmmetSettings()
_MAPI_SETTINGS = MAPIClientSettings()


class MPRester:
Expand Down Expand Up @@ -124,7 +117,7 @@ class MPRester:
def __init__(
self,
api_key: str | None = None,
endpoint: str = DEFAULT_ENDPOINT,
endpoint: str | None = None,
notify_db_version: bool = False,
include_user_agent: bool = True,
monty_decode: bool = True,
Expand All @@ -143,10 +136,10 @@ def __init__(
If so, it will use that environment variable. This makes
easier for heavy users to simply add this environment variable to
their setups and MPRester can then be called without any arguments.
endpoint (str): Url of endpoint to access the MaterialsProject REST
endpoint (str): URL of endpoint to access the MaterialsProject REST
interface. Defaults to the standard Materials Project REST
address at "https://api.materialsproject.org", but
can be changed to other urls implementing a similar interface.
can be changed to other URLs implementing a similar interface.
notify_db_version (bool): If True, the current MP database version will
be retrieved and logged locally in the ~/.mprester.log.yaml. If the database
version changes, you will be notified. The current database version is
Expand All @@ -169,7 +162,7 @@ def __init__(

"""
# SETTINGS tries to read API key from ~/.config/.pmgrc.yaml
api_key = api_key or DEFAULT_API_KEY or SETTINGS.get("PMG_MAPI_KEY")
api_key = api_key or os.getenv("MP_API_KEY") or SETTINGS.get("PMG_MAPI_KEY")

if api_key and len(api_key) != 32:
raise ValueError(
Expand All @@ -179,7 +172,9 @@ def __init__(
)

self.api_key = api_key
self.endpoint = endpoint
self.endpoint = endpoint or os.getenv(
"MP_API_ENDPOINT", "https://api.materialsproject.org/"
)
self.headers = headers or {}
self.session = session or BaseRester._create_session(
api_key=self.api_key,
Expand Down Expand Up @@ -257,7 +252,7 @@ def __init__(
core_resters = {
cls.suffix.split("/")[0]: cls(
api_key=api_key,
endpoint=endpoint,
endpoint=self.endpoint,
include_user_agent=include_user_agent,
session=self.session,
monty_decode=monty_decode,
Expand All @@ -280,7 +275,7 @@ def __init__(
if len(suffix_split) == 1:
rester = cls(
api_key=api_key,
endpoint=endpoint,
endpoint=self.endpoint,
include_user_agent=include_user_agent,
session=self.session,
monty_decode=monty_decode
Expand Down Expand Up @@ -310,7 +305,7 @@ def __core_custom_getattr(_self, _attr, _rester_map):
cls = _rester_map[_attr]
rester = cls(
api_key=api_key,
endpoint=endpoint,
endpoint=self.endpoint,
include_user_agent=include_user_agent,
session=self.session,
monty_decode=monty_decode
Expand Down
4 changes: 1 addition & 3 deletions tests/materials/test_materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def rester():
} # type: dict


@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
def test_client(rester):
search_method = rester.search

Expand Down
4 changes: 1 addition & 3 deletions tests/materials/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@
} # type: dict


@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
def test_client():
search_method = SummaryRester().search

Expand Down
4 changes: 1 addition & 3 deletions tests/materials/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def rester():
} # type: dict


@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
def test_client(rester):
search_method = rester.search

Expand Down
4 changes: 1 addition & 3 deletions tests/materials/test_thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def rester():
} # type: dict


@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
def test_client(rester):
search_method = rester.search

Expand Down
4 changes: 1 addition & 3 deletions tests/materials/test_xas.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def rester():


@pytest.mark.skip(reason="Temp skip until timeout update.")
@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
def test_client(rester):
search_method = rester.search

Expand Down
4 changes: 1 addition & 3 deletions tests/molecules/test_jcesr.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def rester():
} # type: dict


@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
def test_client(rester):
search_method = rester.search

Expand Down
4 changes: 1 addition & 3 deletions tests/molecules/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@


@pytest.mark.skip(reason="Temporary until data adjustments")
@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
def test_client():
search_method = MoleculesSummaryRester().search

Expand Down
14 changes: 4 additions & 10 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings

import pytest

Expand Down Expand Up @@ -49,9 +50,7 @@
]


@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.getenv("MP_API_KEY") is None, reason="No API key found.")
@pytest.mark.parametrize("rester", resters_to_test)
def test_generic_get_methods(rester):
# -- Test generic search and get_data_by_id methods
Expand All @@ -61,9 +60,8 @@ def test_generic_get_methods(rester):
endpoint=mpr.endpoint,
include_user_agent=True,
session=mpr.session,
monty_decode=True
if rester not in [TaskRester, ProvenanceRester] # type: ignore
else False, # Disable monty decode on nested data which may give errors
# Disable monty decode on nested data which may give errors
monty_decode=rester not in [TaskRester, ProvenanceRester],
use_document_model=True,
)

Expand All @@ -85,7 +83,3 @@ def test_generic_get_methods(rester):
key_only_resters[name], fields=[rester.primary_key]
)
assert isinstance(doc, rester.document_model)


if os.getenv("MP_API_KEY", None) is None:
pytest.mark.skip(test_generic_get_methods)
Loading
Loading