Skip to content

Commit

Permalink
Use FutureOutput for lists of inputs of length 1. Upgraded typing syn…
Browse files Browse the repository at this point in the history
…tax to python 3.9. Removed TeraChem frontend keywords from settings (no longer needed).
  • Loading branch information
coltonbh committed Sep 12, 2024
1 parent 60e38f5 commit 3583934
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 28 deletions.
6 changes: 3 additions & 3 deletions chemcloud/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Union
from typing import Optional, Union

from . import __version__
from .config import settings
Expand Down Expand Up @@ -38,7 +38,7 @@ def __init__(
profile=profile,
chemcloud_domain=chemcloud_domain,
)
self._openapi_spec: Optional[Dict] = None
self._openapi_spec: Optional[dict] = None

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -67,7 +67,7 @@ def profile(self) -> str:
return self._client._profile

@property
def supported_programs(self) -> List[str]:
def supported_programs(self) -> list[str]:
"""Compute programs currently supported by ChemCloud.
Returns:
Expand Down
3 changes: 1 addition & 2 deletions chemcloud/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations # Can remove once we drop Python 3.9 support
from __future__ import annotations

from pathlib import Path
from typing import Optional
Expand All @@ -19,7 +19,6 @@ class Settings(BaseSettings):
chemcloud_domain: str = "https://chemcloud.mtzlab.com"
chemcloud_api_version_prefix: str = "/api/v2"
chemcloud_credentials_profile: str = "default"
tcfe_keywords: str = "tcfe:keywords"


settings = Settings()
20 changes: 10 additions & 10 deletions chemcloud/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from getpass import getpass
from pathlib import Path
from time import time
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Optional, Union

import httpx

Expand Down Expand Up @@ -172,9 +172,9 @@ def _request(
method: str,
route: str,
*,
headers: Optional[Dict[str, str]] = None,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[dict[str, str]] = None,
data: Optional[dict[str, Any]] = None,
params: Optional[dict[str, Any]] = None,
api_call: bool = True,
):
"""Make HTTP request"""
Expand Down Expand Up @@ -208,7 +208,7 @@ def _authenticated_request(self, method: str, route: str, **kwargs):

def _tokens_from_username_password(
self, username: str, password: str
) -> Tuple[str, str]:
) -> tuple[str, str]:
"""Exchanges username/password for access_token and refresh_token"""
data = {
"grant_type": "password",
Expand All @@ -227,7 +227,7 @@ def _tokens_from_username_password(

return response["access_token"], response["refresh_token"]

def _refresh_tokens(self, refresh_token: str) -> Tuple[str, str]:
def _refresh_tokens(self, refresh_token: str) -> tuple[str, str]:
"""Get new access and refresh tokens."""
data = {"grant_type": "refresh_token", "refresh_token": refresh_token}
headers = {"content-type": "application/x-www-form-urlencoded"}
Expand Down Expand Up @@ -260,7 +260,7 @@ def _expired_access_token(self, jwt: str) -> bool:
)

@staticmethod
def _decode_access_token(jwt: str) -> Dict[str, Any]:
def _decode_access_token(jwt: str) -> dict[str, Any]:
"""Decode jwt string and return dictionary of payload claims."""
payload = jwt.split(".")[1]
encoded_payload = payload.encode("ascii")
Expand All @@ -275,14 +275,14 @@ def _decode_access_token(jwt: str) -> Dict[str, Any]:
return json.loads(json_string)

def _result_id_to_future_result(self, input_data, result_id):
if isinstance(input_data, list):
if isinstance(input_data, list) and len(input_data) > 1:
return FutureOutputGroup(task_id=result_id, client=self)
return FutureOutput(task_id=result_id, client=self)

def compute(
self,
inp_obj: QCIOInputsOrList,
params: Optional[Dict[str, Any]] = None,
params: Optional[dict[str, Any]] = None,
) -> Union[FutureOutput, FutureOutputGroup]:
"""Submit a computation to ChemCloud"""
result_id = self._authenticated_request(
Expand All @@ -296,7 +296,7 @@ def compute(
def output(
self,
task_id: str,
) -> Tuple[str, Union[Optional[Any], Optional[List[Any]]]]:
) -> tuple[str, Union[Optional[Any], Optional[list[Any]]]]:
"""Check the output of a compute job, returns status and output (if available).
Parameters:
Expand Down
16 changes: 8 additions & 8 deletions chemcloud/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from pathlib import Path
from time import sleep, time
from typing import Any, List, Optional, Type, Union
from typing import Any, Optional, Union

from pydantic import field_validator
from pydantic.main import BaseModel
Expand All @@ -15,9 +15,9 @@

# Convenience types
QCIOInputs: TypeAlias = Union[ProgramInput, FileInput, DualProgramInput]
QCIOInputsOrList: TypeAlias = Union[QCIOInputs, List[QCIOInputs]]
QCIOInputsOrList: TypeAlias = Union[QCIOInputs, list[QCIOInputs]]
QCIOOutputs: TypeAlias = ProgramOutput
QCIOOutputsOrList: TypeAlias = Union[QCIOOutputs, List[QCIOOutputs]]
QCIOOutputsOrList: TypeAlias = Union[QCIOOutputs, list[QCIOOutputs]]


class TaskStatus(str, Enum):
Expand Down Expand Up @@ -121,7 +121,7 @@ class FutureOutput(FutureOutputBase):
class FutureOutputGroup(FutureOutputBase):
"""Group computation result"""

result: Optional[List[QCIOOutputs]] = None
result: Optional[list[QCIOOutputs]] = None

def _output(self):
"""Return result from server. Remove GROUP_ID_PREFIX from id."""
Expand All @@ -142,7 +142,7 @@ def validate_id(cls, val):


def to_file(
future_results: Union[FutureOutputBase, List[FutureOutputBase]],
future_results: Union[FutureOutputBase, list[FutureOutputBase]],
path: Union[str, Path],
*,
append: bool = False,
Expand All @@ -164,18 +164,18 @@ def to_file(
def from_file(
path: Union[str, Path],
client: Any,
) -> List[Union[FutureOutput, FutureOutputGroup]]:
) -> list[Union[FutureOutput, FutureOutputGroup]]:
"""Instantiate FutureOutputs or FutureOutputGroups from file of result ids
Params:
path: Path to file containing the ids
client: Instantiated CCClient object
"""
frs: List[Union[FutureOutput, FutureOutputGroup]] = []
frs: list[Union[FutureOutput, FutureOutputGroup]] = []
with open(path) as f:
for id in f.readlines():
id = id.strip()
model: Union[Type[FutureOutput], Type[FutureOutputGroup]]
model: Union[type[FutureOutput], type[FutureOutputGroup]]
if id.startswith(GROUP_ID_PREFIX):
model = FutureOutputGroup
else:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ target-version = "py39"

[tool.ruff.lint]
isort = { known-first-party = ["tests"] }
select = ["I"]
select = ["I", "F401"]

[tool.coverage.run]
branch = true
Expand Down
4 changes: 2 additions & 2 deletions scripts/release.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def get_repo_url():
"""Get the repository URL from pyproject.toml or ask the user for it."""
try:
with open("pyproject.toml", "r") as file:
with open("pyproject.toml") as file:
pyproject = toml.load(file)
repo_url = pyproject["tool"]["poetry"]["repository"]
return repo_url
Expand All @@ -25,7 +25,7 @@ def update_version_with_poetry(version):
def update_changelog(version, repo_url):
"""Update the CHANGELOG.md file with the new version and today's date."""
print("Updating CHANGELOG.md...")
with open("docs/CHANGELOG.md", "r") as file:
with open("docs/CHANGELOG.md") as file:
lines = file.readlines()

today = datetime.today().strftime("%Y-%m-%d")
Expand Down
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from base64 import b64encode
from pathlib import Path
from time import time
from typing import Dict

import pytest
import tomli_w
Expand All @@ -13,7 +12,7 @@
from chemcloud.config import Settings


def _jwt_from_payload(payload: Dict[str, str]) -> str:
def _jwt_from_payload(payload: dict[str, str]) -> str:
"""Convert payload to fake JWT"""
b64_encoded_access_token = b64encode(json.dumps(payload).encode("utf-8")).decode(
"utf-8"
Expand Down

0 comments on commit 3583934

Please sign in to comment.