Skip to content

Commit

Permalink
sdk test cases (#36)
Browse files Browse the repository at this point in the history
* sdk test cases 1

* add task queue tests

* Add tests for function

* Add test for function map

* add comment explanation

* change function to use mock_coroutine_with_result

* updated for review

* add github actions

* patch actions

* poetry

* poetry

* forgot ot checkut

* fix ruff

* fix ruff

* fix ruff

* last step

* ignore auth for ci

* ignore auth for ci
  • Loading branch information
jsun-m authored Jan 10, 2024
1 parent eb846f3 commit e56fe42
Show file tree
Hide file tree
Showing 17 changed files with 529 additions and 990 deletions.
55 changes: 55 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
name: CI

on: [push]

defaults:
run:
working-directory: sdk

jobs:
lint_and_test:
runs-on: ubuntu-latest
strategy:
max-parallel: 4

steps:
- name: Check out repository
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"

- name: Set up Poetry
env:
ACTIONS_ALLOW_UNSECURE_COMMANDS: true
uses: snok/install-poetry@v1
with:
version: 1.5.1
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true

- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v3
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}

- name: Install dependencies
run: poetry install
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'

- name: Code formatting
run: poetry run ruff format .

# todo: remove exit-zero once ruff issues are fixed
- name: Code linting
run: poetry run ruff check .

- name: Run tests
env:
CI: true
run: make tests
1 change: 1 addition & 0 deletions sdk/src/beam/abstractions/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
class BaseAbstraction(ABC):
def __init__(self) -> None:
self.loop: AbstractEventLoop = asyncio.get_event_loop()

self.channel: Channel = get_gateway_channel()

def run_sync(self, coroutine: Coroutine) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion sdk/src/beam/abstractions/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def _call_remote(self, *args, **kwargs) -> Any:
if r.output != "":
terminal.detail(r.output)

if r.done:
if r.done or r.exit_code != 0:
last_response = r
break

Expand Down
2 changes: 1 addition & 1 deletion sdk/src/beam/abstractions/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def put(self, value: Any) -> bool:
def pop(self) -> Any:
r = self.run_sync(self.stub.pop(name=self.name))
if not r.ok:
return SimpleQueueInternalServerError
raise SimpleQueueInternalServerError

if len(r.value) > 0:
return cloudpickle.loads(r.value)
Expand Down
9 changes: 4 additions & 5 deletions sdk/src/beam/abstractions/taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,10 @@ def __call__(self, *args, **kwargs) -> Any:
if container_id is not None:
return self.local(*args, **kwargs)

if not self.parent.prepare_runtime(
func=self.func,
stub_type=TASKQUEUE_STUB_TYPE,
):
return
raise NotImplementedError(
"Direct calls to TaskQueues are not yet supported."
+ " To enqueue items use .put(*args, **kwargs)"
)

def local(self, *args, **kwargs) -> Any:
return self.func(*args, **kwargs)
Expand Down
12 changes: 10 additions & 2 deletions sdk/src/beam/cli/configure.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import click

from beam import terminal
from beam.config import configure_gateway_credentials, load_config_from_file, save_config_to_file
from beam.config import (
configure_gateway_credentials,
load_config_from_file,
save_config_to_file,
)


@click.command()
Expand All @@ -13,7 +17,11 @@ def configure(name: str, token: str, gateway_host: str, gateway_port: str):
config = load_config_from_file()

config = configure_gateway_credentials(
config, name=name, gateway_host=gateway_host, gateway_port=gateway_port, token=token
config,
name=name,
gateway_host=gateway_host,
gateway_port=gateway_port,
token=token,
)

save_config_to_file(
Expand Down
4 changes: 4 additions & 0 deletions sdk/src/beam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def configure_gateway_credentials(


def get_gateway_channel() -> Channel:
if os.getenv("CI"):
# Ignore auth for CI
return Channel(host="localhost", port=50051, ssl=False)

config: GatewayConfig = get_gateway_config()
channel: Union[AuthenticatedChannel, None] = None

Expand Down
31 changes: 0 additions & 31 deletions sdk/tests/test_build.py

This file was deleted.

94 changes: 94 additions & 0 deletions sdk/tests/test_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from unittest import TestCase
from unittest.mock import MagicMock

import cloudpickle

from beam import Image
from beam.abstractions.function import Function
from beam.clients.function import FunctionInvokeResponse


class AsyncIterator:
def __init__(self, seq):
self.iter = iter(seq)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration


class TestTaskQueue(TestCase):
def test_init(self):
mock_stub = MagicMock()

queue = Function(Image(python_version="python3.8"), cpu=100, memory=128)
queue.stub = mock_stub

self.assertEqual(queue.image.python_version, "python3.8")
self.assertEqual(queue.cpu, 100)
self.assertEqual(queue.memory, 128)

def test_run_local(self):
@Function(Image(python_version="python3.8"), cpu=100, memory=128)
def test_func():
return 1

resp = test_func.local()

self.assertEqual(resp, 1)

def test_function_invoke(self):
@Function(Image(python_version="python3.8"), cpu=100, memory=128)
def test_func(*args, **kwargs):
return 1998

pickled_value = cloudpickle.dumps(1998)

test_func.parent.function_stub = MagicMock()
test_func.parent.syncer = MagicMock()

test_func.parent.function_stub.function_invoke.return_value = AsyncIterator(
[FunctionInvokeResponse(done=True, exit_code=0, result=pickled_value)]
)

test_func.parent.prepare_runtime = MagicMock(return_value=True)

self.assertEqual(test_func(), 1998)

test_func.parent.function_stub.function_invoke.return_value = AsyncIterator(
[FunctionInvokeResponse(done=False, exit_code=1, result=b"")]
)

self.assertRaises(SystemExit, test_func)

def test_map(self):
@Function(Image(python_version="python3.8"), cpu=100, memory=128)
def test_func(*args, **kwargs):
return 1998

pickled_value = cloudpickle.dumps(1998)

test_func.parent.function_stub = MagicMock()
test_func.parent.syncer = MagicMock()

# Since the return value is a reference to this same aysnc iterator, everytime it
# it will iterate to the next value. This iterator in testing is persisted across
# multiple calls to the function, so we can simulate multiple responses.
# (ONLY HAPPENS DURING TESTING)
test_func.parent.function_stub.function_invoke.return_value = AsyncIterator(
[
FunctionInvokeResponse(done=True, exit_code=0, result=pickled_value),
FunctionInvokeResponse(done=True, exit_code=0, result=pickled_value),
FunctionInvokeResponse(done=True, exit_code=0, result=pickled_value),
]
)

test_func.parent.prepare_runtime = MagicMock(return_value=True)

for val in test_func.map([1, 2, 3]):
self.assertEqual(val, 1998)
129 changes: 129 additions & 0 deletions sdk/tests/test_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from unittest import TestCase
from unittest.mock import MagicMock

import cloudpickle

from beam.abstractions.map import Map
from beam.clients.map import (
MapCountResponse,
MapDeleteResponse,
MapGetResponse,
MapKeysResponse,
MapSetResponse,
)

from .utils import mock_coroutine_with_result


class TestMap(TestCase):
def setUp(self):
pass

def test_set(self):
mock_stub = MagicMock()

mock_stub.map_set = mock_coroutine_with_result(MapSetResponse(ok=True))

map = Map(name="test")
map.stub = mock_stub

self.assertTrue(map.set("test", "test"))

mock_stub.map_set = mock_coroutine_with_result(MapSetResponse(ok=False))

map = Map(name="test")
map.stub = mock_stub

self.assertFalse(map.set("test", "test"))

def test_get(self):
mock_stub = MagicMock()

pickled_value = cloudpickle.dumps("test")

mock_stub.map_get = mock_coroutine_with_result(MapGetResponse(ok=True, value=pickled_value))

map = Map(name="test")
map.stub = mock_stub

self.assertEqual(map.get("test"), "test")

mock_stub.map_get = mock_coroutine_with_result(MapGetResponse(ok=False, value=b""))

map = Map(name="test")
map.stub = mock_stub

self.assertEqual(map.get("test"), None)

def test_delitem(self):
mock_stub = MagicMock()

mock_stub.map_delete = mock_coroutine_with_result(MapDeleteResponse(ok=True))

map = Map(name="test")
map.stub = mock_stub

del map["test"]

mock_stub.map_delete = mock_coroutine_with_result(MapDeleteResponse(ok=False))

map = Map(name="test")
map.stub = mock_stub

def _del():
del map["test"]

self.assertRaises(KeyError, _del)

def test_len(self):
mock_stub = MagicMock()

mock_stub.map_count = mock_coroutine_with_result(MapCountResponse(ok=True, count=1))

map = Map(name="test")
map.stub = mock_stub

self.assertEqual(len(map), 1)

mock_stub.map_count = mock_coroutine_with_result(MapCountResponse(ok=False, count=1))

map = Map(name="test")
map.stub = mock_stub

self.assertEqual(len(map), 0)

def test_iter(self):
mock_stub = MagicMock()

mock_stub.map_keys = mock_coroutine_with_result(MapKeysResponse(ok=True, keys=["test"]))

map = Map(name="test")
map.stub = mock_stub

self.assertEqual(list(map), ["test"])

mock_stub.map_keys = mock_coroutine_with_result(MapKeysResponse(ok=False, keys=[]))

map = Map(name="test")
map.stub = mock_stub

self.assertEqual(list(map), [])

def test_items(self):
mock_stub = MagicMock()

pickled_value = cloudpickle.dumps("test")

mock_stub.map_keys = mock_coroutine_with_result(MapKeysResponse(ok=True, keys=["test"]))
mock_stub.map_get = mock_coroutine_with_result(MapGetResponse(ok=True, value=pickled_value))

map = Map(name="test")
map.stub = mock_stub
self.assertListEqual(list(map.items()), [("test", "test")])

mock_stub.map_keys = mock_coroutine_with_result(MapKeysResponse(ok=False, keys=[]))

map = Map(name="test")
map.stub = mock_stub

self.assertEqual(list(map.items()), [])
Loading

0 comments on commit e56fe42

Please sign in to comment.