generated from kyegomez/Python-Package-Template
-
-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kye
committed
Dec 8, 2023
1 parent
ef919cb
commit a0c4d4f
Showing
7 changed files
with
340 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,3 @@ | ||
DEBUG:jaxlib.mlir._mlir_libs:Initializing MLIR with module: _site_initialize_0 | ||
Initializing MLIR with module: _site_initialize_0 | ||
DEBUG:jaxlib.mlir._mlir_libs:Registering dialects from initializer <module 'jaxlib.mlir._mlir_libs._site_initialize_0' from '/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/jaxlib/mlir/_mlir_libs/_site_initialize_0.so'> | ||
Registering dialects from initializer <module 'jaxlib.mlir._mlir_libs._site_initialize_0' from '/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/jaxlib/mlir/_mlir_libs/_site_initialize_0.so'> | ||
DEBUG:jax._src.path:etils.epath was not found. Using pathlib for file I/O. | ||
etils.epath was not found. Using pathlib for file I/O. | ||
INFO:swarms_cloud.main:Registering POST endpoint at /agent | ||
Registering POST endpoint at /agent | ||
[32mINFO[0m: Started server process [[36m28866[0m] | ||
[32mINFO[0m: Waiting for application startup. | ||
[32mINFO[0m: Application startup complete. | ||
[32mINFO[0m: Uvicorn running on [1mhttp://0.0.0.0:8000[0m (Press CTRL+C to quit) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import requests | ||
import json | ||
import os | ||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
|
||
# Get the API key from the environment | ||
api_key = os.getenv("OPENAI_API_KEY") | ||
|
||
# Define the API endpoint | ||
url = "http://localhost:8000/agent" | ||
|
||
# Define the input parameters for the agent | ||
agent_parameters = { | ||
"temperature": 0.5, | ||
"model_name": "gpt-4", | ||
"openai_api_key": api_key, | ||
"max_loops": 1, | ||
"autosave": True, | ||
"dashboard": True, | ||
} | ||
|
||
# Define the task for the agent | ||
task = "Generate a 10,000 word blog on health and wellness." | ||
|
||
# Define the payload for the POST request | ||
payload = { | ||
"task": task, | ||
"parameters": agent_parameters, | ||
"args": [], # Add your args here | ||
"kwargs": {}, # Add your kwargs here | ||
} | ||
|
||
# Send the POST request to the API | ||
response = requests.post(url, data=json.dumps(payload)) | ||
|
||
# Print the response from the API | ||
print(response.json()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,15 +3,15 @@ requires = ["poetry-core>=1.0.0"] | |
build-backend = "poetry.core.masonry.api" | ||
|
||
[tool.poetry] | ||
name = "paper" | ||
name = "swarms-cloud" | ||
version = "0.0.1" | ||
description = "Paper - Pytorch" | ||
description = "Swarms Cloud - Pytorch" | ||
license = "MIT" | ||
authors = ["Kye Gomez <[email protected]>"] | ||
homepage = "https://github.com/kyegomez/paper" | ||
documentation = "" # Add this if you have documentation. | ||
homepage = "https://github.com/kyegomez/swarms-cloud" | ||
documentation = "https://github.com/kyegomez/swarms-cloud" # Add this if you have documentation. | ||
readme = "README.md" # Assuming you have a README.md | ||
repository = "https://github.com/kyegomez/paper" | ||
repository = "https://github.com/kyegomez/swarms-cloud" | ||
keywords = ["artificial intelligence", "deep learning", "optimizers", "Prompt Engineering"] | ||
classifiers = [ | ||
"Development Status :: 4 - Beta", | ||
|
@@ -23,6 +23,10 @@ classifiers = [ | |
|
||
[tool.poetry.dependencies] | ||
python = "^3.9" | ||
swarms = "*" | ||
fastapi = "*" | ||
skypilot = "*" | ||
|
||
|
||
|
||
[tool.poetry.dev-dependencies] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
swarms | ||
sky | ||
skypilot | ||
fastapi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,63 +1,109 @@ | ||
import logging | ||
from typing import Callable, Type, Optional | ||
import pytest | ||
from unittest.mock import MagicMock | ||
|
||
from fastapi import FastAPI, HTTPException | ||
from swarms.structs.agent import Agent | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
|
||
def agent_api_wrapper( | ||
agent_class: Type[Agent], | ||
app: FastAPI, | ||
path: Optional[str] = None, | ||
http_method: Optional[str] = "get", | ||
logging: bool = False, | ||
*args, | ||
**kwargs, | ||
): | ||
"""Expose agent methods as API endpoints | ||
Args: | ||
agent_class (Type[Agent]): _description_ | ||
app (FastAPI): _description_ | ||
path (str): _description_ | ||
http_method (str, optional): _description_. Defaults to "get". | ||
Example: | ||
>>> from swarms.agents import Agent | ||
>>> from fastapi import FastAPI | ||
>>> app = FastAPI() | ||
>>> @agent_api_wrapper(Agent, app, "/agent") | ||
... def agent_method(): | ||
... return "Hello World" | ||
""" | ||
|
||
def decorator(func: Callable): | ||
async def endpoint_wrapper(*args, **kwargs): | ||
try: | ||
logger.info( | ||
f"Creating instance of {agent_class.__name__} with args: {args} and kwargs: {kwargs}" | ||
) | ||
agent_instance = agent_class(*args, **kwargs) | ||
logger.info(f"Calling method {func.__name__} of {agent_class.__name__}") | ||
result = getattr(agent_instance, func.__name__)() | ||
logger.info( | ||
f"Method {func.__name__} of {agent_class.__name__} returned: {result}" | ||
) | ||
return result | ||
except Exception as error: | ||
logger.error(f"An error occurred: {str(error)}") | ||
raise HTTPException(status_code=500, detail=str(error)) | ||
|
||
if http_method.lower() == "get": | ||
logger.info(f"Registering GET endpoint at {path}") | ||
app.get(path)(endpoint_wrapper) | ||
elif http_method.lower() == "post": | ||
logger.info(f"Registering POST endpoint at {path}") | ||
app.post(path)(endpoint_wrapper) | ||
return func | ||
|
||
return decorator | ||
from swarms_cloud.main import agent_api_wrapper | ||
|
||
|
||
class MockAgent(Agent): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.method_called = False | ||
|
||
def mock_method(self): | ||
self.method_called = True | ||
return "Hello World" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"http_method", | ||
["get", "post"], | ||
) | ||
def test_decorator_registers_endpoint(http_method): | ||
# Arrange | ||
app = FastAPI() | ||
path = "/agent" | ||
|
||
# Act | ||
@agent_api_wrapper(MockAgent, app, path, http_method=http_method) | ||
def mock_method(): | ||
return "Hello World" | ||
|
||
# Assert | ||
if http_method == "get": | ||
assert "GET" in app.routes | ||
else: | ||
assert "POST" in app.routes | ||
|
||
assert path in app.routes | ||
|
||
|
||
def test_decorator_logs_correctly(): | ||
# Arrange | ||
app = FastAPI() | ||
path = "/agent" | ||
logger = MagicMock() | ||
|
||
# Act | ||
@agent_api_wrapper(MockAgent, app, path, logging=True) | ||
def mock_method(): | ||
return "Hello World" | ||
|
||
# Assert | ||
assert logger.info.call_count == 3 | ||
assert logger.error.call_count == 0 | ||
|
||
|
||
def test_endpoint_wrapper_calls_agent_method(): | ||
# Arrange | ||
app = FastAPI() | ||
path = "/agent" | ||
agent_instance = MockAgent() | ||
|
||
# Act | ||
@agent_api_wrapper(MockAgent, app, path) | ||
def mock_method(): | ||
return "Hello World" | ||
|
||
# Assert | ||
assert agent_instance.mock_method.called | ||
|
||
|
||
def test_endpoint_wrapper_returns_result(): | ||
# Arrange | ||
app = FastAPI() | ||
path = "/agent" | ||
agent_instance = MockAgent() | ||
expected_result = "Hello World" | ||
|
||
# Act | ||
@agent_api_wrapper(MockAgent, app, path) | ||
def mock_method(): | ||
return expected_result | ||
|
||
# Assert | ||
response = app.get(path) | ||
assert response.status_code == 200 | ||
assert response.json() == expected_result | ||
|
||
|
||
def test_endpoint_wrapper_raises_exception(): | ||
# Arrange | ||
app = FastAPI() | ||
path = "/agent" | ||
agent_instance = MockAgent() | ||
expected_error = Exception("Test error") | ||
|
||
# Act | ||
@agent_api_wrapper(MockAgent, app, path) | ||
def mock_method(): | ||
raise expected_error | ||
|
||
# Assert | ||
with pytest.raises(HTTPException) as error: | ||
app.get(path) | ||
|
||
assert error.value.status_code == 500 | ||
assert error.value.detail == str(expected_error) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import pytest | ||
from fastapi import FastAPI | ||
from fastapi.testclient import TestClient | ||
from swarms.structs.agent import Agent | ||
from swarms_cloud.main import agent_api_wrapper | ||
|
||
|
||
@pytest.fixture | ||
def app(): | ||
return FastAPI() | ||
|
||
|
||
@pytest.fixture | ||
def client(app): | ||
return TestClient(app) | ||
|
||
|
||
def test_get_endpoint_registration(app): | ||
@agent_api_wrapper(Agent, app, path="/test-get", http_method="get") | ||
def dummy_method(): | ||
pass | ||
|
||
assert "/test-get" in [ | ||
route.path for route in app.routes if route.methods == {"GET"} | ||
] | ||
|
||
|
||
def test_post_endpoint_registration(app): | ||
@agent_api_wrapper(Agent, app, path="/test-post", http_method="post") | ||
def dummy_method(): | ||
pass | ||
|
||
assert "/test-post" in [ | ||
route.path for route in app.routes if route.methods == {"POST"} | ||
] | ||
|
||
|
||
def test_agent_instantiation_with_args(app, monkeypatch): | ||
test_args = ("arg1", "arg2") | ||
|
||
def mock_init(self, *args, **kwargs): | ||
assert args == test_args | ||
|
||
monkeypatch.setattr(Agent, "__init__", mock_init) | ||
|
||
@agent_api_wrapper(Agent, app, path="/test-args", http_method="get") | ||
def dummy_method(): | ||
pass | ||
|
||
client(app).get("/test-args", params={"args": test_args}) | ||
|
||
|
||
def test_successful_method_execution(app, monkeypatch): | ||
def mock_method(self): | ||
return "success" | ||
|
||
monkeypatch.setattr(Agent, "dummy_method", mock_method) | ||
|
||
@agent_api_wrapper(Agent, app, path="/test-success", http_method="get") | ||
def dummy_method(): | ||
pass | ||
|
||
response = client(app).get("/test-success") | ||
assert response.text == "success" | ||
|
||
|
||
def test_http_exception_on_failure(app, monkeypatch): | ||
def mock_method(self): | ||
raise ValueError("error") | ||
|
||
monkeypatch.setattr(Agent, "dummy_method", mock_method) | ||
|
||
@agent_api_wrapper(Agent, app, path="/test-exception", http_method="get") | ||
def dummy_method(): | ||
pass | ||
|
||
response = client(app).get("/test-exception") | ||
assert response.status_code == 500 | ||
assert "error" in response.text |
Oops, something went wrong.