Skip to content

Commit

Permalink
Adding a stream test
Browse files Browse the repository at this point in the history
  • Loading branch information
blast-hardcheese committed Dec 3, 2024
1 parent 187a487 commit f5de68b
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tests/codegen/stream/generated/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Code generated by river.codegen. DO NOT EDIT.
from pydantic import BaseModel
from typing import Literal

import replit_river as river


from .test_service import Test_ServiceService


class StreamClient:
def __init__(self, client: river.Client[Literal[None]]):
self.test_service = Test_ServiceService(client)
40 changes: 40 additions & 0 deletions tests/codegen/stream/generated/test_service/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Code generated by river.codegen. DO NOT EDIT.
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any
import datetime

from pydantic import TypeAdapter

from replit_river.error_schema import RiverError
import replit_river as river


from .stream_method import (
encode_Stream_MethodInput,
Stream_MethodOutput,
Stream_MethodInput,
)


class Test_ServiceService:
def __init__(self, client: river.Client[Any]):
self.client = client

async def stream_method(
self,
inputStream: AsyncIterable[Stream_MethodInput],
) -> AsyncIterator[Stream_MethodOutput | RiverError]:
return self.client.send_stream(
"test_service",
"stream_method",
None,
inputStream,
None,
encode_Stream_MethodInput,
lambda x: TypeAdapter(Stream_MethodOutput).validate_python(
x # type: ignore[arg-type]
),
lambda x: TypeAdapter(RiverError).validate_python(
x # type: ignore[arg-type]
),
)
40 changes: 40 additions & 0 deletions tests/codegen/stream/generated/test_service/stream_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# ruff: noqa
# Code generated by river.codegen. DO NOT EDIT.
from collections.abc import AsyncIterable, AsyncIterator
import datetime
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Mapping,
Union,
Tuple,
TypedDict,
)

from pydantic import BaseModel, Field, TypeAdapter
from replit_river.error_schema import RiverError

import replit_river as river


encode_Stream_MethodInput: Callable[["Stream_MethodInput"], Any] = lambda x: {
k: v
for (k, v) in (
{
"data": x.get("data"),
}
).items()
if v is not None
}


class Stream_MethodInput(TypedDict):
data: str


class Stream_MethodOutput(BaseModel):
data: str
32 changes: 32 additions & 0 deletions tests/codegen/stream/schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"services": {
"test_service": {
"procedures": {
"stream_method": {
"input": {
"type": "object",
"properties": {
"data": {
"type": "string"
}
},
"required": ["data"]
},
"output": {
"type": "object",
"properties": {
"data": {
"type": "string"
}
},
"required": ["data"]
},
"errors": {
"not": {}
},
"type": "stream"
}
}
}
}
}
46 changes: 46 additions & 0 deletions tests/codegen/stream/test_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import importlib
import shutil
from typing import AsyncIterable

import pytest

from replit_river.client import Client
from replit_river.codegen.client import schema_to_river_client_codegen
from tests.codegen.stream.generated.test_service.stream_method import (
Stream_MethodInput,
Stream_MethodOutput,
)
from tests.common_handlers import basic_stream


@pytest.fixture(scope="session", autouse=True)
def generate_stream_client() -> None:
import tests.codegen.stream.generated

shutil.rmtree("tests/codegen/stream/generated")
schema_to_river_client_codegen(
"tests/codegen/stream/schema.json",
"tests/codegen/stream/generated",
"StreamClient",
True,
)
importlib.reload(tests.codegen.stream.generated)


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_stream}])
async def test_basic_stream(client: Client) -> None:
from tests.codegen.stream.generated import StreamClient

async def emit() -> AsyncIterable[Stream_MethodInput]:
for i in range(5):
yield {"data": str(i)}

res = await StreamClient(client).test_service.stream_method(emit())

i = 0
async for datum in res:
assert isinstance(datum, Stream_MethodOutput)
assert f"Stream response for {i}" == datum.data, f"{i} == {datum.data}"
i = i + 1
assert i == 5

0 comments on commit f5de68b

Please sign in to comment.