You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Description
The ONNX Runner application works correctly with the non-quantized version of the Qwen2-0.5B-Instruct model but encounters an error when trying to use the quantized version.
Working Code (Non-Quantized Version)
The following code works correctly with the non-quantized model:
import ai.onnxruntime.*;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
public class App {
// private static final String MODEL_PATH = "/home/davesoma/llms/OnnxRunner/app/src/main/resources/QwenOnnxMobile/qwen_mobile.onnx";
private static final String MODEL_PATH = "/home/davesoma/llms/OnnxRunner/app/build/resources/main/Qwen2-0.5B-Instruct/model_q4f16.onnx";
private static final int MAX_LENGTH = 500; // Adjusted for potentially longer responses
private static OrtEnvironment env;
private static OrtSession session;
public static void main(String[] args) {
String prompt = "System: You are an AI assistant. Answer the following question concisely.\n" +
"Human: What is the meaning of life?\n" +
"AI:";
try {
env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
session = env.createSession(MODEL_PATH, options);
CompletableFuture<Void> future = CompletableFuture.supplyAsync(() -> tokenizePrompt(prompt))
.thenAccept(inputIds -> generateResponse(inputIds))
.exceptionally(ex -> {
ex.printStackTrace();
return null;
});
System.out.println("AI: ");
future.join(); // Wait for the future to complete
} catch (Exception e) {
e.printStackTrace();
} finally {
cleanupResources();
}
}
private static void generateResponse(long[] inputIds) {
if (inputIds == null || inputIds.length == 0) {
System.err.println("No valid input tokens received.");
return;
}
try {
ArrayList<Long> generatedIds = new ArrayList<>(Arrays.asList(Arrays.stream(inputIds).boxed().toArray(Long[]::new)));
StringBuilder generatedText = new StringBuilder();
boolean isFirstToken = true;
for (int i = 0; i < MAX_LENGTH; i++) {
Map<String, OnnxTensor> inputs = prepareInputs(generatedIds);
try (OrtSession.Result results = session.run(inputs)) {
float[][][] logits = (float[][][]) results.get(0).getValue();
long nextToken = argmax(logits[0][logits[0].length - 1]);
if (nextToken == 0) {
break; // End of sequence
}
generatedIds.add(nextToken);
String decodedToken = decodeTokens(new long[]{nextToken});
// Properly handle spacing and word breaks
if (!isFirstToken) {
if (decodedToken.startsWith(" ") || decodedToken.startsWith("\n") ||
generatedText.toString().endsWith(" ") || generatedText.toString().endsWith("\n")) {
// Do nothing, space already exists
} else if (Character.isLetterOrDigit(decodedToken.charAt(0)) &&
Character.isLetterOrDigit(generatedText.charAt(generatedText.length() - 1))) {
System.out.print(" ");
generatedText.append(" ");
}
}
generatedText.append(decodedToken);
System.out.print(decodedToken);
System.out.flush();
isFirstToken = false;
}
}
System.out.println(); // New line at the end of the response
} catch (Exception e) {
e.printStackTrace();
}
}
private static Map<String, OnnxTensor> prepareInputs(ArrayList<Long> inputIds) throws OrtException {
long[] ids = inputIds.stream().mapToLong(l -> l).toArray();
OnnxTensor inputTensor = OnnxTensor.createTensor(env, new long[][]{ids});
long[] attentionMask = new long[ids.length];
Arrays.fill(attentionMask, 1);
OnnxTensor attentionMaskTensor = OnnxTensor.createTensor(env, new long[][]{attentionMask});
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("input_ids", inputTensor);
inputs.put("attention_mask", attentionMaskTensor);
int numLayers = 24;
int numKeyValueHeads = 2;
int headSize = 64;
int batchSize = 1;
for (int i = 0; i < numLayers; i++) {
float[][][][] pastKey = new float[batchSize][numKeyValueHeads][ids.length][headSize];
float[][][][] pastValue = new float[batchSize][numKeyValueHeads][ids.length][headSize];
inputs.put(String.format("past_key_values.%d.key", i), OnnxTensor.createTensor(env, pastKey));
inputs.put(String.format("past_key_values.%d.value", i), OnnxTensor.createTensor(env, pastValue));
}
return inputs;
}
private static long argmax(float[] array) {
int maxIndex = 0;
for (int i = 1; i < array.length; i++) {
if (array[i] > array[maxIndex]) {
maxIndex = i;
}
}
return maxIndex;
}
private static long[] tokenizePrompt(String prompt) {
try {
ProcessBuilder processBuilder = new ProcessBuilder(
"/home/davesoma/llms/OnnxRunner/app/src/main/python/venv/bin/python",
"/home/davesoma/llms/OnnxRunner/app/src/main/python/Tokenizer.py",
"tokenize", prompt);
processBuilder.redirectErrorStream(true);
Process process = processBuilder.start();
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
StringBuilder output = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
output.append(line);
}
String jsonOutput = output.toString().trim();
if (jsonOutput.isEmpty()) {
System.err.println("No output from tokenizer script.");
return null;
}
ObjectMapper mapper = new ObjectMapper();
return mapper.readValue(jsonOutput, long[].class);
} catch (Exception e) {
e.printStackTrace();
System.err.println("Failed to tokenize prompt: " + e.getMessage());
}
return null;
}
private static String decodeTokens(long[] tokenIds) {
try {
String tokenIdsJson = new ObjectMapper().writeValueAsString(tokenIds);
ProcessBuilder processBuilder = new ProcessBuilder(
"/home/davesoma/llms/OnnxRunner/app/src/main/python/venv/bin/python",
"/home/davesoma/llms/OnnxRunner/app/src/main/python/Tokenizer.py",
"decode", tokenIdsJson);
processBuilder.redirectErrorStream(true);
Process process = processBuilder.start();
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
StringBuilder output = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
output.append(line);
}
return output.toString().trim();
} catch (Exception e) {
e.printStackTrace();
System.err.println("Failed to decode tokens: " + e.getMessage());
}
return null;
}
private static void cleanupResources() {
if (session != null) {
try {
session.close();
} catch (OrtException e) {
e.printStackTrace();
}
}
if (env != null) {
env.close();
}
}
}
Error with Quantized Version
When trying to use the quantized model (model_q4f16.onnx), the following error is encountered:
ai.onnxruntime.OrtException: Error code - ORT_RUNTIME_EXCEPTION - message: Non-zero status code returned while running Reshape node. Name:'/model/Reshape' Status Message: /onnxruntime_src/include/onnxruntime/core/framework/op_kernel_context.h:42 const T* onnxruntime::OpKernelContext::Input(int) const [with T = onnxruntime::Tensor] Missing Input: position_ids
at ai.onnxruntime.OrtSession.run(Native Method)
at ai.onnxruntime.OrtSession.run(OrtSession.java:395)
at ai.onnxruntime.OrtSession.run(OrtSession.java:242)
at ai.onnxruntime.OrtSession.run(OrtSession.java:210)
at onnxrunner.App.generateResponse(App.java:294)
at onnxrunner.App.lambda$main$1(App.java:265)
at java.base/java.util.concurrent.CompletableFuture$UniAccept.tryFire(CompletableFuture.java:718)
at java.base/java.util.concurrent.CompletableFuture.postComplete(CompletableFuture.java:510)
at java.base/java.util.concurrent.CompletableFuture$AsyncSupply.run(CompletableFuture.java:1773)
at java.base/java.util.concurrent.CompletableFuture$AsyncSupply.exec(CompletableFuture.java:1760)
at java.base/java.util.concurrent.ForkJoinTask.doExec(ForkJoinTask.java:387)
at java.base/java.util.concurrent.ForkJoinPool$WorkQueue.topLevelExec(ForkJoinPool.java:1312)
at java.base/java.util.concurrent.ForkJoinPool.scan(ForkJoinPool.java:1843)
at java.base/java.util.concurrent.ForkJoinPool.runWorker(ForkJoinPool.java:1808)
at java.base/java.util.concurrent.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:188)
```
The error suggests that the quantized model might require a "position_ids" input that isn't provided in the current implementation.
The non-quantized version works without requiring this input.
The text was updated successfully, but these errors were encountered:
@Dave86ch, could you please check if the quantized models has "position_ids" input? if so, you need to add the "position_ids" under the "input" section.
If you still have issue, could you please share the non-quantized and quantized models? And how did you generate the quantized model?
Description
The ONNX Runner application works correctly with the non-quantized version of the Qwen2-0.5B-Instruct model but encounters an error when trying to use the quantized version.
Working Code (Non-Quantized Version)
The following code works correctly with the non-quantized model:
Error with Quantized Version
When trying to use the quantized model (model_q4f16.onnx), the following error is encountered:
The text was updated successfully, but these errors were encountered: