Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PublisherBytesSupplier #905

Merged
merged 1 commit into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.serving.cache;

import ai.djl.inference.streaming.ChunkedBytesSupplier;
import ai.djl.inference.streaming.PublisherBytesSupplier;
import ai.djl.modality.Output;

import cloud.localstack.Localstack;
Expand Down Expand Up @@ -105,15 +106,14 @@ public void testS3CacheEngine() throws IOException, ExecutionException, Interrup
// Helper Functions

private void testCacheEngine(CacheEngine engine)
throws IOException, ExecutionException, InterruptedException {
throws ExecutionException, InterruptedException {
Assert.assertFalse(engine.isMultiTenant());

testBasic(engine);
testStream(engine);
}

private void testBasic(CacheEngine engine)
throws ExecutionException, InterruptedException, IOException {
private void testBasic(CacheEngine engine) throws ExecutionException, InterruptedException {
// Test cache miss
Output o = engine.get("none-exist-key", Integer.MAX_VALUE);
Assert.assertNull(o);
Expand Down Expand Up @@ -152,21 +152,41 @@ private void testBasic(CacheEngine engine)
engine.remove(key1);
}

private void testStream(CacheEngine engine)
throws IOException, ExecutionException, InterruptedException {
private void testStream(CacheEngine engine) throws ExecutionException, InterruptedException {
// Test ChunkedBytesSupplier streaming
testStreamType(engine, true);

// Test PublisherBytesSupplier streaming
testStreamType(engine, false);
}

private void testStreamType(CacheEngine engine, boolean chunkedVsPublisher)
throws ExecutionException, InterruptedException {
String key = engine.create();
Output output = new Output();
output.addProperty("x-next-token", key);
CompletableFuture<Void> future = engine.put(key, output);
future.get();

Output output2 = new Output();
ChunkedBytesSupplier cbs2 = new ChunkedBytesSupplier();
output2.add(cbs2);
cbs2.appendContent(buf, false);
future = engine.put(key, output2);
for (int i = 0; i < 21; ++i) {
cbs2.appendContent(buf, i == 20);
if (chunkedVsPublisher) {
// Test ChunkedBytesSupplier streaming
ChunkedBytesSupplier cbs2 = new ChunkedBytesSupplier();
output2.add(cbs2);
cbs2.appendContent(buf, false);
future = engine.put(key, output2);
for (int i = 0; i < 21; ++i) {
cbs2.appendContent(buf, i == 20);
}
} else {
// Test PublisherBytesSupplier streaming
PublisherBytesSupplier pbs = new PublisherBytesSupplier();
output2.add(pbs);
pbs.appendContent(buf, false);
future = engine.put(key, output2);
for (int i = 0; i < 21; ++i) {
pbs.appendContent(buf, i == 20);
}
}
future.get();

Expand All @@ -175,6 +195,13 @@ private void testStream(CacheEngine engine)
if (engine instanceof BaseCacheEngine) {
// 1 for initial input, 1 for write batch
expectedBatch = 1 + ((BaseCacheEngine) engine).getWriteBatch();

if (!chunkedVsPublisher && (!(engine instanceof MemoryCacheEngine))) {
// The PublisherBytesSupplier doesn't include data as part of the first item, only
// properties
// This applies when serializing, so it doesn't affect the MemoryCacheEngine
expectedBatch--;
}
}
Assert.assertEquals(o.getCode(), 200);
String nextToken = o.getProperty("x-next-token", null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,18 @@ public CompletableFuture<Void> put(String key, Output output) {
o.setMessage(output.getMessage());
o.setProperties(output.getProperties());
PublisherBytesSupplier pub = (PublisherBytesSupplier) supplier;
AtomicInteger index = new AtomicInteger();
AtomicInteger index = new AtomicInteger(-1);
List<byte[]> list = new ArrayList<>(writeBatch);
putStream(key, o, null, index.incrementAndGet(), false);
pub.subscribe(
buf -> {
try {
if (buf == null) {
byte[] batch = joinBytes(list);
putStream(
key,
null,
null,
batch,
index.incrementAndGet(),
true);
} else if (buf.length > 0) {
Expand All @@ -116,12 +117,6 @@ public CompletableFuture<Void> put(String key, Output output) {
false);
list.clear();
}
putStream(
key,
o,
buf,
index.incrementAndGet(),
false);
}
} catch (IOException e) {
throw new CompletionException(e);
Expand Down
Loading