Skip to content

Commit

Permalink
[python] added more examples and fix requirments version (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi authored May 20, 2024
1 parent 47e3b2f commit cbf6e81
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 40 deletions.
40 changes: 21 additions & 19 deletions python/scalellm/examples/async_stream_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

def main():
# Create an LLM engine.
engine = AsyncLLMEngine(model="meta-llama/Meta-Llama-3-8B-Instruct")
engine = AsyncLLMEngine(model="google/gemma-1.1-2b-it")
# start the engine loop
engine.start()

Expand All @@ -25,20 +25,25 @@ def main():

# append the user message
messages.append(Message(role="user", content=prompt))

output_stream = engine.schedule_chat(
messages=messages,
sampling_params=sampling_params,
stream=True,
)
assistant_response = ""
print("\n[Assistant]: ", end="", flush=True)
for output in output_stream:
if len(output.outputs) > 0:
response = output.outputs[0].text
assistant_response += response
print(response, end="", flush=True)
print()

try:
output_stream = engine.schedule_chat(
messages=messages,
sampling_params=sampling_params,
stream=True,
)
assistant_response = ""
print("\n[Assistant]: ", end="", flush=True)
for output in output_stream:
if len(output.outputs) > 0:
response = output.outputs[0].text
assistant_response += response
print(response, end="", flush=True)
print()
except KeyboardInterrupt:
# cancel the request
output_stream.cancel()
break

# append the assistant message
messages.append(Message(role="assistant", content=assistant_response))
Expand All @@ -48,7 +53,4 @@ def main():


if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
pass
main()
32 changes: 17 additions & 15 deletions python/scalellm/examples/async_stream_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,34 @@ def main():
# start the engine loop
engine.start()

prompt = input("Enter a prompt: ")
prompt = input("\n[Prompt]: ")
while True:
if prompt == "exit":
break
sampling_params = SamplingParams(
temperature=0, top_p=1.0, max_tokens=100, echo=True
)
output_stream = engine.schedule(
prompt=prompt,
sampling_params=sampling_params,
stream=True,
)
for output in output_stream:
if len(output.outputs) > 0:
print(output.outputs[0].text, end="", flush=True)
print()
try:
output_stream = engine.schedule(
prompt=prompt,
sampling_params=sampling_params,
stream=True,
)
for output in output_stream:
if len(output.outputs) > 0:
print(output.outputs[0].text, end="", flush=True)
print()
except KeyboardInterrupt:
# cancel the request
output_stream.cancel()
break

# Get the next prompt.
prompt = input("Enter a prompt: ")
prompt = input("\n[Prompt]: ")

# stop the engine
engine.stop()


if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
pass
main()
22 changes: 22 additions & 0 deletions python/scalellm/examples/cpu_offline_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from scalellm import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, echo=True)

# Create an LLM.
llm = LLM(model="gpt2", devices="cpu")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for i, output in enumerate(outputs):
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
2 changes: 1 addition & 1 deletion python/scalellm/examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, echo=True)

# Create an LLM.
llm = LLM(model="gpt2")
llm = LLM(model="gpt2", devices="cuda")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
Expand Down
29 changes: 29 additions & 0 deletions python/scalellm/examples/speculative_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from scalellm import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, echo=True)

# Create an LLM.
llm = LLM(
model="google/gemma-7b",
devices="cuda",
draft_model="google/gemma-2b",
draft_devices="cuda",
num_speculative_tokens=4,
)

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for i, output in enumerate(outputs):
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
4 changes: 2 additions & 2 deletions python/scalellm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def error(self, error: str) -> bool:
# cancel the stream
def cancel(self) -> None:
self._cancelled = True
self._queue.put_nowait(None)
self._queue.put_nowait(StopIteration())

def __iter__(self):
return self
Expand Down Expand Up @@ -92,7 +92,7 @@ def error(self, error: str) -> bool:
# cancel the stream
def cancel(self) -> None:
self._cancelled = True
self._queue.put_nowait(None)
self._queue.put_nowait(StopAsyncIteration())

def __aiter__(self):
return self
Expand Down
3 changes: 2 additions & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def build_extension(self, ext: CMakeExtension):
},
classifiers=[
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand All @@ -234,6 +235,6 @@ def build_extension(self, ext: CMakeExtension):
package_data={
"scalellm": scalellm_package_data,
},
python_requires=">=3.9",
python_requires=">=3.8",
install_requires=read_requirements(),
)
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torch >= 2.1.0
fastapi >= 0.110.0
huggingface_hub
shortuuid
fastapi
shortuuid

0 comments on commit cbf6e81

Please sign in to comment.