Skip to content

Commit

Permalink
Add support for TensorFlow graph with "encoded_image_string_tensor" o…
Browse files Browse the repository at this point in the history
…peration
  • Loading branch information
kinhong committed Mar 12, 2020
1 parent be18615 commit 4f58120
Showing 1 changed file with 113 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
import javafx.application.Platform;
import javafx.beans.property.SimpleStringProperty;
import javafx.beans.property.StringProperty;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.ArrayUtils;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.Tensors;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;
Expand All @@ -40,6 +42,7 @@
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.*;
import java.text.MessageFormat;
import java.util.*;
Expand Down Expand Up @@ -128,6 +131,7 @@ private Void update(Path path) {
model = SavedModelBundle.load(savedModelFile.getAbsolutePath(), "serve");
String message = MessageFormat.format(bundle.getString("msg.loadedSavedModel"), savedModelFile);
LOG.info(message);
printSignature(model);
Platform.runLater(() -> statusProperty.set(message));
}
}
Expand All @@ -148,17 +152,31 @@ public List<HintModel> detect(File imageFile) throws IOException {
return Collections.emptyList();
}
List<HintModel> hints = new ArrayList();
//printSignature(model);
List<Tensor<?>> outputs;
Tensor<?> input = null;
String operation = "";
BufferedImage img = ImageIO.read(imageFile);
try (Tensor<UInt8> input = makeImageTensor(img)) {
try {
if (model.graph().operation("image_tensor") != null) {
operation = "image_tensor";
input = makeImageTensor(img);
}
else if (model.graph().operation("encoded_image_string_tensor") != null) {
operation = "encoded_image_string_tensor";
input = makeImageStringTensor(imageFile);
}
outputs = model.session()
.runner()
.feed("image_tensor", input)
.fetch("detection_scores")
.fetch("detection_classes")
.fetch("detection_boxes")
.run();
.runner()
.feed(operation, input)
.fetch("detection_scores")
.fetch("detection_classes")
.fetch("detection_boxes")
.run();
}
finally {
if (input != null) {
input.close();
}
}
try (Tensor<Float> scoresT = outputs.get(0).expect(Float.class);
Tensor<Float> classesT = outputs.get(1).expect(Float.class);
Expand Down Expand Up @@ -198,27 +216,30 @@ public List<HintModel> detect(File imageFile) throws IOException {
}
}

private static void printSignature(SavedModelBundle model) throws Exception {
MetaGraphDef m = MetaGraphDef.parseFrom(model.metaGraphDef());
SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
int numInputs = sig.getInputsCount();
int i = 1;
LOG.info("MODEL SIGNATURE");
LOG.info("Inputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
TensorInfo t = entry.getValue();
LOG.info(String.format("%d of %d: %-20s (Node name in graph: %-20s, type: %s)",
i++, numInputs, entry.getKey(), t.getName(), t.getDtype()));
}
int numOutputs = sig.getOutputsCount();
i = 1;
LOG.info("Outputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
TensorInfo t = entry.getValue();
LOG.info(String.format("%d of %d: %-20s (Node name in graph: %-20s, type: %s)",
i++, numOutputs, entry.getKey(), t.getName(), t.getDtype()));
private static void printSignature(SavedModelBundle model) {
try {
MetaGraphDef m = MetaGraphDef.parseFrom(model.metaGraphDef());
SignatureDef sig = m.getSignatureDefOrThrow("serving_default");
int numInputs = sig.getInputsCount();
int i = 1;
LOG.info("MODEL SIGNATURE");
LOG.info("Inputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getInputsMap().entrySet()) {
TensorInfo t = entry.getValue();
LOG.info(String.format("%d of %d: %-20s (Node name in graph: %-20s, type: %s)",
i++, numInputs, entry.getKey(), t.getName(), t.getDtype()));
}
int numOutputs = sig.getOutputsCount();
i = 1;
LOG.info("Outputs:");
for (Map.Entry<String, TensorInfo> entry : sig.getOutputsMap().entrySet()) {
TensorInfo t = entry.getValue();
LOG.info(String.format("%d of %d: %-20s (Node name in graph: %-20s, type: %s)",
i++, numOutputs, entry.getKey(), t.getName(), t.getDtype()));
}
LOG.info("-----------------------------------------------");
}
LOG.info("-----------------------------------------------");
catch (Exception ex) {}
}

private static String[] loadLabels() {
Expand Down Expand Up @@ -246,7 +267,7 @@ private static void bgr2rgb(byte[] data) {
}
}

private static Tensor<UInt8> makeImageTensor(BufferedImage img) throws IOException {
private static Tensor<?> makeImageTensor(BufferedImage img) throws IOException {
if (img.getType() == BufferedImage.TYPE_BYTE_INDEXED
|| img.getType() == BufferedImage.TYPE_BYTE_BINARY
|| img.getType() == BufferedImage.TYPE_BYTE_GRAY
Expand All @@ -269,4 +290,67 @@ private static Tensor<UInt8> makeImageTensor(BufferedImage img) throws IOExcepti
long[] shape = new long[]{BATCH_SIZE, img.getHeight(), img.getWidth(), CHANNELS};
return Tensor.create(UInt8.class, shape, ByteBuffer.wrap(data));
}

/**
* See <a href="https://github.com/tensorflow/tensorflow/issues/24331#issuecomment-447523402">GitHub issue</a>
*/
private static Tensor<?> makeImageStringTensor(File imageFile) throws IOException {
var content = FileUtils.readFileToByteArray(imageFile);
byte[][] data = { content };
return Tensors.create(data);
}

/**
* See <a href="https://github.com/tensorflow/tensorflow/issues/8244#issuecomment-477854356">GitHub issue</a>
*/
private static ByteBuffer stringArrayToBuffer(String[] values) throws IOException {
long offsets[] = new long[values.length];
byte[][] data = new byte[values.length][];
int dataSize = 0;

// Convert strings to encoded bytes and calculate required data size, including a varint for each of them
for (int i = 0; i < values.length; ++i) {
byte[] byteValue = values[i].getBytes("UTF-8");
data[i] = byteValue;
int length = byteValue.length + varintLength(byteValue.length);
dataSize += length;
if (i < values.length - 1) {
offsets[i + 1] = offsets[i] + length;
}
}

// Important: buffer must follow native byte order
ByteBuffer buffer = ByteBuffer.allocate(dataSize + (offsets.length * 8)).order(ByteOrder.nativeOrder());

// First, write offsets to each elements in the buffer
for (int i = 0; i < offsets.length; ++i) {
buffer.putLong(offsets[i]);
}

// Second, write strings bytes, each preceded by its length encoded as a varint
for (int i = 0; i < data.length; ++i) {
encodeVarint(buffer, data[i].length);
buffer.put(data[i]);
}

return buffer.rewind();
}

private static void encodeVarint(ByteBuffer buffer, int value) {
int v = value;
while (v >= 0x80) {
buffer.put((byte)((v & 0x7F) | 0x80));
v >>= 7;
}
buffer.put((byte)v);
}

private static int varintLength(int length) {
int len = 1;
while (length >= 0x80) {
length >>= 7;
++len;
}
return len;
}
}

0 comments on commit 4f58120

Please sign in to comment.