Skip to content

Commit

Permalink
[TESTS for agent_api_wrapper]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 8, 2023
1 parent ef919cb commit a0c4d4f
Show file tree
Hide file tree
Showing 7 changed files with 340 additions and 74 deletions.
9 changes: 0 additions & 9 deletions errors.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,3 @@
DEBUG:jaxlib.mlir._mlir_libs:Initializing MLIR with module: _site_initialize_0
Initializing MLIR with module: _site_initialize_0
DEBUG:jaxlib.mlir._mlir_libs:Registering dialects from initializer <module 'jaxlib.mlir._mlir_libs._site_initialize_0' from '/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/jaxlib/mlir/_mlir_libs/_site_initialize_0.so'>
Registering dialects from initializer <module 'jaxlib.mlir._mlir_libs._site_initialize_0' from '/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/jaxlib/mlir/_mlir_libs/_site_initialize_0.so'>
DEBUG:jax._src.path:etils.epath was not found. Using pathlib for file I/O.
etils.epath was not found. Using pathlib for file I/O.
INFO:swarms_cloud.main:Registering POST endpoint at /agent
Registering POST endpoint at /agent
INFO: Started server process [28866]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
40 changes: 40 additions & 0 deletions example_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import requests
import json
import os
from dotenv import load_dotenv

load_dotenv()


# Get the API key from the environment
api_key = os.getenv("OPENAI_API_KEY")

# Define the API endpoint
url = "http://localhost:8000/agent"

# Define the input parameters for the agent
agent_parameters = {
"temperature": 0.5,
"model_name": "gpt-4",
"openai_api_key": api_key,
"max_loops": 1,
"autosave": True,
"dashboard": True,
}

# Define the task for the agent
task = "Generate a 10,000 word blog on health and wellness."

# Define the payload for the POST request
payload = {
"task": task,
"parameters": agent_parameters,
"args": [], # Add your args here
"kwargs": {}, # Add your kwargs here
}

# Send the POST request to the API
response = requests.post(url, data=json.dumps(payload))

# Print the response from the API
print(response.json())
14 changes: 9 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "paper"
name = "swarms-cloud"
version = "0.0.1"
description = "Paper - Pytorch"
description = "Swarms Cloud - Pytorch"
license = "MIT"
authors = ["Kye Gomez <[email protected]>"]
homepage = "https://github.com/kyegomez/paper"
documentation = "" # Add this if you have documentation.
homepage = "https://github.com/kyegomez/swarms-cloud"
documentation = "https://github.com/kyegomez/swarms-cloud" # Add this if you have documentation.
readme = "README.md" # Assuming you have a README.md
repository = "https://github.com/kyegomez/paper"
repository = "https://github.com/kyegomez/swarms-cloud"
keywords = ["artificial intelligence", "deep learning", "optimizers", "Prompt Engineering"]
classifiers = [
"Development Status :: 4 - Beta",
Expand All @@ -23,6 +23,10 @@ classifiers = [

[tool.poetry.dependencies]
python = "^3.9"
swarms = "*"
fastapi = "*"
skypilot = "*"



[tool.poetry.dev-dependencies]
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
swarms
sky
skypilot
fastapi
164 changes: 105 additions & 59 deletions swarms_cloud/main.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,109 @@
import logging
from typing import Callable, Type, Optional
import pytest
from unittest.mock import MagicMock

from fastapi import FastAPI, HTTPException
from swarms.structs.agent import Agent

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def agent_api_wrapper(
agent_class: Type[Agent],
app: FastAPI,
path: Optional[str] = None,
http_method: Optional[str] = "get",
logging: bool = False,
*args,
**kwargs,
):
"""Expose agent methods as API endpoints
Args:
agent_class (Type[Agent]): _description_
app (FastAPI): _description_
path (str): _description_
http_method (str, optional): _description_. Defaults to "get".
Example:
>>> from swarms.agents import Agent
>>> from fastapi import FastAPI
>>> app = FastAPI()
>>> @agent_api_wrapper(Agent, app, "/agent")
... def agent_method():
... return "Hello World"
"""

def decorator(func: Callable):
async def endpoint_wrapper(*args, **kwargs):
try:
logger.info(
f"Creating instance of {agent_class.__name__} with args: {args} and kwargs: {kwargs}"
)
agent_instance = agent_class(*args, **kwargs)
logger.info(f"Calling method {func.__name__} of {agent_class.__name__}")
result = getattr(agent_instance, func.__name__)()
logger.info(
f"Method {func.__name__} of {agent_class.__name__} returned: {result}"
)
return result
except Exception as error:
logger.error(f"An error occurred: {str(error)}")
raise HTTPException(status_code=500, detail=str(error))

if http_method.lower() == "get":
logger.info(f"Registering GET endpoint at {path}")
app.get(path)(endpoint_wrapper)
elif http_method.lower() == "post":
logger.info(f"Registering POST endpoint at {path}")
app.post(path)(endpoint_wrapper)
return func

return decorator
from swarms_cloud.main import agent_api_wrapper


class MockAgent(Agent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.method_called = False

def mock_method(self):
self.method_called = True
return "Hello World"


@pytest.mark.parametrize(
"http_method",
["get", "post"],
)
def test_decorator_registers_endpoint(http_method):
# Arrange
app = FastAPI()
path = "/agent"

# Act
@agent_api_wrapper(MockAgent, app, path, http_method=http_method)
def mock_method():
return "Hello World"

# Assert
if http_method == "get":
assert "GET" in app.routes
else:
assert "POST" in app.routes

assert path in app.routes


def test_decorator_logs_correctly():
# Arrange
app = FastAPI()
path = "/agent"
logger = MagicMock()

# Act
@agent_api_wrapper(MockAgent, app, path, logging=True)
def mock_method():
return "Hello World"

# Assert
assert logger.info.call_count == 3
assert logger.error.call_count == 0


def test_endpoint_wrapper_calls_agent_method():
# Arrange
app = FastAPI()
path = "/agent"
agent_instance = MockAgent()

# Act
@agent_api_wrapper(MockAgent, app, path)
def mock_method():
return "Hello World"

# Assert
assert agent_instance.mock_method.called


def test_endpoint_wrapper_returns_result():
# Arrange
app = FastAPI()
path = "/agent"
agent_instance = MockAgent()
expected_result = "Hello World"

# Act
@agent_api_wrapper(MockAgent, app, path)
def mock_method():
return expected_result

# Assert
response = app.get(path)
assert response.status_code == 200
assert response.json() == expected_result


def test_endpoint_wrapper_raises_exception():
# Arrange
app = FastAPI()
path = "/agent"
agent_instance = MockAgent()
expected_error = Exception("Test error")

# Act
@agent_api_wrapper(MockAgent, app, path)
def mock_method():
raise expected_error

# Assert
with pytest.raises(HTTPException) as error:
app.get(path)

assert error.value.status_code == 500
assert error.value.detail == str(expected_error)
79 changes: 79 additions & 0 deletions tests/test_agent_api_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from swarms.structs.agent import Agent
from swarms_cloud.main import agent_api_wrapper


@pytest.fixture
def app():
return FastAPI()


@pytest.fixture
def client(app):
return TestClient(app)


def test_get_endpoint_registration(app):
@agent_api_wrapper(Agent, app, path="/test-get", http_method="get")
def dummy_method():
pass

assert "/test-get" in [
route.path for route in app.routes if route.methods == {"GET"}
]


def test_post_endpoint_registration(app):
@agent_api_wrapper(Agent, app, path="/test-post", http_method="post")
def dummy_method():
pass

assert "/test-post" in [
route.path for route in app.routes if route.methods == {"POST"}
]


def test_agent_instantiation_with_args(app, monkeypatch):
test_args = ("arg1", "arg2")

def mock_init(self, *args, **kwargs):
assert args == test_args

monkeypatch.setattr(Agent, "__init__", mock_init)

@agent_api_wrapper(Agent, app, path="/test-args", http_method="get")
def dummy_method():
pass

client(app).get("/test-args", params={"args": test_args})


def test_successful_method_execution(app, monkeypatch):
def mock_method(self):
return "success"

monkeypatch.setattr(Agent, "dummy_method", mock_method)

@agent_api_wrapper(Agent, app, path="/test-success", http_method="get")
def dummy_method():
pass

response = client(app).get("/test-success")
assert response.text == "success"


def test_http_exception_on_failure(app, monkeypatch):
def mock_method(self):
raise ValueError("error")

monkeypatch.setattr(Agent, "dummy_method", mock_method)

@agent_api_wrapper(Agent, app, path="/test-exception", http_method="get")
def dummy_method():
pass

response = client(app).get("/test-exception")
assert response.status_code == 500
assert "error" in response.text
Loading

0 comments on commit a0c4d4f

Please sign in to comment.