diff --git a/llm/models.py b/llm/models.py index 3ed61bf3..8539d083 100644 --- a/llm/models.py +++ b/llm/models.py @@ -416,6 +416,9 @@ async def datetime_utc(self) -> str: await self._force() return self._start_utcnow.isoformat() if self._start_utcnow else "" + def __await__(self): + return self.text().__await__() + @classmethod def fake( cls, diff --git a/tests/test_async.py b/tests/test_async.py index c7d3f9d9..db4cf529 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -8,3 +8,7 @@ async def test_async_model(async_mock_model): async for chunk in async_mock_model.prompt("hello"): gathered.append(chunk) assert gathered == ["hello world"] + # Not as an iterator + async_mock_model.enqueue(["hello world"]) + text = await async_mock_model.prompt("hello") + assert text == "hello world"