Skip to content

Commit

Permalink
Merge pull request #1248 from expectedparrot/fix_remote_inference
Browse files Browse the repository at this point in the history
Auth flow, model endpoint, and results visibility
  • Loading branch information
apostolosfilippas authored Nov 8, 2024
2 parents 4036a72 + ce21ce4 commit 85cd75f
Show file tree
Hide file tree
Showing 13 changed files with 402 additions and 59 deletions.
14 changes: 11 additions & 3 deletions edsl/conjure/AgentConstructionMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def to_results(
sample_size: int = None,
seed: str = "edsl",
dryrun=False,
disable_remote_cache: bool = False,
disable_remote_inference: bool = False,
) -> Union[Results, None]:
"""Return the results of the survey.
Expand All @@ -109,7 +111,7 @@ def to_results(
>>> from edsl.conjure.InputData import InputDataABC
>>> id = InputDataABC.example()
>>> r = id.to_results()
>>> r = id.to_results(disable_remote_cache = True, disable_remote_inference = True)
>>> len(r) == id.num_observations
True
"""
Expand All @@ -125,7 +127,10 @@ def to_results(
import time

start = time.time()
_ = survey.by(agent_list.sample(DRYRUN_SAMPLE)).run()
_ = survey.by(agent_list.sample(DRYRUN_SAMPLE)).run(
disable_remote_cache=disable_remote_cache,
disable_remote_inference=disable_remote_inference,
)
end = time.time()
print(f"Time to run {DRYRUN_SAMPLE} agents (s): {round(end - start, 2)}")
time_per_agent = (end - start) / DRYRUN_SAMPLE
Expand All @@ -143,7 +148,10 @@ def to_results(
f"Full sample will take about {round(full_sample_time / 3600, 2)} hours."
)
return None
return survey.by(agent_list).run()
return survey.by(agent_list).run(
disable_remote_cache=disable_remote_cache,
disable_remote_inference=disable_remote_inference,
)


if __name__ == "__main__":
Expand Down
97 changes: 96 additions & 1 deletion edsl/coop/coop.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,77 @@ def _resolve_server_response(self, response: requests.Response) -> None:
if response.status_code >= 400:
message = response.json().get("detail")
# print(response.text)
if "Authorization" in message:
if "The API key you provided is invalid" in message:
import secrets
from edsl.utilities.utilities import write_api_key_to_env

edsl_auth_token = secrets.token_urlsafe(16)

print("Your Expected Parrot API key is invalid.")
print(
"\nUse the link below to log in to Expected Parrot so we can automatically update your API key."
)
print(
f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}\n"
)
api_key = self._poll_for_api_key(edsl_auth_token)

if api_key is None:
print("\nTimed out waiting for login. Please try again.")
return

write_api_key_to_env(api_key)
print("\n✨ API key retrieved and written to .env file.")
print("Rerun your code to try again with a valid API key.")
return

elif "Authorization" in message:
print(message)
message = "Please provide an Expected Parrot API key."

raise CoopServerResponseError(message)

def _poll_for_api_key(
self, edsl_auth_token: str, timeout: int = 120
) -> Union[str, None]:
"""
Allows the user to retrieve their Expected Parrot API key by logging in with an EDSL auth token.
:param edsl_auth_token: The EDSL auth token to use for login
:param timeout: Maximum time to wait for login, in seconds (default: 120)
"""
import time
from datetime import datetime

start_poll_time = time.time()
waiting_for_login = True
while waiting_for_login:

elapsed_time = time.time() - start_poll_time
if elapsed_time > timeout:
# Timed out waiting for the user to log in
print("\r" + " " * 80 + "\r", end="")
return None

api_key = self._get_api_key(edsl_auth_token)
if api_key is not None:
print("\r" + " " * 80 + "\r", end="")
return api_key
else:
duration = 5
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
start_time = time.time()
i = 0
while time.time() - start_time < duration:
print(
f"\r{frames[i % len(frames)]} Waiting for login. Last checked: {time_checked}",
end="",
flush=True,
)
time.sleep(0.1)
i += 1

def _json_handle_none(self, value: Any) -> Any:
"""
Handle None values during JSON serialization.
Expand Down Expand Up @@ -489,6 +555,7 @@ def remote_inference_create(
description: Optional[str] = None,
status: RemoteJobStatus = "queued",
visibility: Optional[VisibilityType] = "unlisted",
initial_results_visibility: Optional[VisibilityType] = "unlisted",
iterations: Optional[int] = 1,
) -> dict:
"""
Expand Down Expand Up @@ -517,6 +584,7 @@ def remote_inference_create(
"iterations": iterations,
"visibility": visibility,
"version": self._edsl_version,
"initial_results_visibility": initial_results_visibility,
},
)
self._resolve_server_response(response)
Expand Down Expand Up @@ -664,6 +732,17 @@ def fetch_prices(self) -> dict:
else:
return {}

def fetch_models(self) -> dict:
"""
Fetch a dict of available models from Coop.
Each key in the dict is an inference service, and each value is a list of models from that service.
"""
response = self._send_server_request(uri="api/v0/models", method="GET")
self._resolve_server_response(response)
data = response.json()
return data

def fetch_rate_limit_config_vars(self) -> dict:
"""
Fetch a dict of rate limit config vars from Coop.
Expand All @@ -678,6 +757,22 @@ def fetch_rate_limit_config_vars(self) -> dict:
data = response.json()
return data

def _get_api_key(self, edsl_auth_token: str):
"""
Given an EDSL auth token, find the corresponding user's API key.
"""

response = self._send_server_request(
uri="api/v0/get-api-key",
method="POST",
payload={
"edsl_auth_token": edsl_auth_token,
},
)
data = response.json()
api_key = data.get("api_key")
return api_key


if __name__ == "__main__":
sheet_data = fetch_sheet_data()
Expand Down
2 changes: 1 addition & 1 deletion edsl/data/Cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def store(
>>> c = Cache()
>>> len(c)
0
>>> results = Question.example("free_text").by(m).run(cache = c)
>>> results = Question.example("free_text").by(m).run(cache = c, disable_remote_cache = True, disable_remote_inference = True)
>>> len(c)
1
"""
Expand Down
41 changes: 32 additions & 9 deletions edsl/inference_services/InferenceServicesCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,48 @@ def add_model(cls, service_name, model_name):

@staticmethod
def _get_service_available(service, warn: bool = False) -> list[str]:
from_api = True
try:
service_models = service.available()
except Exception as e:
except Exception:
if warn:
warnings.warn(
f"""Error getting models for {service._inference_service_}.
Check that you have properly stored your Expected Parrot API key and activated remote inference, or stored your own API keys for the language models that you want to use.
See https://docs.expectedparrot.com/en/latest/api_keys.html for instructions on storing API keys.
Relying on cache.""",
Relying on Coop.""",
UserWarning,
)
from edsl.inference_services.models_available_cache import models_available

service_models = models_available.get(service._inference_service_, [])
# cache results
service._models_list_cache = service_models
from_api = False
return service_models # , from_api
# Use the list of models on Coop as a fallback
try:
from edsl import Coop

c = Coop()
models_from_coop = c.fetch_models()
service_models = models_from_coop.get(service._inference_service_, [])

# cache results
service._models_list_cache = service_models

# Finally, use the available models cache from the Python file
except Exception:
if warn:
warnings.warn(
f"""Error getting models for {service._inference_service_}.
Relying on EDSL cache.""",
UserWarning,
)

from edsl.inference_services.models_available_cache import (
models_available,
)

service_models = models_available.get(service._inference_service_, [])

# cache results
service._models_list_cache = service_models

return service_models

def available(self):
total_models = []
Expand Down
Loading

0 comments on commit 85cd75f

Please sign in to comment.