Skip to content

Commit

Permalink
Non streaming for rolling batch (#881)
Browse files Browse the repository at this point in the history
* Non streaming for rolling batch
  • Loading branch information
sindhuvahinis authored Jun 29, 2023
1 parent 36b750b commit cfb23d8
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 23 deletions.
26 changes: 13 additions & 13 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,7 @@ def initialize(self, properties: dict):
properties.get("dtype"))
self.rolling_batch_type = properties.get("rolling_batch", None)

if self.enable_streaming:
self._init_model_and_tokenizer(model_id_or_path, **kwargs)
self.initialized = True
return
elif self.rolling_batch_type:
if self.rolling_batch_type:
self.rolling_batch_type = self.rolling_batch_type.lower()
is_mpi = properties.get("engine") != "Python"
if is_mpi:
Expand All @@ -163,6 +159,10 @@ def initialize(self, properties: dict):

self.initialized = True
return
elif self.enable_streaming:
self._init_model_and_tokenizer(model_id_or_path, **kwargs)
self.initialized = True
return

if not task:
task = self.infer_task_from_model_architecture(model_id_or_path)
Expand Down Expand Up @@ -207,7 +207,14 @@ def inference(self, inputs):

outputs = Output()

if self.enable_streaming:
if self.rolling_batch_type:
result = self.rolling_batch.inference(input_data, parameters)
for i in range(len(batch)):
res = result[i]
outputs.add_as_json(res, batch_index=i)

return outputs
elif self.enable_streaming:
outputs.add_property("content-type", "application/jsonlines")
if self.enable_streaming == "huggingface":
outputs.add_stream_content(
Expand All @@ -222,13 +229,6 @@ def inference(self, inputs):
input_data, self.device,
**parameters[0]))
return outputs
elif self.rolling_batch_type:
result = self.rolling_batch.inference(input_data, parameters)
for i in range(len(batch)):
res = result[i]
outputs.add_as_json(res, batch_index=i)

return outputs

prediction = self.hf_pipeline(input_data, **parameters[0])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ public PyPredictor(
this.process = process;
this.timeout = timeout;
isRollingBatch = model.getProperty("rolling_batch") != null;
boolean enableStreaming =
Boolean.parseBoolean(model.getProperty("enable_streaming", "false"));
if (isRollingBatch) {
int maxRollingBatchSize =
Integer.parseInt(model.getProperty("max_rolling_batch_size", "3"));
rollingBatch = new RollingBatch(process, maxRollingBatchSize, timeout);
rollingBatch = new RollingBatch(process, maxRollingBatchSize, timeout, enableStreaming);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ class RollingBatch implements Runnable {
private ReentrantLock lock;
private Condition canAdd;
private Condition canRead;
private boolean enableStreaming;

RollingBatch(PyProcess process, int maxRollingBatchSize, int timeout) {
RollingBatch(PyProcess process, int maxRollingBatchSize, int timeout, boolean enableStreaming) {
this.process = process;
this.maxRollingBatchSize = maxRollingBatchSize;
this.timeout = timeout;
this.enableStreaming = enableStreaming;
list = new ArrayList<>(3);
lock = new ReentrantLock(true);
canAdd = lock.newCondition();
Expand Down Expand Up @@ -97,7 +99,7 @@ public void run() {
for (int i = 0; i < size; ++i) {
Request status = list.get(i);
String json = content.get(i).getValue().getAsString();
status.addResponse(json);
status.addResponse(json, enableStreaming);
}
list.removeIf(status -> status.last);
if (list.size() < maxRollingBatchSize) {
Expand All @@ -122,7 +124,7 @@ public Output addInput(Input input, int timeout) throws TranslateException {
throw new TranslateException("Time out in: " + timeout);
}
}
Request req = new Request(input);
Request req = new Request(input, enableStreaming);
list.add(req);
canRead.signal();
return req.output;
Expand All @@ -144,28 +146,41 @@ private static final class Request {
Input input;
ChunkedBytesSupplier data;
Output output;
String nextToken;
StringBuilder nextToken; // NOPMD
boolean last;

Request(Input input) {
Request(Input input, boolean enableStreaming) {
this.input = input;
data = new ChunkedBytesSupplier();
output = new Output();
output.add(data);
if (enableStreaming) {
nextToken = new StringBuilder();
} else {
nextToken = new StringBuilder(1024);
}
}

BytesSupplier getRequest() {
if (nextToken != null) {
if (nextToken.length() != 0) {
return BytesSupplier.wrap("{\"inputs\": [\"\"]}");
}
return input.getData();
}

void addResponse(String json) {
void addResponse(String json, boolean enableStreaming) {
JsonObject element = JsonUtils.GSON.fromJson(json, JsonObject.class);
last = element.get("last").getAsBoolean();
nextToken = element.get("data").getAsString();
data.appendContent(BytesSupplier.wrap(nextToken), last);
if (enableStreaming) {
nextToken.setLength(0);
nextToken.append(element.get("data").getAsString());
data.appendContent(BytesSupplier.wrap(nextToken.toString()), last);
} else {
nextToken.append(element.get("data").getAsString());
if (last) {
data.appendContent(BytesSupplier.wrap(nextToken.toString()), true);
}
}
}
}
}

0 comments on commit cfb23d8

Please sign in to comment.