Skip to content

Commit 5d31832

Browse files
authored
Data and request fixes for real data / chat_completions pathways (#131)
## Changes - Enable config settings for controlling the route (text_completions or chat_completions) used by the backend through GUIDELLM__PREFERRED_ROUTE - Fix data files erroring out due to Path objects being passed into HF load_dataset - Fix silent bug where a batched request is treated as a single request leading to issues in the output as well as slow times
1 parent 35fc95a commit 5d31832

File tree

4 files changed

+23
-8
lines changed

4 files changed

+23
-8
lines changed

src/guidellm/backend/openai.py

+7
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,13 @@ async def text_completions( # type: ignore[override]
202202
and a ResponseSummary for the final response.
203203
"""
204204
logger.debug("{} invocation with args: {}", self.__class__.__name__, locals())
205+
206+
if isinstance(prompt, list):
207+
raise ValueError(
208+
"List prompts (batching) is currently not supported for "
209+
f"text_completions OpenAI pathways. Received: {prompt}"
210+
)
211+
205212
headers = self._headers()
206213
payload = self._completions_payload(
207214
orig_kwargs=kwargs,

src/guidellm/config.py

+3
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ class Settings(BaseSettings):
134134
Literal["request", "response", "local"]
135135
] = "response"
136136
preferred_backend: Literal["openai"] = "openai"
137+
preferred_route: Literal["text_completions", "chat_completions"] = (
138+
"text_completions"
139+
)
137140
openai: OpenAISettings = OpenAISettings()
138141

139142
# Output settings

src/guidellm/dataset/file.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -71,19 +71,21 @@ def load_dataset(
7171

7272
dataset = Dataset.from_dict({"text": items}, **(data_args or {}))
7373
elif path.suffix.lower() == ".csv":
74-
dataset = load_dataset("csv", data_files=path, **(data_args or {}))
74+
dataset = load_dataset("csv", data_files=str(path), **(data_args or {}))
7575
elif path.suffix.lower() in {".json", ".jsonl"}:
76-
dataset = load_dataset("json", data_files=path, **(data_args or {}))
76+
dataset = load_dataset("json", data_files=str(path), **(data_args or {}))
7777
elif path.suffix.lower() == ".parquet":
78-
dataset = load_dataset("parquet", data_files=path, **(data_args or {}))
78+
dataset = load_dataset("parquet", data_files=str(path), **(data_args or {}))
7979
elif path.suffix.lower() == ".arrow":
80-
dataset = load_dataset("arrow", data_files=path, **(data_args or {}))
80+
dataset = load_dataset("arrow", data_files=str(path), **(data_args or {}))
8181
elif path.suffix.lower() == ".hdf5":
82-
dataset = Dataset.from_pandas(pd.read_hdf(path), **(data_args or {}))
82+
dataset = Dataset.from_pandas(pd.read_hdf(str(path)), **(data_args or {}))
8383
elif path.suffix.lower() == ".db":
84-
dataset = Dataset.from_sql(con=path, **(data_args or {}))
84+
dataset = Dataset.from_sql(con=str(path), **(data_args or {}))
8585
elif path.suffix.lower() == ".tar":
86-
dataset = load_dataset("webdataset", data_files=path, **(data_args or {}))
86+
dataset = load_dataset(
87+
"webdataset", data_files=str(path), **(data_args or {})
88+
)
8789
else:
8890
raise ValueError(f"Unsupported file type: {path.suffix} given for {path}. ")
8991

src/guidellm/request/loader.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
1212
from transformers import PreTrainedTokenizerBase # type: ignore[import]
1313

14+
from guidellm.config import settings
1415
from guidellm.dataset import ColumnInputTypes, load_dataset
1516
from guidellm.objects import StandardBaseModel
1617
from guidellm.request.request import GenerationRequest
@@ -61,6 +62,8 @@ class GenerativeRequestLoader(RequestLoader):
6162
"content",
6263
"conversation",
6364
"conversations",
65+
"turn",
66+
"turns",
6467
"text",
6568
]
6669

@@ -270,7 +273,7 @@ def _create_request(self, item: dict[str, Any]) -> GenerationRequest:
270273
)
271274

272275
return GenerationRequest(
273-
request_type="text_completions",
276+
request_type=settings.preferred_route,
274277
content=item[self.column_mappings["prompt_column"]],
275278
stats=(
276279
{"prompt_tokens": prompt_tokens} if prompt_tokens is not None else {}

0 commit comments

Comments
 (0)