Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added vdm_upload_pipe Tests #23

Open
wants to merge 6 commits into
base: devin_fea_1453_simplify_vectordb_example
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions examples/llm/common/content_extractor_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,20 @@ def convert(self,
file_path = [file_path]

docs: list[Document] = []
if meta is None:
text_column_name = "content"
else:
text_column_name = meta.get("csv", {}).get("text_column_name", "content")
text_column_names = {"content"}

if meta is not None:
text_column_names = set(meta.get("csv", {}).get("text_column_names", text_column_names))

for path in file_path:
df = pd.read_csv(path, encoding=encoding)
if len(df.columns) == 0 or (text_column_name not in df.columns):
if len(df.columns) == 0 or (not text_column_names.issubset(set(df.columns))):
raise ValueError("The CSV file must either include a 'content' column or have a "
"column specified in the meta configuraton with key 'text_column_name'.")
"columns specified in the meta configuraton with key 'text_column_names'.")

df.fillna(value="", inplace=True)
df[text_column_name] = df[text_column_name].apply(lambda x: x.strip())
df["content"] = df[text_column_names].apply(lambda x: ' '.join(map(str, x)), axis=1)

df = df.rename(columns={text_column_name: "content"})
docs_dicts = df.to_dict(orient="records")

for dictionary in docs_dicts:
Expand Down Expand Up @@ -195,7 +194,7 @@ def process_content(docs: list[Document], file_meta: FileMeta, chunk_size: int,

class CSVConverterParamContract(BaseModel):
chunk_size: int = 1024
text_column_name: str = "raw"
text_column_names: list[str] = Field(default_factory=["raw"])
chunk_overlap: int = 102 # Example default value


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from morpheus.messages.multi_message import MultiMessage
from morpheus.pipeline.pipeline import Pipeline
from morpheus.stages.general.linear_modules_source import LinearModuleSourceStage
from .module.file_source_pipe import FileSourcePipe
from .module.rss_source_pipe import RSSSourcePipe
from vdb_upload.module.file_source_pipe import FileSourcePipe
from vdb_upload.module.rss_source_pipe import RSSSourcePipe

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion examples/llm/vdb_upload/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.milvus import Milvus

from llm.vdb_upload.common import build_rss_urls
from examples.llm.vdb_upload.helper import build_rss_urls
from morpheus.utils.logging_timer import log_time

logger = logging.getLogger(__name__)
Expand Down
5 changes: 2 additions & 3 deletions examples/llm/vdb_upload/module/file_source_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@
from morpheus.modules.preprocess.deserialize import DeserializeInterface
from morpheus.utils.module_utils import ModuleInterface
from morpheus.utils.module_utils import register_module

from ...common.content_extractor_module import FileContentExtractorInterface
from .schema_transform import SchemaTransformInterface
from vdb_upload.module.schema_transform import SchemaTransformInterface
from common.content_extractor_module import FileContentExtractorInterface

logger = logging.getLogger(__name__)

Expand Down
5 changes: 2 additions & 3 deletions examples/llm/vdb_upload/module/rss_source_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@
from morpheus.modules.preprocess.deserialize import DeserializeInterface
from morpheus.utils.module_utils import ModuleInterface
from morpheus.utils.module_utils import register_module

from ...common.web_scraper_module import WebScraperInterface
from .schema_transform import SchemaTransformInterface
from vdb_upload.module.schema_transform import SchemaTransformInterface
from common.web_scraper_module import WebScraperInterface

logger = logging.getLogger(__name__)

Expand Down
3 changes: 1 addition & 2 deletions examples/llm/vdb_upload/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from morpheus.stages.inference.triton_inference_stage import TritonInferenceStage
from morpheus.stages.output.write_to_vector_db_stage import WriteToVectorDBStage
from morpheus.stages.preprocess.preprocess_nlp_stage import PreprocessNLPStage

from .common import process_vdb_sources
from vdb_upload.helper import process_vdb_sources

logger = logging.getLogger(__name__)

Expand Down
6 changes: 3 additions & 3 deletions examples/llm/vdb_upload/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import click

from .vdb_utils import build_cli_configs
from .vdb_utils import build_final_config
from .vdb_utils import is_valid_service
from vdb_upload.vdb_utils import build_cli_configs
from vdb_upload.vdb_utils import build_final_config
from vdb_upload.vdb_utils import is_valid_service

logger = logging.getLogger(__name__)

Expand Down
3 changes: 1 addition & 2 deletions examples/llm/vdb_upload/vdb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import logging
import typing

import click
import pymilvus
import yaml

Expand Down Expand Up @@ -347,7 +346,7 @@ def build_final_config(vdb_conf_path,
source_conf = vdb_pipeline_config.get('sources', []) + list(cli_source_conf.values())
tokenizer_conf = merge_configs(vdb_pipeline_config.get('tokenizer', {}), cli_tokenizer_conf)
vdb_conf = vdb_pipeline_config.get('vdb', {})
resource_schema = vdb_conf.pop("resource_shema", None)
resource_schema = vdb_conf.pop("resource_schema", None)

if resource_schema:
vdb_conf["resource_kwargs"] = build_milvus_config(resource_schema)
Expand Down
11 changes: 7 additions & 4 deletions morpheus/modules/output/write_to_vector_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,13 @@ def on_completed():

# Pushing remaining messages
for key, accum_stats in accumulator_dict.items():
if accum_stats.data:
merged_df = cudf.concat(accum_stats.data)
service.insert_dataframe(name=key, df=merged_df)
final_df_references.append(accum_stats.data)
try:
if accum_stats.data:
merged_df = cudf.concat(accum_stats.data)
service.insert_dataframe(name=key, df=merged_df)
final_df_references.append(accum_stats.data)
except Exception as e:
logger.error(f"Unable to upload dataframe entries to vector database: {e}")
# Close vector database service connection
service.close()

Expand Down
156 changes: 65 additions & 91 deletions tests/llm/test_vdb_upload_pipe.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -27,93 +27,88 @@
from _utils import mk_async_infer
from _utils.dataset_manager import DatasetManager
from morpheus.config import Config
from morpheus.config import PipelineModes
from morpheus.pipeline.linear_pipeline import LinearPipeline
from morpheus.service.vdb.milvus_vector_db_service import MilvusVectorDBService
from morpheus.stages.inference.triton_inference_stage import TritonInferenceStage
from morpheus.stages.input.rss_source_stage import RSSSourceStage
from morpheus.stages.output.write_to_vector_db_stage import WriteToVectorDBStage
from morpheus.stages.preprocess.deserialize_stage import DeserializeStage
from morpheus.stages.preprocess.preprocess_nlp_stage import PreprocessNLPStage

EMBEDDING_SIZE = 384
MODEL_MAX_BATCH_SIZE = 64
MODEL_FEA_LENGTH = 512


def _run_pipeline(config: Config,
milvus_server_uri: str,
collection_name: str,
rss_files: list[str],
utils_mod: types.ModuleType,
web_scraper_stage_mod: types.ModuleType):

config.mode = PipelineModes.NLP
config.pipeline_batch_size = 1024
config.model_max_batch_size = MODEL_MAX_BATCH_SIZE
config.feature_length = MODEL_FEA_LENGTH
config.edge_buffer_size = 128
config.class_labels = [str(i) for i in range(EMBEDDING_SIZE)]

pipe = LinearPipeline(config)

pipe.set_source(
RSSSourceStage(config, feed_input=rss_files, batch_size=128, run_indefinitely=False, enable_cache=False))
pipe.add_stage(web_scraper_stage_mod.WebScraperStage(config, chunk_size=MODEL_FEA_LENGTH, enable_cache=False))
pipe.add_stage(DeserializeStage(config))

pipe.add_stage(
PreprocessNLPStage(config,
vocab_hash_file=os.path.join(TEST_DIRS.data_dir, 'bert-base-uncased-hash.txt'),
do_lower_case=True,
truncation=True,
add_special_tokens=False,
column='page_content'))

pipe.add_stage(
TritonInferenceStage(config, model_name='test-model', server_url='test:0000', force_convert_inputs=True))

pipe.add_stage(
WriteToVectorDBStage(config,
resource_name=collection_name,
resource_kwargs=utils_mod.build_milvus_config(embedding_size=EMBEDDING_SIZE),
recreate=True,
service="milvus",
uri=milvus_server_uri))
pipe.run()


@pytest.mark.milvus
@pytest.mark.use_python
@pytest.mark.use_pandas
@pytest.mark.import_mod([
os.path.join(TEST_DIRS.examples_dir, 'llm/common/utils.py'),
os.path.join(TEST_DIRS.examples_dir, 'llm/common/web_scraper_stage.py')
os.path.join(TEST_DIRS.examples_dir, 'llm/common'),
os.path.join(TEST_DIRS.examples_dir, 'llm/vdb_upload/helper.py'),
os.path.join(TEST_DIRS.examples_dir, 'llm/vdb_upload/run.py'),
os.path.join(TEST_DIRS.examples_dir, 'llm/vdb_upload/pipeline.py')
])
@mock.patch('requests.Session')
@mock.patch('tritonclient.grpc.InferenceServerClient')
@pytest.mark.parametrize('is_rss_source, exclude_columns, expected_output_path, vdb_conf_file',
[(True, ['id', 'embedding', 'source'],
'service/milvus_rss_data.json',
'examples/llm/vdb_upload/vdb_rss_source_config.yaml'),
(False, ['id', 'embedding'],
'examples/llm/vdb_upload/test_data_output.json',
'examples/llm/vdb_upload/vdb_file_source_config.yaml')])
def test_vdb_upload_pipe(mock_triton_client: mock.MagicMock,
mock_requests_session: mock.MagicMock,
config: Config,
dataset: DatasetManager,
milvus_server_uri: str,
import_mod: list[types.ModuleType]):
mock_requests_session: mock.MagicMock,
dataset: DatasetManager,
milvus_server_uri: str,
import_mod: list[types.ModuleType],
is_rss_source: str,
exclude_columns: list[str],
expected_output_path: str,
vdb_conf_file: str):

# We're going to use this DF to both provide values to the mocked Tritonclient,
# but also to verify the values in the Milvus collection.
expected_values_df = dataset["service/milvus_rss_data.json"]

with open(os.path.join(TEST_DIRS.tests_data_dir, 'service/cisa_web_responses.json'), encoding='utf-8') as fh:
web_responses = json.load(fh)
expected_values_df = dataset[expected_output_path]

if is_rss_source:
with open(os.path.join(TEST_DIRS.tests_data_dir, 'service/cisa_web_responses.json'), encoding='utf-8') as fh:
web_responses = json.load(fh)

# Mock requests, since we are feeding the RSSSourceStage with a local file it won't be using the
# requests lib, only web_scraper_stage.py will use it.
def mock_get_fn(url: str):
mock_response = mock.MagicMock()
mock_response.ok = True
mock_response.status_code = 200
mock_response.text = web_responses[url]
return mock_response

mock_requests_session.return_value = mock_requests_session
mock_requests_session.get.side_effect = mock_get_fn

# As page_content is used by other pipelines, we're just renaming it to content.
expected_values_df = expected_values_df.rename(columns={"page_content": "content"})
expected_values_df["source"] = "rss"

vdb_conf_path = os.path.join(TEST_DIRS.tests_data_dir, vdb_conf_file)

_, _, vdb_upload_run_mod, vdb_upload_pipeline_mod = import_mod

# Building final configuration. Here we're passing empty dictionaries for cli configuration.
vdb_pipeline_config = vdb_upload_run_mod.build_final_config(vdb_conf_path=vdb_conf_path,
cli_source_conf={},
cli_embeddings_conf={},
cli_pipeline_conf={},
cli_tokenizer_conf={},
cli_vdb_conf={})

config: Config = vdb_pipeline_config["pipeline_config"]

# Overwriting uri provided in the config file with milvus_server_uri
vdb_pipeline_config["vdb_config"]["uri"] = milvus_server_uri
collection_name = vdb_pipeline_config["vdb_config"]["resource_name"]

# Mock Triton results
mock_metadata = {
"inputs": [{
"name": "input_ids", "datatype": "INT32", "shape": [-1, MODEL_FEA_LENGTH]
"name": "input_ids", "datatype": "INT32", "shape": [-1, config.feature_length]
}, {
"name": "attention_mask", "datatype": "INT32", "shape": [-1, MODEL_FEA_LENGTH]
"name": "attention_mask", "datatype": "INT32", "shape": [-1, config.feature_length]
}],
"outputs": [{
"name": "output", "datatype": "FP32", "shape": [-1, EMBEDDING_SIZE]
"name": "output", "datatype": "FP32", "shape": [-1, len(config.class_labels)]
}]
}
mock_model_config = {"config": {"max_batch_size": 256}}
Expand All @@ -127,36 +122,15 @@ def test_vdb_upload_pipe(mock_triton_client: mock.MagicMock,

mock_result_values = expected_values_df['embedding'].to_list()
inf_results = np.split(mock_result_values,
range(MODEL_MAX_BATCH_SIZE, len(mock_result_values), MODEL_MAX_BATCH_SIZE))
range(config.model_max_batch_size, len(mock_result_values), config.model_max_batch_size))

# The triton client is going to perform a logits function, calculate the inverse of it here
inf_results = [np.log((1.0 / x) - 1.0) * -1 for x in inf_results]

async_infer = mk_async_infer(inf_results)
mock_triton_client.async_infer.side_effect = async_infer

# Mock requests, since we are feeding the RSSSourceStage with a local file it won't be using the
# requests lib, only web_scraper_stage.py will use it.
def mock_get_fn(url: str):
mock_response = mock.MagicMock()
mock_response.ok = True
mock_response.status_code = 200
mock_response.text = web_responses[url]
return mock_response

mock_requests_session.return_value = mock_requests_session
mock_requests_session.get.side_effect = mock_get_fn

(utils_mod, web_scraper_stage_mod) = import_mod
collection_name = "test_vdb_upload_pipe"
rss_source_file = os.path.join(TEST_DIRS.tests_data_dir, 'service/cisa_rss_feed.xml')

_run_pipeline(config=config,
milvus_server_uri=milvus_server_uri,
collection_name=collection_name,
rss_files=[rss_source_file],
utils_mod=utils_mod,
web_scraper_stage_mod=web_scraper_stage_mod)
vdb_upload_pipeline_mod.pipeline(**vdb_pipeline_config)

milvus_service = MilvusVectorDBService(uri=milvus_server_uri)
resource_service = milvus_service.load_resource(name=collection_name)
Expand All @@ -167,7 +141,7 @@ def mock_get_fn(url: str):
db_df = pd.DataFrame(sorted(db_results, key=lambda k: k['id']))

# The comparison function performs rounding on the values, but is unable to do so for array columns
dataset.assert_compare_df(db_df, expected_values_df[db_df.columns], exclude_columns=['id', 'embedding'])
dataset.assert_compare_df(db_df, expected_values_df[db_df.columns], exclude_columns=exclude_columns)
db_emb = db_df['embedding']
expected_emb = expected_values_df['embedding']

Expand Down
3 changes: 3 additions & 0 deletions tests/tests_data/examples/llm/vdb_upload/test_data.csv
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown