Skip to content

Commit

Permalink
[onnx] Allows to customize onnxruntime optimization level (#2137)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Jul 2, 2024
1 parent 8b8a00f commit 4fafd87
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions wlm/src/main/java/ai/djl/serving/wlm/LmiUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,11 @@ static void convertOnnxModel(ModelInfo<?, ?> info) throws IOException {
if (modelId == null) {
modelId = repo.toString();
}
info.modelUrl = convertOnnx(modelId).toUri().toURL().toString();
String optimization = info.prop.getProperty("option.optimization");
info.modelUrl = convertOnnx(modelId, optimization).toUri().toURL().toString();
}

private static Path convertOnnx(String modelId) throws IOException {
private static Path convertOnnx(String modelId, String optimization) throws IOException {
logger.info("Converting model to onnx artifacts");
String hash = Utils.hash(modelId);
String download = Utils.getenv("SERVING_DOWNLOAD_DIR", null);
Expand All @@ -174,6 +175,11 @@ private static Path convertOnnx(String modelId) throws IOException {

Engine onnx = Engine.getEngine("OnnxRuntime");
boolean hasCuda = onnx.getGpuCount() > 0;
if (optimization == null || optimization.isBlank()) {
optimization = hasCuda ? "O4" : "O2";
} else if (!optimization.matches("O\\d")) {
throw new IllegalArgumentException("Unsupported optimization level: " + optimization);
}

String[] cmd = {
"djl-convert",
Expand All @@ -184,7 +190,7 @@ private static Path convertOnnx(String modelId) throws IOException {
"-m",
modelId,
"--optimize",
hasCuda ? "O4" : "O2",
optimization,
"--device",
hasCuda ? "cuda" : "cpu"
};
Expand Down

0 comments on commit 4fafd87

Please sign in to comment.