From 173cb55947a47d4d680080644fe5b03c78ee5df8 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Tue, 4 Jul 2023 10:41:46 -0700 Subject: [PATCH] [python] Add generated_text key to rolling batch output --- .../src/main/java/ai/djl/python/engine/RollingBatch.java | 5 ++++- .../src/test/java/ai/djl/python/engine/PyEngineTest.java | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java b/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java index b33b3bad81..5d387869f8 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java +++ b/engines/python/src/main/java/ai/djl/python/engine/RollingBatch.java @@ -179,7 +179,10 @@ void addResponse(String json, boolean enableStreaming) { } else { nextToken.append(element.get("data").getAsString()); if (last) { - data.appendContent(BytesSupplier.wrap(nextToken.toString()), true); + data.appendContent( + BytesSupplier.wrap( + "{\"generated_text\": [\"" + nextToken.toString() + "\"]}"), + true); } } } diff --git a/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java b/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java index 445fe492a6..dd834b452e 100644 --- a/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java +++ b/engines/python/src/test/java/ai/djl/python/engine/PyEngineTest.java @@ -427,7 +427,7 @@ public void testRollingBatch() throws TranslateException, IOException, ModelExce Assert.assertNull(cbs.pollChunk()); String ret = cbs.getAsString(); System.out.println(ret); - Assert.assertTrue(ret.startsWith(" token_request4_")); + Assert.assertTrue(ret.startsWith("{\"generated_text\": [\" token_request4_")); } }