Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Unit tests illustrating Llama2 and Dolly2 invocations #584

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/steamship_tests/integrations/test_dolly2_12b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest

from steamship import Steamship


def use_dolly2_12b(text: str, client: Steamship) -> str:
"""Dolly2 provides text generation."""
dolly2 = client.use_plugin(
plugin_handle="replicate-llm",
config={
"model_name": "replicate/dolly-v2-12b", # Optional
"max_tokens": 256, # Optional
"temperature": 0.4, # Optional
},
)

task = dolly2.generate(
text=text,
append_output_to_file=True, # Persist the output so that it's stored for later
make_output_public=True, # Permit anyone to consume the output
)

task.wait() # Wait for the generation to complete.

output = task.output.blocks[0] # Get the output block containing the response

return output.text


@pytest.mark.usefixtures("client")
def test_use_dolly2(client: Steamship):
response = use_dolly2_12b("Knock Knock!", client)
assert response
print(f"The 12B response is: {response}")
34 changes: 34 additions & 0 deletions tests/steamship_tests/integrations/test_llama2_13b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest

from steamship import Steamship


def use_llama2_13b(text: str, client: Steamship) -> str:
"""LLama2 provides text generation."""
llama2 = client.use_plugin(
plugin_handle="replicate-llm",
config={
"model_name": "a16z-infra/llama-2-13b-chat", # Optional
"max_tokens": 256, # Optional
"temperature": 0.4, # Optional
},
)

task = llama2.generate(
text=text,
append_output_to_file=True, # Persist the output so that it's stored for later
make_output_public=True, # Permit anyone to consume the output
)

task.wait() # Wait for the generation to complete.

output = task.output.blocks[0] # Get the output block containing the response

return output.text


@pytest.mark.usefixtures("client")
def test_use_llama2(client: Steamship):
response = use_llama2_13b("Knock Knock!", client)
assert response
print(f"The 13B response is: {response}")
34 changes: 34 additions & 0 deletions tests/steamship_tests/integrations/test_llama2_70b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from steamship import Steamship


def use_llama2_70b(text: str, client: Steamship) -> str:
"""LLama2 provides text generation."""
llama2 = client.use_plugin(
plugin_handle="replicate-llm",
config={
"model_name": "replicate/llama-2-70b-chat", # Optional
"max_tokens": 256, # Optional
"temperature": 0.4, # Optional
},
)

task = llama2.generate(
text=text,
append_output_to_file=True, # Persist the output so that it's stored for later
make_output_public=True, # Permit anyone to consume the output
)

task.wait() # Wait for the generation to complete.

output = task.output.blocks[0] # Get the output block containing the response

return output.text


# NOTE: In internal testing, replicate/llama-2-70b-chat consistently failed to return responses.
#
# @pytest.mark.usefixtures("client")
# def test_use_llama2(client: Steamship):
# response = use_llama2_70b("Knock Knock!", client)
# assert response
# print(f"The 70B response is: {response}")
Loading