Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/rn] Supoort New Architecture #16669

Open
wants to merge 48 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
2354db1
TS: define NativeOnnxruntimeSpec
jhen0409 Jul 6, 2023
209b91f
setup codegen
jhen0409 Jul 6, 2023
efce8ed
Android: Split module implementation from oldarch/newarch
jhen0409 Jul 8, 2023
f176211
Bump @types/react-native to 0.71.3
jhen0409 Jul 8, 2023
70daa8d
Ignore unicorn/filename-case for NativeOnnxruntimeSpec
jhen0409 Jul 8, 2023
39dab48
Android: Fix incorrect dir path
jhen0409 Jul 8, 2023
8485286
E2E: Upgrade React Native to v0.71
jhen0409 Jul 10, 2023
1df0795
Android: Fix build
jhen0409 Jul 10, 2023
16cff1e
TS: Fix native spec
jhen0409 Jul 11, 2023
cff8ff2
Android: Still use legacy module for JSIHelper
jhen0409 Jul 11, 2023
e6f5bff
iOS: Support New Architecture
jhen0409 Jul 11, 2023
d99af7f
Ignore more paths
jhen0409 Jul 11, 2023
e0813d3
Format
jhen0409 Jul 11, 2023
0ce934a
Android: Fix path of E2E package name
jhen0409 Jul 11, 2023
b12e78d
Android: Fix OnnxruntimeSpec
jhen0409 Jul 11, 2023
a2f98f6
Android: Use Light style for E2E project
jhen0409 Jul 11, 2023
7dde3a6
Android: Fix path of E2E package name
jhen0409 Jul 11, 2023
e7a1c9d
Android: Always set blob module on check
jhen0409 Jul 11, 2023
c9e1714
Use v0.69 for native unit tests
jhen0409 Jul 11, 2023
61ff7ab
Android: Disable new-arch in E2E project for passed detox test
jhen0409 Jul 11, 2023
e4e389b
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Jul 11, 2023
1d55dc4
TS: Move Binding types to native module spec (options remain {})
jhen0409 Jul 12, 2023
77f54a3
Revert unnecessary changes
jhen0409 Jul 12, 2023
f964488
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Jul 12, 2023
ca38c81
iOS: Revert removed comments
jhen0409 Jul 12, 2023
5ba46e3
Format
jhen0409 Jul 12, 2023
200f14b
E2E: Remove local package links
jhen0409 Jul 12, 2023
7ea8b98
TS: Un-ban {} type only for NativeOnnxruntime spec
jhen0409 Jul 12, 2023
27d36df
Android: Revert rn_edit_text_material
jhen0409 Jul 12, 2023
b02c8e2
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Jul 13, 2023
3babf7e
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Jul 29, 2023
075770f
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Aug 7, 2023
9479f46
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Aug 28, 2023
99817d6
Doc: Remove unnecessary keygen step
jhen0409 Aug 28, 2023
4560820
Android: Move more duplicated code to Onnxruntime class
jhen0409 Aug 28, 2023
b50162a
Android: Use class name as 2nd arg for init ReactModuleInfo
jhen0409 Aug 28, 2023
ba5dc5d
Android: Fix tests
jhen0409 Aug 28, 2023
0b6479e
Android: Remove unnecessary code
jhen0409 Aug 28, 2023
8599948
iOS: Remove unnecessary patch
jhen0409 Aug 28, 2023
4fd5831
Android: Use 0.71 for unit tests & fix gradle build
jhen0409 Aug 29, 2023
93c195b
iOS: Update Podfile
jhen0409 Sep 7, 2023
110b381
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Sep 7, 2023
48e0e73
Android: Fix react-android dep for RN < 0.71
jhen0409 Sep 10, 2023
0a1988a
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Sep 28, 2023
5e08a9e
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Oct 19, 2023
e1a5b8d
Merge branch 'main' into jhen-rn-new-arch
jhen0409 Nov 12, 2023
446b12b
Revert unnecessary deps change
jhen0409 Nov 12, 2023
c409719
Remove dep that added hash in lockfile
jhen0409 Nov 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion js/.eslintrc.js
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,22 @@ module.exports = {
}
}, {
files: ['react_native/lib/**/*.ts'], rules: {
'@typescript-eslint/naming-convention': 'off'
'@typescript-eslint/naming-convention': 'off',
}
}, {
files: ['react_native/lib/NativeOnnxruntime.ts'], rules: {
'unicorn/filename-case': 'off',
'@typescript-eslint/ban-types': [
'error',
{
types: {
// NOTE: We got issue like https://github.com/facebook/react-native/issues/36431
// So we have to use `{}` type here.
'{}': false,
},
extendDefaults: true,
}
]
}
}, {
files: ['react_native/scripts/**/*.ts'], rules: {
Expand Down
2 changes: 2 additions & 0 deletions js/react_native/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ DerivedData
*.ipa
*.xcuserstate
project.xcworkspace
xcshareddata

# Android/IJ
#
.idea
.gradle
local.properties
android.iml
.cxx

# Cocoapods
#
Expand Down
24 changes: 24 additions & 0 deletions js/react_native/android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ buildscript {

apply plugin: 'com.android.library'

def isNewArchitectureEnabled() {
return rootProject.hasProperty("newArchEnabled") && rootProject.getProperty("newArchEnabled") == "true"
}

if (isNewArchitectureEnabled()) {
apply plugin: "com.facebook.react"
}

def getExtOrDefault(name) {
return rootProject.ext.has(name) ? rootProject.ext.get(name) : project.properties['OnnxruntimeModule_' + name]
}
Expand Down Expand Up @@ -90,6 +98,7 @@ android {
abiFilters (*reactNativeArchitectures())
}
}
buildConfigField "boolean", "IS_NEW_ARCHITECTURE_ENABLED", isNewArchitectureEnabled().toString()
}

if (rootProject.hasProperty("ndkPath")) {
Expand All @@ -115,6 +124,7 @@ android {
"META-INF",
"META-INF/**",
"**/libjsi.so",
"**/libc++_shared.so",
]
}

Expand All @@ -139,6 +149,12 @@ android {
} else {
java.exclude '**/OnnxruntimeExtensionsEnabled.java'
}

if (isNewArchitectureEnabled()) {
java.srcDirs += ['src/newarch']
} else {
java.srcDirs += ['src/oldarch']
}
}
}
}
Expand Down Expand Up @@ -238,3 +254,11 @@ dependencies {
implementation "com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.integration@aar"
}
}

if (isNewArchitectureEnabled()) {
react {
jsRootDir = file("../lib/")
libraryName = "OnnxruntimeSpec"
codegenJavaPackageName = "ai.onnxruntime.reactnative"
}
}
Copy link
Contributor

@YUNQIUGUO YUNQIUGUO Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

general:
I guess you mentioned that we are still using RN v0.69 for the unit tests, does it cause issue if we are using inconsistent react native versions across the project? Asking as we would really want to enable some tests to cover this new architecture, so wondering if it can cause issue when we running that in CI.

same question for the E2E test project, is it available for testing the new turbo module with the current set up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously I was not able run the react_native/android unit tests successfully with 0.71.x.

Today I tried again, it should be easy by use react-android for >= 0.71:

  if (REACT_NATIVE_MINOR_VERSION >= 71) {
    // REACT_NATIVE_VERSION >= 0.71.x use react-android (https://mvnrepository.com/artifact/com.facebook.react/react-android)
    // See also https://github.com/facebook/react-native/blob/0.71-stable/android/README.md
    api "com.facebook.react:react-android:" + REACT_NATIVE_VERSION
  } else {
    api "com.facebook.react:react-native:" + REACT_NATIVE_VERSION
  }

For RN project, the react-native-gradle-plugin should always converted react-native to react-android, but it's not used in the unit tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. that's great!

Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public void onnxruntime_module() throws Exception {

OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext);
ortModule.blobModule = blobModule;
ortModule.checkBlobModule();
String sessionKey = "";

// test loadModel()
Expand All @@ -82,7 +83,7 @@ public void onnxruntime_module() throws Exception {

JavaOnlyMap options = new JavaOnlyMap();
try {
ReadableMap resultMap = ortModule.loadModel(modelBuffer, options);
ReadableMap resultMap = ortModule.getOnnxruntime().loadModel(modelBuffer, options);
sessionKey = resultMap.getString("key");
ReadableArray inputNames = resultMap.getArray("inputNames");
ReadableArray outputNames = resultMap.getArray("outputNames");
Expand Down Expand Up @@ -132,7 +133,7 @@ public void onnxruntime_module() throws Exception {
options.putBoolean("encodeTensorData", true);

try {
ReadableMap resultMap = ortModule.run(sessionKey, inputDataMap, outputNames, options);
ReadableMap resultMap = ortModule.getOnnxruntime().run(sessionKey, inputDataMap, outputNames, options);

ReadableMap outputMap = resultMap.getMap("output");
for (int i = 0; i < 2; ++i) {
Expand All @@ -151,7 +152,7 @@ public void onnxruntime_module() throws Exception {
}

// test dispose
ortModule.dispose(sessionKey);
ortModule.getOnnxruntime().dispose(sessionKey);
} finally {
mockSession.finishMocking();
}
Expand All @@ -166,6 +167,7 @@ public void onnxruntime_module_append_nnapi() throws Exception {

OnnxruntimeModule ortModule = new OnnxruntimeModule(reactContext);
ortModule.blobModule = blobModule;
ortModule.checkBlobModule();
String sessionKey = "";

// test loadModel() with nnapi ep options
Expand All @@ -182,7 +184,7 @@ public void onnxruntime_module_append_nnapi() throws Exception {
options.putArray("executionProviders", epArray);

try {
ReadableMap resultMap = ortModule.loadModel(modelBuffer, options);
ReadableMap resultMap = ortModule.getOnnxruntime().loadModel(modelBuffer, options);
sessionKey = resultMap.getString("key");
ReadableArray inputNames = resultMap.getArray("inputNames");
ReadableArray outputNames = resultMap.getArray("outputNames");
Expand All @@ -195,7 +197,7 @@ public void onnxruntime_module_append_nnapi() throws Exception {
Assert.fail(e.getMessage());
}
}
ortModule.dispose(sessionKey);
ortModule.getOnnxruntime().dispose(sessionKey);
} finally {
mockSession.finishMocking();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,16 @@
import android.net.Uri;
import android.os.Build;
import android.util.Log;
import androidx.annotation.NonNull;
import androidx.annotation.RequiresApi;
import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.LifecycleEventListener;
import com.facebook.react.bridge.Promise;
import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReactContextBaseJavaModule;
import com.facebook.react.bridge.ReactMethod;
import com.facebook.react.bridge.ReadableArray;
import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.bridge.ReadableType;
import com.facebook.react.bridge.WritableArray;
import com.facebook.react.bridge.WritableMap;
import com.facebook.react.modules.blob.BlobModule;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
Expand All @@ -46,7 +40,7 @@
import java.util.stream.Stream;

@RequiresApi(api = Build.VERSION_CODES.N)
public class OnnxruntimeModule extends ReactContextBaseJavaModule implements LifecycleEventListener {
public class Onnxruntime {
private static ReactApplicationContext reactContext;

private static OrtEnvironment ortEnvironment = OrtEnvironment.getEnvironment();
Expand All @@ -59,104 +53,13 @@ private static String getNextSessionKey() {
return key;
}

protected BlobModule blobModule;
private BlobModule blobModule;

public OnnxruntimeModule(ReactApplicationContext context) {
super(context);
reactContext = context;
}
public Onnxruntime(ReactApplicationContext context) { reactContext = context; }

@NonNull
@Override
public String getName() {
return "Onnxruntime";
}
protected void setBlobModule(BlobModule blobModule) { this.blobModule = blobModule; }

public void checkBlobModule() {
if (blobModule == null) {
blobModule = getReactApplicationContext().getNativeModule(BlobModule.class);
if (blobModule == null) {
throw new RuntimeException("BlobModule is not initialized");
}
}
}

/**
* React native binding API to load a model using given uri.
*
* @param uri a model file location
* @param options onnxruntime session options
* @param promise output returning back to react native js
* @note the value provided to `promise` includes a key representing the session.
* when run() is called, the key must be passed into the first parameter.
*/
@ReactMethod
public void loadModel(String uri, ReadableMap options, Promise promise) {
try {
WritableMap resultMap = loadModel(uri, options);
promise.resolve(resultMap);
} catch (Exception e) {
promise.reject("Failed to load model \"" + uri + "\": " + e.getMessage(), e);
}
}

/**
* React native binding API to load a model using blob object that data stored in BlobModule.
*
* @param data the blob object
* @param options onnxruntime session options
* @param promise output returning back to react native js
* @note the value provided to `promise` includes a key representing the session.
* when run() is called, the key must be passed into the first parameter.
*/
@ReactMethod
public void loadModelFromBlob(ReadableMap data, ReadableMap options, Promise promise) {
try {
checkBlobModule();
String blobId = data.getString("blobId");
byte[] bytes = blobModule.resolve(blobId, data.getInt("offset"), data.getInt("size"));
blobModule.remove(blobId);
WritableMap resultMap = loadModel(bytes, options);
promise.resolve(resultMap);
} catch (Exception e) {
promise.reject("Failed to load model from buffer: " + e.getMessage(), e);
}
}

/**
* React native binding API to dispose a session.
*
* @param key session key representing a session given at loadModel()
* @param promise output returning back to react native js
*/
@ReactMethod
public void dispose(String key, Promise promise) {
try {
dispose(key);
promise.resolve(null);
} catch (OrtException e) {
promise.reject("Failed to dispose session: " + e.getMessage(), e);
}
}

/**
* React native binding API to run a model using given uri.
*
* @param key session key representing a session given at loadModel()
* @param input an input tensor
* @param output an output names to be returned
* @param options onnxruntime run options
* @param promise output returning back to react native js
*/
@ReactMethod
public void run(String key, ReadableMap input, ReadableArray output, ReadableMap options, Promise promise) {
try {
WritableMap resultMap = run(key, input, output, options);
promise.resolve(resultMap);
} catch (Exception e) {
promise.reject("Fail to inference: " + e.getMessage(), e);
}
}
protected Map<String, OrtSession> getSessionMap() { return sessionMap; }

/**
* Load a model from raw resource directory.
Expand Down Expand Up @@ -259,8 +162,6 @@ public WritableMap run(String key, ReadableMap input, ReadableArray output, Read

RunOptions runOptions = parseRunOptions(options);

checkBlobModule();

long startTime = System.currentTimeMillis();
Map<String, OnnxTensor> feed = new HashMap<>();
Iterator<String> iterator = ortSession.getInputNames().iterator();
Expand Down Expand Up @@ -436,22 +337,4 @@ private RunOptions parseRunOptions(ReadableMap options) throws OrtException {

return runOptions;
}

@Override
public void onHostResume() {}

@Override
public void onHostPause() {}

@Override
public void onHostDestroy() {
for (String key : sessionMap.keySet()) {
try {
dispose(key);
} catch (Exception e) {
Log.e("onHostDestroy", "Failed to dispose session: " + key, e);
}
}
sessionMap.clear();
}
}
Loading