Skip to content

Commit

Permalink
Merge pull request #34 from EliseChouleur/jni_from_rust_thread
Browse files Browse the repository at this point in the history
Be able to call JNI method from thread
  • Loading branch information
EliseChouleur authored Oct 30, 2023
2 parents 449a4b1 + afe8133 commit 063f848
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 23 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ but you can switch to `#[call_type(unchecked)]` at any time, most likely with fe

You can also force a Java type on input arguments via `#[input_type]` attribute, which can be useful for Android JNI development for example.

### Android specificities

On Android App, to call a Java class from rust the JVM use the callstack to find desired class.
But when in a rust thread, you don't have a call stack anymore.\
So to be able to call a Java class you have to pass the class reference rather than the string class path.

You can find an example of this usage in `robusta-android-example/src/thread_func.rs`

## Code example

You can find an example under `./robusta-example`. To run it you should have `java` and `javac` on your PATH and then execute:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,13 @@ public static String getAppFilesDir(Context context) {
Log.d("ROBUSTA_ANDROID_EXAMPLE", "getAppFilesDir IN");
return context.getFilesDir().toString();
}
public static int threadTestNoClass(String s) {
Log.d("ROBUSTA_ANDROID_EXAMPLE", "threadTestNoClass IN: " + s);
return 10;
}
public static int threadTestWithClass(String s) {
Log.d("ROBUSTA_ANDROID_EXAMPLE", "threadTestWithClass IN: " + s);
return 10;
}

}
58 changes: 57 additions & 1 deletion robusta-android-example/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
pub(crate) mod thread_func;

use ::jni::objects::GlobalRef;
use ::jni::JavaVM;
use robusta_jni::bridge;
use std::sync::OnceLock;

static APP_CONTEXT: OnceLock<(JavaVM, GlobalRef)> = OnceLock::new();

#[bridge]
mod jni {
use crate::APP_CONTEXT;
use android_logger::Config;
use jni::objects::JObject;
use jni::objects::{GlobalRef, JObject, JValue};
use log::info;
use robusta_jni::convert::{IntoJavaValue, Signature, TryFromJavaValue, TryIntoJavaValue};
use robusta_jni::jni::errors::Result as JniResult;
use robusta_jni::jni::objects::AutoLocal;
use robusta_jni::jni::JNIEnv;
use std::thread;

#[derive(Signature, TryIntoJavaValue, IntoJavaValue, TryFromJavaValue)]
#[package(com.example.robustaandroidexample)]
Expand All @@ -24,14 +33,61 @@ mod jni {
.with_min_level(log::Level::Debug)
.with_tag("RUST_ROBUSTA_ANDROID_EXAMPLE"),
);

info!("TEST START");
let java_class = env
.find_class("com/example/robustaandroidexample/RobustaAndroidExample")
.unwrap();
let _ = APP_CONTEXT.set((
env.get_java_vm().unwrap(),
env.new_global_ref(java_class).unwrap(),
));

let app_files_dir = RobustaAndroidExample::getAppFilesDir(env, context).unwrap();
info!("App files dir: {}", app_files_dir);

assert_eq!(
RobustaAndroidExample::threadTestNoClass(env, "test".to_string()).unwrap(),
10
);

let test_string = env.new_string("SUPER TEST").unwrap();
let test_string = JValue::from(test_string);
let met_call = env.call_static_method(
"com/example/robustaandroidexample/RobustaAndroidExample",
"threadTestNoClass",
"(Ljava/lang/String;)I",
&[test_string],
);
assert!(met_call.is_ok());

let thread_handler = thread::Builder::new()
.name("test_thread_fail".to_string())
.spawn(move || crate::thread_func::thread_test_fail());
let join_res = thread_handler.unwrap().join().unwrap();
assert!(join_res.is_err());

let thread_handler = thread::Builder::new()
.name("test_thread_good".to_string())
.spawn(move || crate::thread_func::thread_test_good());
let join_res = thread_handler.unwrap().join().unwrap();
assert!(join_res.is_ok());

info!("TEST END");
}

pub extern "java" fn getAppFilesDir(
env: &JNIEnv,
#[input_type("Landroid/content/Context;")] context: JObject,
) -> JniResult<String> {
}

pub extern "java" fn threadTestNoClass(env: &JNIEnv, s: String) -> JniResult<i32> {}
pub extern "java" fn threadTestWithClass(
env: &JNIEnv,
class_ref: &GlobalRef,
s: String,
) -> JniResult<i32> {
}
}
}
90 changes: 90 additions & 0 deletions robusta-android-example/src/thread_func.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
use crate::jni::RobustaAndroidExample;
use jni::objects::JValue;
use log::{debug, error};

pub(crate) fn thread_test_fail() -> Result<(), String> {
debug!("TEST_THREAD_FAIL: start...");

let (app_vm, _) = crate::APP_CONTEXT
.get()
.ok_or_else(|| "Couldn't get APP_CONTEXT".to_string())?;
let env = app_vm
.attach_current_thread_permanently()
.map_err(|_| "Couldn't attach to current thread".to_string())?;

debug!("TEST_THREAD_FAIL: via JNI");
let test_string = env.new_string("SUPER TEST").unwrap();
let test_string = JValue::from(test_string);
if let Err(e) = env.call_static_method(
"com/example/robustaandroidexample/RobustaAndroidExample",
"threadTestNoClass",
"(Ljava/lang/String;)I",
&[test_string],
) {
error!("Couldn't call method via classic JNI: {}", e);
if env.exception_check().unwrap_or(false) {
let _ = env.exception_clear();
}
}

debug!("TEST_THREAD_FAIL: via Robusta");

/* Call methode */
if let Err(e) = RobustaAndroidExample::threadTestNoClass(&env, "test".to_string()) {
let msg = format!("Couldn't call method via Robusta: {}", e);
error!("{}", msg);
if env.exception_check().unwrap_or(false) {
let _ = env.exception_clear();
}
return Err(msg);
}
Ok(())
}

pub(crate) fn thread_test_good() -> Result<(), String> {
debug!("TEST_THREAD_GOOD: start...");

let (app_vm, class_ref) = crate::APP_CONTEXT
.get()
.ok_or_else(|| "Couldn't get APP_CONTEXT".to_string())?;
let env = app_vm
.attach_current_thread_permanently()
.map_err(|_| "Couldn't attach to current thread".to_string())?;

debug!("TEST_THREAD_GOOD: via JNI");
let test_string = env.new_string("SUPER TEST").unwrap();
let test_string = JValue::from(test_string);
if let Err(e) = env.call_static_method(
class_ref,
"threadTestNoClass",
"(Ljava/lang/String;)I",
&[test_string],
) {
error!("Couldn't call method via classic JNI: {}", e);
if env.exception_check().unwrap_or(false) {
let ex = env.exception_occurred().unwrap();
let _ = env.exception_clear();
let res = env
.call_method(ex, "toString", "()Ljava/lang/String;", &[])
.unwrap()
.l()
.unwrap();
let ex_msg: String = env.get_string(res.into()).unwrap().into();
error!("check_jni_error: {}", ex_msg);
}
}

debug!("TEST_THREAD_GOOD: via Robusta");

/* Call methode */
if let Err(e) = RobustaAndroidExample::threadTestWithClass(&env, class_ref, "test".to_string())
{
let msg = format!("Couldn't call method via Robusta: {}", e);
error!("{}", msg);
if env.exception_check().unwrap_or(false) {
let _ = env.exception_clear();
}
return Err(msg);
}
Ok(())
}
91 changes: 69 additions & 22 deletions robusta-codegen/src/transformation/imported.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use syn::{FnArg, ImplItemMethod, Lit, Pat, PatIdent, ReturnType, Signature};
use crate::transformation::context::StructContext;
use crate::transformation::utils::get_call_type;
use crate::transformation::{CallType, CallTypeAttribute, SafeParams};
use crate::utils::{get_abi, get_env_arg, is_self_method};
use crate::utils::{get_abi, get_class_arg_if_any, get_env_arg, is_self_method};
use std::collections::HashSet;

pub struct ImportedMethodTransformer<'ctx> {
Expand Down Expand Up @@ -49,7 +49,8 @@ impl<'ctx> Fold for ImportedMethodTransformer<'ctx> {

let mut original_signature = node.sig.clone();
let self_method = is_self_method(&node.sig);
let (mut signature, env_arg) = get_env_arg(node.sig.clone());
let (signature, env_arg) = get_env_arg(node.sig.clone());
let (mut signature, class_ref_arg) = get_class_arg_if_any(signature.clone());

let impl_item_attributes: Vec<_> = {
let discarded_known_attributes: HashSet<&str> = {
Expand Down Expand Up @@ -321,6 +322,20 @@ impl<'ctx> Fold for ImportedMethodTransformer<'ctx> {
h
};

let class_arg_ident = if let Some(class_ref_arg) = class_ref_arg {
match class_ref_arg {
FnArg::Typed(t) => {
match *t.pat {
Pat::Ident(PatIdent { ident, .. }) => Some(ident),
_ => panic!("non-ident pat in FnArg")
}
},
_ => panic!("Bug -- please report to library author. Expected env parameter, found receiver")
}
} else {
None
};

original_signature.inputs.iter_mut().for_each(|i| match i {
FnArg::Typed(t) => match &*t.pat {
Pat::Ident(PatIdent { ident, .. }) if ident == "self" => {}
Expand Down Expand Up @@ -373,32 +388,64 @@ impl<'ctx> Fold for ImportedMethodTransformer<'ctx> {
match call_type {
CallType::Safe(_) => {
if is_constructor {
parse_quote! {{
let env: &'_ ::robusta_jni::jni::JNIEnv<'_> = #env_ident;
let res = env.new_object(#java_class_path, #java_signature, &[#input_conversions]);
#return_expr
}}
if let Some(class_arg_ident) = class_arg_ident {
parse_quote! {{
let env: &'_ ::robusta_jni::jni::JNIEnv<'_> = #env_ident;
let res = env.new_object(#class_arg_ident, #java_signature, &[#input_conversions]);
#return_expr
}}
} else {
parse_quote! {{
let env: &'_ ::robusta_jni::jni::JNIEnv<'_> = #env_ident;
let res = env.new_object(#java_class_path, #java_signature, &[#input_conversions]);
#return_expr
}}
}
} else {
parse_quote! {{
let env: &'_ ::robusta_jni::jni::JNIEnv<'_> = #env_ident;
let res = env.call_static_method(#java_class_path, #java_method_name, #java_signature, &[#input_conversions]);
#return_expr
}}
if let Some(class_arg_ident) = class_arg_ident {
parse_quote! {{
let env: &'_ ::robusta_jni::jni::JNIEnv<'_> = #env_ident;
let res = env.call_static_method(#class_arg_ident, #java_method_name, #java_signature, &[#input_conversions]);
#return_expr
}}
} else {
parse_quote! {{
let env: &'_ ::robusta_jni::jni::JNIEnv<'_> = #env_ident;
let res = env.call_static_method(#java_class_path, #java_method_name, #java_signature, &[#input_conversions]);
#return_expr
}}
}
}
}
CallType::Unchecked(_) => {
if is_constructor {
parse_quote! {{
let env: &'_ ::robusta_jni::jni::JNIEnv<'_> = #env_ident;
let res = env.new_object(#java_class_path, #java_signature, &[#input_conversions]).unwrap();
#return_expr
}}
if let Some(class_arg_ident) = class_arg_ident {
parse_quote! {{
let env: &'_ ::robusta_jni::jni::JNIEnv<'_> = #env_ident;
let res = env.new_object(#class_arg_ident, #java_signature, &[#input_conversions]).unwrap();
#return_expr
}}
} else {
parse_quote! {{
let env: &'_ ::robusta_jni::jni::JNIEnv<'_> = #env_ident;
let res = env.new_object(#java_class_path, #java_signature, &[#input_conversions]).unwrap();
#return_expr
}}
}
} else {
parse_quote! {{
let env: &'_ ::robusta_jni::jni::JNIEnv<'_> = #env_ident;
let res = env.call_static_method(#java_class_path, #java_method_name, #java_signature, &[#input_conversions]).unwrap();
#return_expr
}}
if let Some(class_arg_ident) = class_arg_ident {
parse_quote! {{
let env: &'_ ::robusta_jni::jni::JNIEnv<'_> = #env_ident;
let res = env.call_static_method(#class_arg_ident, #java_method_name, #java_signature, &[#input_conversions]).unwrap();
#return_expr
}}
} else {
parse_quote! {{
let env: &'_ ::robusta_jni::jni::JNIEnv<'_> = #env_ident;
let res = env.call_static_method(#java_class_path, #java_method_name, #java_signature, &[#input_conversions]).unwrap();
#return_expr
}}
}
}
}
}
Expand Down
44 changes: 44 additions & 0 deletions robusta-codegen/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,50 @@ pub fn get_env_arg(signature: Signature) -> (Signature, Option<FnArg>) {
(transformed_signature, env_arg)
}

pub fn get_class_arg_if_any(signature: Signature) -> (Signature, Option<FnArg>) {
let has_explicit_class_ref_arg = if let Some(FnArg::Typed(PatType { ty, .. })) = signature.inputs.iter().next() {
if let Type::Reference(TypeReference { elem, .. }) = &**ty {
if let Type::Path(t) = &**elem {
let full_path: Path = parse_quote! { ::robusta_jni::jni::objects::GlobalRef };
let imported_path: Path = parse_quote! { GlobalRef };
let canonicalized_type_path = canonicalize_path(&t.path);

canonicalized_type_path == imported_path || canonicalized_type_path == full_path
} else {
false
}
} else if let Type::Path(t) = &**ty {
/* If the user has input `class_ref: GlobalRef` instead of `class_ref: &GlobalRef`, we let them know. */
let full_path: Path = parse_quote! { ::robusta_jni::jni::objects::GlobalRef };
let imported_path: Path = parse_quote! { GlobalRef };
let canonicalized_type_path = canonicalize_path(&t.path);

if canonicalized_type_path == imported_path || canonicalized_type_path == full_path {
emit_error!(t, "explicit environment parameter must be of type `&GlobalRef`");
}

false
} else {
false
}
} else {
false
};

if has_explicit_class_ref_arg {
let mut inner_signature = signature;

let mut iter = inner_signature.inputs.into_iter();
let class_arg = iter.next();

inner_signature.inputs = iter.collect();
(inner_signature, class_arg)

} else {
(signature, None)
}
}

pub fn get_abi(sig: &Signature) -> Option<String> {
sig.abi
.as_ref()
Expand Down

0 comments on commit 063f848

Please sign in to comment.