Skip to content

Commit

Permalink
Don't mock asyncio.gather because it is also used by jinja
Browse files Browse the repository at this point in the history
  • Loading branch information
dagardner-nv committed Jan 12, 2024
1 parent 27c0a18 commit f6eaab4
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions tests/llm/test_rag_standalone_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,24 +131,18 @@ def _run_pipeline(config: Config,
@pytest.mark.use_cudf
@pytest.mark.parametrize("repeat_count", [5])
@pytest.mark.import_mod(os.path.join(TEST_DIRS.examples_dir, 'llm/common/utils.py'))
@mock.patch("asyncio.wrap_future")
@mock.patch("asyncio.gather", new_callable=mock.AsyncMock)
def test_rag_standalone_pipe_nemo(
mock_asyncio_gather: mock.AsyncMock,
mock_asyncio_wrap_future: mock.MagicMock, # pylint: disable=unused-argument
config: Config,
mock_nemollm: mock.MagicMock,
dataset: DatasetManager,
milvus_server_uri: str,
repeat_count: int,
import_mod: types.ModuleType):
def test_rag_standalone_pipe_nemo(config: Config,
mock_nemollm: mock.MagicMock,
dataset: DatasetManager,
milvus_server_uri: str,
repeat_count: int,
import_mod: types.ModuleType):
collection_name = "test_rag_standalone_pipe_nemo"
populate_milvus(milvus_server_uri=milvus_server_uri,
collection_name=collection_name,
resource_kwargs=import_mod.build_milvus_config(embedding_size=EMBEDDING_SIZE),
df=dataset["service/milvus_rss_data.json"],
overwrite=True)
mock_asyncio_gather.return_value = [mock.MagicMock() for _ in range(repeat_count)]
mock_nemollm.post_process_generate_response.side_effect = [{"text": EXPECTED_RESPONSE} for _ in range(repeat_count)]
results = _run_pipeline(
config=config,
Expand Down

0 comments on commit f6eaab4

Please sign in to comment.