Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sdk test cases #36

Merged
merged 19 commits into from
Jan 10, 2024
2 changes: 1 addition & 1 deletion sdk/src/beam/abstractions/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def _call_remote(self, *args, **kwargs) -> Any:
if r.output != "":
terminal.detail(r.output)

if r.done:
if r.done or r.exit_code != 0:
last_response = r
break

Expand Down
2 changes: 1 addition & 1 deletion sdk/src/beam/abstractions/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def put(self, value: Any) -> bool:
def pop(self) -> Any:
r = self.run_sync(self.stub.pop(name=self.name))
if not r.ok:
return SimpleQueueInternalServerError
raise SimpleQueueInternalServerError

if len(r.value) > 0:
return cloudpickle.loads(r.value)
Expand Down
9 changes: 4 additions & 5 deletions sdk/src/beam/abstractions/taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,10 @@ def __call__(self, *args, **kwargs) -> Any:
if container_id is not None:
return self.local(*args, **kwargs)

if not self.parent.prepare_runtime(
func=self.func,
stub_type=TASKQUEUE_STUB_TYPE,
):
return
raise NotImplementedError(
"Direct calls to TaskQueues are not yet supported."
+ " To enqueue items use .put(*args, **kwargs)"
)

def local(self, *args, **kwargs) -> Any:
return self.func(*args, **kwargs)
Expand Down
3 changes: 3 additions & 0 deletions sdk/src/beam/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@
cli = click.Group()
cli.add_command(configure.configure)
cli.add_command(tasks.cli)

luke-lombardi marked this conversation as resolved.
Show resolved Hide resolved
if __name__ == "__main__":
cli()
31 changes: 0 additions & 31 deletions sdk/tests/test_build.py

This file was deleted.

98 changes: 98 additions & 0 deletions sdk/tests/test_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from unittest import TestCase
from unittest.mock import MagicMock

import cloudpickle

from beam import Image
from beam.abstractions.function import Function
from beam.clients.function import FunctionInvokeResponse

from .utils import override_run_sync


class AsyncIterator:
def __init__(self, seq):
self.iter = iter(seq)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration


class TestTaskQueue(TestCase):
def test_init(self):
mock_stub = MagicMock()

queue = Function(Image(python_version="python3.7"), cpu=100, memory=128)
luke-lombardi marked this conversation as resolved.
Show resolved Hide resolved
queue.stub = mock_stub

self.assertEqual(queue.image.python_version, "python3.7")
self.assertEqual(queue.cpu, 100)
self.assertEqual(queue.memory, 128)

def test_run_local(self):
@Function(Image(python_version="python3.7"), cpu=100, memory=128)
def test_func():
return 1

resp = test_func.local()

self.assertEqual(resp, 1)

def test_function_invoke(self):
@Function(Image(python_version="python3.7"), cpu=100, memory=128)
def test_func(*args, **kwargs):
return 1998

pickled_value = cloudpickle.dumps(1998)

test_func.parent.function_stub = MagicMock()
test_func.parent.syncer = MagicMock()

test_func.parent.function_stub.function_invoke.return_value = AsyncIterator(
[FunctionInvokeResponse(done=True, exit_code=0, result=pickled_value)]
)

test_func.parent.prepare_runtime = MagicMock(return_value=True)
test_func.parent.run_sync = override_run_sync

self.assertEqual(test_func(), 1998)

test_func.parent.function_stub.function_invoke.return_value = AsyncIterator(
[FunctionInvokeResponse(done=False, exit_code=1, result=b"")]
)

self.assertRaises(SystemExit, test_func)

def test_map(self):
@Function(Image(python_version="python3.7"), cpu=100, memory=128)
def test_func(*args, **kwargs):
return 1998

pickled_value = cloudpickle.dumps(1998)

test_func.parent.function_stub = MagicMock()
test_func.parent.syncer = MagicMock()

# Since the return value is a reference to this same aysnc iterator, everytime it
# it will iterate to the next value. This iterator in testing is persisted across
# multiple calls to the function, so we can simulate multiple responses.
# (ONLY HAPPENS DURING TESTING)
test_func.parent.function_stub.function_invoke.return_value = AsyncIterator(
[
FunctionInvokeResponse(done=True, exit_code=0, result=pickled_value),
FunctionInvokeResponse(done=True, exit_code=0, result=pickled_value),
FunctionInvokeResponse(done=True, exit_code=0, result=pickled_value),
]
)

test_func.parent.prepare_runtime = MagicMock(return_value=True)
test_func.parent.run_sync = override_run_sync

for val in test_func.map([1, 2, 3]):
self.assertEqual(val, 1998)
142 changes: 142 additions & 0 deletions sdk/tests/test_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from unittest import TestCase
from unittest.mock import MagicMock

import cloudpickle

from beam.abstractions.map import Map
from beam.clients.map import (
MapCountResponse,
MapDeleteResponse,
MapGetResponse,
MapKeysResponse,
MapSetResponse,
)

from .utils import override_run_sync


class TestMap(TestCase):
def setUp(self):
pass

def test_set(self):
mock_stub = MagicMock()

mock_stub.map_set.return_value = MapSetResponse(ok=True)

map = Map(name="test")
map.stub = mock_stub
map.run_sync = override_run_sync

self.assertTrue(map.set("test", "test"))

mock_stub.map_set.return_value = MapSetResponse(ok=False)

map = Map(name="test")
map.stub = mock_stub
map.run_sync = override_run_sync

self.assertFalse(map.set("test", "test"))

def test_get(self):
mock_stub = MagicMock()

pickled_value = cloudpickle.dumps("test")

mock_stub.map_get.return_value = MapGetResponse(ok=True, value=pickled_value)

map = Map(name="test")
map.stub = mock_stub
map.run_sync = override_run_sync

self.assertEqual(map.get("test"), "test")

mock_stub.map_get.return_value = MapGetResponse(ok=False, value=b"")

map = Map(name="test")
map.stub = mock_stub
map.run_sync = override_run_sync

self.assertEqual(map.get("test"), None)

def test_delitem(self):
mock_stub = MagicMock()

mock_stub.map_delete.return_value = MapDeleteResponse(ok=True)

map = Map(name="test")
map.stub = mock_stub
map.run_sync = override_run_sync

del map["test"]

mock_stub.map_delete.return_value = MapDeleteResponse(ok=False)

map = Map(name="test")
map.stub = mock_stub
map.run_sync = override_run_sync

def _del():
del map["test"]

self.assertRaises(KeyError, _del)

def test_len(self):
mock_stub = MagicMock()

mock_stub.map_count.return_value = MapCountResponse(ok=True, count=1)

map = Map(name="test")
map.stub = mock_stub
map.run_sync = override_run_sync

self.assertEqual(len(map), 1)

mock_stub.map_count.return_value = MapCountResponse(ok=False, count=1)

map = Map(name="test")
map.stub = mock_stub
map.run_sync = override_run_sync

self.assertEqual(len(map), 0)

def test_iter(self):
mock_stub = MagicMock()

mock_stub.map_keys.return_value = MapKeysResponse(ok=True, keys=["test"])

map = Map(name="test")
map.stub = mock_stub
map.run_sync = override_run_sync

self.assertEqual(list(map), ["test"])

mock_stub.map_keys.return_value = MapKeysResponse(ok=False, keys=[])

map = Map(name="test")
map.stub = mock_stub
map.run_sync = override_run_sync

self.assertEqual(list(map), [])

def test_items(self):
mock_stub = MagicMock()

pickled_value = cloudpickle.dumps("test")

mock_stub.map_keys.return_value = MapKeysResponse(ok=True, keys=["test"])
mock_stub.map_get.return_value = MapGetResponse(ok=True, value=pickled_value)

map = Map(name="test")
map.stub = mock_stub
map.run_sync = override_run_sync

self.assertEqual(list(map.items()), [("test", "test")])

mock_stub.map_keys.return_value = MapKeysResponse(ok=False, keys=[])

map = Map(name="test")
map.stub = mock_stub
map.run_sync = override_run_sync

self.assertEqual(list(map.items()), [])
Loading