Skip to content

Commit

Permalink
[JAVA_API] Wrapper for ov::PartialShape (openvinotoolkit#883)
Browse files Browse the repository at this point in the history
* Get partial shape from output

* Update PartialShape getDimension to align with C++ API

* Use pointer arithmetic to access Partial Shape dimension

Co-authored-by: Nesterov Alexander <[email protected]>

* Remove redundant delete method

* Revert "Use pointer arithmetic to access Partial Shape dimension"

---------

Co-authored-by: Nesterov Alexander <[email protected]>
Co-authored-by: Anna Likholat <[email protected]>
  • Loading branch information
3 people authored Mar 13, 2024
1 parent ec67eed commit 66557de
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 10 deletions.
6 changes: 0 additions & 6 deletions modules/java_api/src/main/cpp/dimension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,3 @@ JNIEXPORT jint JNICALL Java_org_intel_openvino_Dimension_getLength(JNIEnv *env,
)
return 0;
}

JNIEXPORT void JNICALL Java_org_intel_openvino_Dimension_delete(JNIEnv *, jobject, jlong addr)
{
Dimension *dim = (Dimension *)addr;
delete dim;
}
7 changes: 6 additions & 1 deletion modules/java_api/src/main/cpp/openvino_java.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,19 @@ extern "C"

// ov::Dimension
JNIEXPORT jint JNICALL Java_org_intel_openvino_Dimension_getLength(JNIEnv *, jobject, jlong);
JNIEXPORT void JNICALL Java_org_intel_openvino_Dimension_delete(JNIEnv *, jobject, jlong);

// ov::Output<ov::Node>
JNIEXPORT jstring JNICALL Java_org_intel_openvino_Output_GetAnyName(JNIEnv *, jobject, jlong);
JNIEXPORT jintArray JNICALL Java_org_intel_openvino_Output_GetShape(JNIEnv *, jobject, jlong);
JNIEXPORT jlong JNICALL Java_org_intel_openvino_Output_GetPartialShape(JNIEnv *, jobject, jlong);
JNIEXPORT int JNICALL Java_org_intel_openvino_Output_GetElementType(JNIEnv *, jobject, jlong);
JNIEXPORT void JNICALL Java_org_intel_openvino_Output_delete(JNIEnv *, jobject, jlong);

// ov::PartialShape
JNIEXPORT jlong JNICALL Java_org_intel_openvino_PartialShape_GetDimension(JNIEnv *, jobject, jlong, jint);
JNIEXPORT jintArray JNICALL Java_org_intel_openvino_PartialShape_GetMaxShape(JNIEnv *, jobject, jlong);
JNIEXPORT jintArray JNICALL Java_org_intel_openvino_PartialShape_GetMinShape(JNIEnv *, jobject, jlong);

#ifdef __cplusplus
}
#endif
10 changes: 10 additions & 0 deletions modules/java_api/src/main/cpp/output.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ JNIEXPORT int JNICALL Java_org_intel_openvino_Output_GetElementType(JNIEnv *env,
return 0;
}

JNIEXPORT jlong JNICALL Java_org_intel_openvino_Output_GetPartialShape(JNIEnv *env, jobject obj, jlong addr) {
JNI_METHOD("GetPartialShape",
Output<Node> *output = (Output<Node> *)addr;
const PartialShape& partialShape = output->get_partial_shape();

return (jlong) &partialShape;
)
return 0;
}

JNIEXPORT void JNICALL Java_org_intel_openvino_Output_delete(JNIEnv *, jobject, jlong addr)
{
Output<Node> *obj = (Output<Node> *)addr;
Expand Down
56 changes: 56 additions & 0 deletions modules/java_api/src/main/cpp/partial_shape.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (C) 2020-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <jni.h> // JNI header provided by JDK
#include "openvino/openvino.hpp"

#include "openvino_java.hpp"
#include "jni_common.hpp"

using namespace ov;

JNIEXPORT jlong JNICALL Java_org_intel_openvino_PartialShape_GetDimension(JNIEnv *env, jobject obj, jlong addr, jint index) {
JNI_METHOD("GetDimension",
PartialShape* partial_shape = (PartialShape *)addr;
return (jlong) &(*partial_shape)[index];
)
return 0;
}

JNIEXPORT jintArray JNICALL Java_org_intel_openvino_PartialShape_GetMaxShape(JNIEnv *env, jobject obj, jlong addr) {
JNI_METHOD("GetMaxShape",
PartialShape* partial_shape = (PartialShape *)addr;
Shape max_shape = partial_shape->get_max_shape();

jintArray result = env->NewIntArray(max_shape.size());
if (!result) {
throw std::runtime_error("Out of memory!");
} jint *arr = env->GetIntArrayElements(result, nullptr);

for (int i = 0; i < max_shape.size(); ++i)
arr[i] = max_shape[i];

env->ReleaseIntArrayElements(result, arr, 0);
return result;
)
return 0;
}

JNIEXPORT jintArray JNICALL Java_org_intel_openvino_PartialShape_GetMinShape(JNIEnv *env, jobject obj, jlong addr) {
JNI_METHOD("GetMinShape",
PartialShape* partial_shape = (PartialShape *)addr;
Shape min_shape = partial_shape->get_min_shape();

jintArray result = env->NewIntArray(min_shape.size());
if (!result) {
throw std::runtime_error("Out of memory!");
} jint *arr = env->GetIntArrayElements(result, nullptr);

for (int i = 0; i < min_shape.size(); ++i)
arr[i] = min_shape[i];

env->ReleaseIntArrayElements(result, arr, 0);
return result;
)
return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,4 @@ public int get_length() {

/*----------------------------------- native methods -----------------------------------*/
private static native int getLength(long addr);

@Override
protected native void delete(long nativeObj);
}
7 changes: 7 additions & 0 deletions modules/java_api/src/main/java/org/intel/openvino/Output.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ public int[] get_shape() {
return GetShape(nativeObj);
}

/** Returns the partial shape of the output referred to by this output handle. */
public PartialShape get_partial_shape() {
return new PartialShape(GetPartialShape(nativeObj));
}

/** Returns the element type of the output referred to by this output handle. */
public ElementType get_element_type() {
return ElementType.valueOf(GetElementType(nativeObj));
Expand All @@ -30,6 +35,8 @@ public ElementType get_element_type() {

private static native int[] GetShape(long addr);

private static native long GetPartialShape(long addr);

private static native int GetElementType(long addr);

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (C) 2020-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

package org.intel.openvino;

/** This class represents the definitions and operations about partial shape. */
public class PartialShape extends Wrapper {

public PartialShape(long addr) {
super(addr);
}

/**
* Get the dimension at specified index of a partial shape.
*
* @param index The index of dimension.
* @return The particular dimension of partial shape.
*/
public Dimension get_dimension(int index) {
return new Dimension(GetDimension(nativeObj, index));
}

/** Returns the max bounding shape. */
public int[] get_max_shape() {
return GetMaxShape(nativeObj);
}

/** Returns the min bounding shape. */
public int[] get_min_shape() {
return GetMinShape(nativeObj);
}

/*----------------------------------- native methods -----------------------------------*/
private static native long GetDimension(long addr, int index);

private static native int[] GetMaxShape(long addr);

private static native int[] GetMinShape(long addr);
}
14 changes: 14 additions & 0 deletions modules/java_api/src/test/java/org/intel/openvino/ModelTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,20 @@ public void testGetShape() {
assertArrayEquals("Shape", ref, outputs.get(0).get_shape());
}

@Test
public void testGetPartialShape() {
ArrayList<Output> outputs = net.outputs();
int[] ref = new int[] {1, 10};

PartialShape partialShape = outputs.get(0).get_partial_shape();
for (int i = 0; i < ref.length; i++) {
Dimension dim = partialShape.get_dimension(i);
assertEquals(ref[i], dim.get_length());
}
assertArrayEquals("MaxShape", ref, partialShape.get_max_shape());
assertArrayEquals("MinShape", ref, partialShape.get_min_shape());
}

@Test
public void testReshape() {
int[] inpDims = net.input().get_shape();
Expand Down

0 comments on commit 66557de

Please sign in to comment.