From 10e535a98348bc5b9cb024f665b6c2527867fd69 Mon Sep 17 00:00:00 2001 From: Ted Benson Date: Thu, 12 Oct 2023 11:46:16 -0400 Subject: [PATCH] Llama and Dolly tests --- .../integrations/test_dolly2_12b.py | 34 +++++++++++++++++++ .../integrations/test_llama2_13b.py | 34 +++++++++++++++++++ .../integrations/test_llama2_70b.py | 34 +++++++++++++++++++ 3 files changed, 102 insertions(+) create mode 100644 tests/steamship_tests/integrations/test_dolly2_12b.py create mode 100644 tests/steamship_tests/integrations/test_llama2_13b.py create mode 100644 tests/steamship_tests/integrations/test_llama2_70b.py diff --git a/tests/steamship_tests/integrations/test_dolly2_12b.py b/tests/steamship_tests/integrations/test_dolly2_12b.py new file mode 100644 index 00000000..43fa69f4 --- /dev/null +++ b/tests/steamship_tests/integrations/test_dolly2_12b.py @@ -0,0 +1,34 @@ +import pytest + +from steamship import Steamship + + +def use_dolly2_12b(text: str, client: Steamship) -> str: + """Dolly2 provides text generation.""" + dolly2 = client.use_plugin( + plugin_handle="replicate-llm", + config={ + "model_name": "replicate/dolly-v2-12b", # Optional + "max_tokens": 256, # Optional + "temperature": 0.4, # Optional + }, + ) + + task = dolly2.generate( + text=text, + append_output_to_file=True, # Persist the output so that it's stored for later + make_output_public=True, # Permit anyone to consume the output + ) + + task.wait() # Wait for the generation to complete. + + output = task.output.blocks[0] # Get the output block containing the response + + return output.text + + +@pytest.mark.usefixtures("client") +def test_use_dolly2(client: Steamship): + response = use_dolly2_12b("Knock Knock!", client) + assert response + print(f"The 12B response is: {response}") diff --git a/tests/steamship_tests/integrations/test_llama2_13b.py b/tests/steamship_tests/integrations/test_llama2_13b.py new file mode 100644 index 00000000..1ad6443c --- /dev/null +++ b/tests/steamship_tests/integrations/test_llama2_13b.py @@ -0,0 +1,34 @@ +import pytest + +from steamship import Steamship + + +def use_llama2_13b(text: str, client: Steamship) -> str: + """LLama2 provides text generation.""" + llama2 = client.use_plugin( + plugin_handle="replicate-llm", + config={ + "model_name": "a16z-infra/llama-2-13b-chat", # Optional + "max_tokens": 256, # Optional + "temperature": 0.4, # Optional + }, + ) + + task = llama2.generate( + text=text, + append_output_to_file=True, # Persist the output so that it's stored for later + make_output_public=True, # Permit anyone to consume the output + ) + + task.wait() # Wait for the generation to complete. + + output = task.output.blocks[0] # Get the output block containing the response + + return output.text + + +@pytest.mark.usefixtures("client") +def test_use_llama2(client: Steamship): + response = use_llama2_13b("Knock Knock!", client) + assert response + print(f"The 13B response is: {response}") diff --git a/tests/steamship_tests/integrations/test_llama2_70b.py b/tests/steamship_tests/integrations/test_llama2_70b.py new file mode 100644 index 00000000..ad9d016d --- /dev/null +++ b/tests/steamship_tests/integrations/test_llama2_70b.py @@ -0,0 +1,34 @@ +from steamship import Steamship + + +def use_llama2_70b(text: str, client: Steamship) -> str: + """LLama2 provides text generation.""" + llama2 = client.use_plugin( + plugin_handle="replicate-llm", + config={ + "model_name": "replicate/llama-2-70b-chat", # Optional + "max_tokens": 256, # Optional + "temperature": 0.4, # Optional + }, + ) + + task = llama2.generate( + text=text, + append_output_to_file=True, # Persist the output so that it's stored for later + make_output_public=True, # Permit anyone to consume the output + ) + + task.wait() # Wait for the generation to complete. + + output = task.output.blocks[0] # Get the output block containing the response + + return output.text + + +# NOTE: In internal testing, replicate/llama-2-70b-chat consistently failed to return responses. +# +# @pytest.mark.usefixtures("client") +# def test_use_llama2(client: Steamship): +# response = use_llama2_70b("Knock Knock!", client) +# assert response +# print(f"The 70B response is: {response}")