From 730e2b8acba384f44124e2f8c14cc52db18dcf3b Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 31 Jul 2024 15:28:12 -0400 Subject: [PATCH] Add streaming generate content test --- .../streaming-success-code-execution.txt | 16 ++++++ .../GoogleAITests/GenerativeModelTests.swift | 53 +++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 Tests/GoogleAITests/GenerateContentResponses/streaming-success-code-execution.txt diff --git a/Tests/GoogleAITests/GenerateContentResponses/streaming-success-code-execution.txt b/Tests/GoogleAITests/GenerateContentResponses/streaming-success-code-execution.txt new file mode 100644 index 0000000..24c9ef6 --- /dev/null +++ b/Tests/GoogleAITests/GenerateContentResponses/streaming-success-code-execution.txt @@ -0,0 +1,16 @@ +data: {"candidates": [{"content": {"parts": [{"text": "Thoughts"}],"role": "model"},"finishReason": "STOP","index": 0}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 1,"totalTokenCount": 22}} + +data: {"candidates": [{"content": {"parts": [{"text": ": I can use the `print()` function in Python to print strings. "}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 16,"totalTokenCount": 37}} + +data: {"candidates": [{"content": {"parts": [{"text": "\n\n"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 16,"totalTokenCount": 37}} + +data: {"candidates": [{"content": {"parts": [{"executableCode": {"language": "PYTHON","code": "\nprint(\"Hello, world!\")\n"}}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 29,"totalTokenCount": 50}} + +data: {"candidates": [{"content": {"parts": [{"codeExecutionResult": {"outcome": "OUTCOME_OK","output": "Hello, world!\n"}}],"role": "model"},"index": 0}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 29,"totalTokenCount": 50}} + +data: {"candidates": [{"content": {"parts": [{"text": "OK"}],"role": "model"},"finishReason": "STOP","index": 0}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 1,"totalTokenCount": 22}} + +data: {"candidates": [{"content": {"parts": [{"text": ". I have printed \"Hello, world!\" using the `print()` function in"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 17,"totalTokenCount": 38}} + +data: {"candidates": [{"content": {"parts": [{"text": " Python. \n"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 19,"totalTokenCount": 40}} + diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index 2822841..3dbe7cd 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -872,6 +872,59 @@ final class GenerativeModelTests: XCTestCase { })) } + func testGenerateContentStream_success_codeExecution() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "streaming-success-code-execution", + withExtension: "txt" + ) + let expectedTexts1 = [ + "Thoughts", + ": I can use the `print()` function in Python to print strings. ", + "\n\n", + ] + let expectedTexts2 = [ + "OK", + ". I have printed \"Hello, world!\" using the `print()` function in", + " Python. \n", + ] + let expectedTexts = Set(expectedTexts1 + expectedTexts2) + let expectedLanguage = "PYTHON" + let expectedCode = "\nprint(\"Hello, world!\")\n" + let expectedOutput = "Hello, world!\n" + + var textValues = [String]() + let stream = model.generateContentStream(testPrompt) + for try await content in stream { + let candidate = try XCTUnwrap(content.candidates.first) + let part = try XCTUnwrap(candidate.content.parts.first) + switch part { + case let .text(textPart): + XCTAssertTrue(expectedTexts.contains(textPart)) + case let .executableCode(executableCode): + XCTAssertEqual(executableCode.language, expectedLanguage) + XCTAssertEqual(executableCode.code, expectedCode) + case let .codeExecutionResult(codeExecutionResult): + XCTAssertEqual(codeExecutionResult.outcome, .ok) + XCTAssertEqual(codeExecutionResult.output, expectedOutput) + default: + XCTFail("Unexpected part type: \(part)") + } + try textValues.append(XCTUnwrap(content.text)) + } + + XCTAssertEqual(textValues.joined(separator: "\n"), """ + \(expectedTexts1.joined(separator: "\n")) + ```\(expectedLanguage.lowercased()) + \(expectedCode) + ``` + ``` + \(expectedOutput) + ``` + \(expectedTexts2.joined(separator: "\n")) + """) + } + func testGenerateContentStream_usageMetadata() async throws { MockURLProtocol .requestHandler = try httpRequestHandler(