Skip to content

Commit cbf6e81

Browse files
authored
[python] added more examples and fix requirments version (#199)
1 parent 47e3b2f commit cbf6e81

File tree

8 files changed

+96
-40
lines changed

8 files changed

+96
-40
lines changed

python/scalellm/examples/async_stream_chat.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
def main():
55
# Create an LLM engine.
6-
engine = AsyncLLMEngine(model="meta-llama/Meta-Llama-3-8B-Instruct")
6+
engine = AsyncLLMEngine(model="google/gemma-1.1-2b-it")
77
# start the engine loop
88
engine.start()
99

@@ -25,20 +25,25 @@ def main():
2525

2626
# append the user message
2727
messages.append(Message(role="user", content=prompt))
28-
29-
output_stream = engine.schedule_chat(
30-
messages=messages,
31-
sampling_params=sampling_params,
32-
stream=True,
33-
)
34-
assistant_response = ""
35-
print("\n[Assistant]: ", end="", flush=True)
36-
for output in output_stream:
37-
if len(output.outputs) > 0:
38-
response = output.outputs[0].text
39-
assistant_response += response
40-
print(response, end="", flush=True)
41-
print()
28+
29+
try:
30+
output_stream = engine.schedule_chat(
31+
messages=messages,
32+
sampling_params=sampling_params,
33+
stream=True,
34+
)
35+
assistant_response = ""
36+
print("\n[Assistant]: ", end="", flush=True)
37+
for output in output_stream:
38+
if len(output.outputs) > 0:
39+
response = output.outputs[0].text
40+
assistant_response += response
41+
print(response, end="", flush=True)
42+
print()
43+
except KeyboardInterrupt:
44+
# cancel the request
45+
output_stream.cancel()
46+
break
4247

4348
# append the assistant message
4449
messages.append(Message(role="assistant", content=assistant_response))
@@ -48,7 +53,4 @@ def main():
4853

4954

5055
if __name__ == "__main__":
51-
try:
52-
main()
53-
except KeyboardInterrupt:
54-
pass
56+
main()

python/scalellm/examples/async_stream_complete.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,34 @@ def main():
77
# start the engine loop
88
engine.start()
99

10-
prompt = input("Enter a prompt: ")
10+
prompt = input("\n[Prompt]: ")
1111
while True:
1212
if prompt == "exit":
1313
break
1414
sampling_params = SamplingParams(
1515
temperature=0, top_p=1.0, max_tokens=100, echo=True
1616
)
17-
output_stream = engine.schedule(
18-
prompt=prompt,
19-
sampling_params=sampling_params,
20-
stream=True,
21-
)
22-
for output in output_stream:
23-
if len(output.outputs) > 0:
24-
print(output.outputs[0].text, end="", flush=True)
25-
print()
17+
try:
18+
output_stream = engine.schedule(
19+
prompt=prompt,
20+
sampling_params=sampling_params,
21+
stream=True,
22+
)
23+
for output in output_stream:
24+
if len(output.outputs) > 0:
25+
print(output.outputs[0].text, end="", flush=True)
26+
print()
27+
except KeyboardInterrupt:
28+
# cancel the request
29+
output_stream.cancel()
30+
break
2631

2732
# Get the next prompt.
28-
prompt = input("Enter a prompt: ")
33+
prompt = input("\n[Prompt]: ")
2934

3035
# stop the engine
3136
engine.stop()
3237

3338

3439
if __name__ == "__main__":
35-
try:
36-
main()
37-
except KeyboardInterrupt:
38-
pass
40+
main()
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from scalellm import LLM, SamplingParams
2+
3+
# Sample prompts.
4+
prompts = [
5+
"Hello, my name is",
6+
"The president of the United States is",
7+
"The capital of France is",
8+
"The future of AI is",
9+
]
10+
11+
# Create a sampling params object.
12+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, echo=True)
13+
14+
# Create an LLM.
15+
llm = LLM(model="gpt2", devices="cpu")
16+
# Generate texts from the prompts. The output is a list of RequestOutput objects
17+
# that contain the generated text, and other information.
18+
outputs = llm.generate(prompts, sampling_params)
19+
# Print the outputs.
20+
for i, output in enumerate(outputs):
21+
generated_text = output.outputs[0].text
22+
print(f"Generated text: {generated_text!r}")

python/scalellm/examples/offline_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, echo=True)
1313

1414
# Create an LLM.
15-
llm = LLM(model="gpt2")
15+
llm = LLM(model="gpt2", devices="cuda")
1616
# Generate texts from the prompts. The output is a list of RequestOutput objects
1717
# that contain the prompt, generated text, and other information.
1818
outputs = llm.generate(prompts, sampling_params)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from scalellm import LLM, SamplingParams
2+
3+
# Sample prompts.
4+
prompts = [
5+
"Hello, my name is",
6+
"The president of the United States is",
7+
"The capital of France is",
8+
"The future of AI is",
9+
]
10+
11+
# Create a sampling params object.
12+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, echo=True)
13+
14+
# Create an LLM.
15+
llm = LLM(
16+
model="google/gemma-7b",
17+
devices="cuda",
18+
draft_model="google/gemma-2b",
19+
draft_devices="cuda",
20+
num_speculative_tokens=4,
21+
)
22+
23+
# Generate texts from the prompts. The output is a list of RequestOutput objects
24+
# that contain the generated text, and other information.
25+
outputs = llm.generate(prompts, sampling_params)
26+
# Print the outputs.
27+
for i, output in enumerate(outputs):
28+
generated_text = output.outputs[0].text
29+
print(f"Generated text: {generated_text!r}")

python/scalellm/llm_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def error(self, error: str) -> bool:
4343
# cancel the stream
4444
def cancel(self) -> None:
4545
self._cancelled = True
46-
self._queue.put_nowait(None)
46+
self._queue.put_nowait(StopIteration())
4747

4848
def __iter__(self):
4949
return self
@@ -92,7 +92,7 @@ def error(self, error: str) -> bool:
9292
# cancel the stream
9393
def cancel(self) -> None:
9494
self._cancelled = True
95-
self._queue.put_nowait(None)
95+
self._queue.put_nowait(StopAsyncIteration())
9696

9797
def __aiter__(self):
9898
return self

python/setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def build_extension(self, ext: CMakeExtension):
220220
},
221221
classifiers=[
222222
"Development Status :: 3 - Alpha",
223+
"Programming Language :: Python :: 3.8",
223224
"Programming Language :: Python :: 3.9",
224225
"Programming Language :: Python :: 3.10",
225226
"Programming Language :: Python :: 3.11",
@@ -234,6 +235,6 @@ def build_extension(self, ext: CMakeExtension):
234235
package_data={
235236
"scalellm": scalellm_package_data,
236237
},
237-
python_requires=">=3.9",
238+
python_requires=">=3.8",
238239
install_requires=read_requirements(),
239240
)

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
torch >= 2.1.0
2+
fastapi >= 0.110.0
23
huggingface_hub
3-
shortuuid
4-
fastapi
4+
shortuuid

0 commit comments

Comments
 (0)