From 2ff6aed113f4c67db00b485d5cb1d1432d797e38 Mon Sep 17 00:00:00 2001 From: findolor Date: Wed, 26 Feb 2025 15:44:37 +0300 Subject: [PATCH] update macro and write tests --- crates/js_api/src/gui/new.rs | 91 +++++-- crates/macros/src/lib.rs | 286 +++++++++++++++++++-- packages/orderbook/test/js_api/new.test.ts | 42 +++ 3 files changed, 375 insertions(+), 44 deletions(-) create mode 100644 packages/orderbook/test/js_api/new.test.ts diff --git a/crates/js_api/src/gui/new.rs b/crates/js_api/src/gui/new.rs index 626482756..eca2474b6 100644 --- a/crates/js_api/src/gui/new.rs +++ b/crates/js_api/src/gui/new.rs @@ -1,42 +1,91 @@ use serde::{Deserialize, Serialize}; -use wasm_bindgen_utils::prelude::*; -use wasm_function_macro::print_fn_names; +use wasm_bindgen_utils::{impl_wasm_traits, prelude::*}; +use wasm_function_macro::impl_wasm_exports; use super::GuiError; +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Tsify)] +#[serde(rename_all = "camelCase")] +pub struct CustomError { + msg: String, + readable_msg: String, +} +impl_wasm_traits!(CustomError); + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Tsify)] +pub struct CustomResult { + data: Option, + error: Option, +} +impl_wasm_traits!(CustomResult); + +impl From> for CustomResult { + fn from(result: Result) -> Self { + match result { + Ok(data) => CustomResult { + data: Some(data), + error: None, + }, + Err(err) => CustomResult { + data: None, + error: Some(CustomError { + msg: err.to_string(), + readable_msg: err.to_string(), + }), + }, + } + } +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[wasm_bindgen] pub struct TestStruct { field: String, } -#[print_fn_names] +#[impl_wasm_exports] impl TestStruct { - pub fn result_function() -> Result { - Ok("Hello, world!".to_string()) + pub fn new(value: String) -> Self { + Self { field: value } } - pub async fn async_function() -> Result { - // Simulate an asynchronous operation with a small delay - wasm_bindgen_futures::JsFuture::from(js_sys::Promise::new(&mut |resolve, _| { - let window = web_sys::window().expect("should have a window in this context"); - let _ = window.set_timeout_with_callback_and_timeout_and_arguments_0( - &resolve, 100, // 100ms delay - ); - })) - .await - .map_err(|_| GuiError::JsError("Failed to sleep".to_string()))?; + pub fn simple_function() -> Result { Ok("Hello, world!".to_string()) } - pub fn normal_function() -> String { - "Hello, world!".to_string() + pub fn err_function() -> Result { + Err(GuiError::JsError("some error".to_string())) } - pub fn self_function(&self) -> String { - self.field.clone() + pub fn simple_function_with_self(&self) -> Result { + Ok(format!("Hello, {}!", self.field)) } - fn some_private_function() -> String { - "Hello, world!".to_string() + pub fn err_function_with_self(&self) -> Result { + Err(GuiError::JsError("some error".to_string())) } + + // pub async fn async_function() -> Result { + // // Simulate an asynchronous operation with a small delay + // wasm_bindgen_futures::JsFuture::from(js_sys::Promise::new(&mut |resolve, _| { + // let window = web_sys::window().expect("should have a window in this context"); + // let _ = window.set_timeout_with_callback_and_timeout_and_arguments_0( + // &resolve, 100, // 100ms delay + // ); + // })) + // .await + // .map_err(|_| GuiError::JsError("Failed to sleep".to_string()))?; + // Ok("Hello, world!".to_string()) + // } + + // pub fn normal_function() -> String { + // "Hello, world!".to_string() + // } + + // pub fn self_function(&self) -> String { + // self.field.clone() + // } + + // fn some_private_function() -> String { + // "Hello, world!".to_string() + // } } diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index c1863d595..e90be812a 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -1,35 +1,85 @@ use proc_macro::TokenStream; use quote::quote; -use syn::{parse_macro_input, ImplItem, ItemImpl}; +use syn::{parse_macro_input, ImplItem, ItemImpl, Path, PathSegment, ReturnType, Type, TypePath}; #[proc_macro_attribute] -pub fn print_fn_names(_attr: TokenStream, item: TokenStream) -> TokenStream { +pub fn impl_wasm_exports(_attr: TokenStream, item: TokenStream) -> TokenStream { // Parse the input as an impl block let mut input = parse_macro_input!(item as ItemImpl); - // Transform each method to add the wasm_function attribute - input.items = input - .items - .into_iter() - .map(|item| { - if let ImplItem::Fn(mut method) = item { - // Only process public functions - if let syn::Visibility::Public(_) = method.vis { - let fn_name = &method.sig.ident; - let camel_case_name = to_camel_case(&fn_name.to_string()); - - // Add the wasm_function attribute with the camelCase name - method - .attrs - .push(syn::parse_quote!(#[wasm_bindgen(js_name = #camel_case_name)])); + // Transform each method to add a wasm export version + let mut new_items = Vec::new(); + + for item in input.items { + if let ImplItem::Fn(mut method) = item { + // Only process public functions that return Result + if let syn::Visibility::Public(_) = method.vis { + if let ReturnType::Type(_, return_type) = &method.sig.output { + // Try to extract Result inner type (will skip non-Result functions) + if let Some(inner_type) = try_extract_result_inner_type(return_type) { + let fn_name = &method.sig.ident; + let is_async = method.sig.asyncness.is_some(); + let args = collect_function_arguments(&method.sig.inputs); + + // New function logic + { + let export_fn_name = syn::Ident::new( + &format!("{}__{}", fn_name, "wasm_export"), + fn_name.span(), + ); + let camel_case_name = to_camel_case(&fn_name.to_string()); + + // Create a new function with __wasm_export suffix + let mut export_method = method.clone(); + export_method.sig.ident = export_fn_name; + + add_attributes_to_new_function(&mut export_method, &camel_case_name); + + // Create a new return type wrapped in CustomResult with the inner type + let new_return_type = if is_async { + syn::parse_quote!(-> std::pin::Pin, wasm_bindgen::JsValue>>>>) + } else { + syn::parse_quote!(-> Result, wasm_bindgen::JsValue>) + }; + + export_method.sig.output = new_return_type; + + let call_expr = create_new_function_call(&fn_name, &args); + + if is_async { + export_method.block = syn::parse_quote!({ + Box::pin(async move { + let result: CustomResult<_> = #call_expr.await.into(); + Ok(result) + }) + }); + } else { + export_method.block = syn::parse_quote!({ + let result: CustomResult<_> = #call_expr.into(); + Ok(result) + }); + } + + new_items.push(ImplItem::Fn(export_method)); + } + + // Add the skip_typescript attribute to the original method + method + .attrs + .push(syn::parse_quote!(#[wasm_bindgen(skip_typescript)])); + } } - ImplItem::Fn(method) - } else { - // Return non-method items unchanged - item } - }) - .collect(); + + // Keep the original item + new_items.push(ImplItem::Fn(method)); + } else { + // Keep the original item + new_items.push(item.clone()); + } + } + + input.items = new_items; // Generate the output with wasm_bindgen applied to the impl block let output = quote! { @@ -37,6 +87,8 @@ pub fn print_fn_names(_attr: TokenStream, item: TokenStream) -> TokenStream { #input }; + println!("{}", output); + output.into() } @@ -57,3 +109,191 @@ fn to_camel_case(name: &str) -> String { result } + +// Try to extract the inner type from a Result type, returning None if not a Result +fn try_extract_result_inner_type(return_type: &Box) -> Option<&Type> { + if let Type::Path(TypePath { + path: Path { segments, .. }, + .. + }) = &**return_type + { + if let Some(PathSegment { + ident, arguments, .. + }) = segments.first() + { + if ident.to_string() == "Result" { + if let syn::PathArguments::AngleBracketed(args) = arguments { + if let Some(syn::GenericArgument::Type(t)) = args.args.first() { + return Some(t); + } + } + } + } + } + None +} + +fn collect_function_arguments( + inputs: &syn::punctuated::Punctuated, +) -> Vec { + inputs + .iter() + .filter_map(|arg| { + match arg { + syn::FnArg::Receiver(receiver) => { + // Handle self parameter + if receiver.reference.is_some() { + if receiver.mutability.is_some() { + Some(quote::quote! { &mut self }) + } else { + Some(quote::quote! { &self }) + } + } else { + Some(quote::quote! { self }) + } + } + syn::FnArg::Typed(pat_type) => { + // Handle named parameters + if let syn::Pat::Ident(pat_ident) = &*pat_type.pat { + Some(quote::quote! { #pat_ident }) + } else { + None + } + } + } + }) + .collect() +} + +fn add_attributes_to_new_function(method: &mut syn::ImplItemFn, camel_case_name: &str) { + // Add the allow attribute to suppress the warning + method + .attrs + .push(syn::parse_quote!(#[allow(non_snake_case)])); + + // Add the wasm_function attribute with the camelCase name + method + .attrs + .push(syn::parse_quote!(#[wasm_bindgen(js_name = #camel_case_name)])); + + // Extract the inner type from the Result return type + if let ReturnType::Type(_, return_type) = &method.sig.output { + if let Some(inner_type) = try_extract_result_inner_type(return_type) { + let ts_type = rust_type_to_ts_type(inner_type); + let return_type = format!("CustomResult<{}>", ts_type); + method.attrs.push(syn::parse_quote!( + #[wasm_bindgen(unchecked_return_type = #return_type)] + )); + } + } +} + +/// Converts a Rust type to its TypeScript equivalent for wasm_bindgen +fn rust_type_to_ts_type(rust_type: &Type) -> String { + match rust_type { + Type::Path(type_path) => { + if let Some(segment) = type_path.path.segments.last() { + let type_name = segment.ident.to_string(); + + // Handle primitive types + match type_name.as_str() { + "String" | "str" => "string".to_string(), + "bool" => "boolean".to_string(), + "u8" | "u16" | "u32" | "i8" | "i16" | "i32" | "f32" | "f64" => { + "number".to_string() + } + "u64" | "u128" | "i64" | "i128" => "bigint".to_string(), + "Vec" => { + // Handle Vec -> Array + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + if let Some(syn::GenericArgument::Type(inner_type)) = args.args.first() + { + let inner_ts_type = rust_type_to_ts_type(inner_type); + return format!("Array<{}>", inner_ts_type); + } + } + "Array".to_string() + } + "Option" => { + // Handle Option -> T | null + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + if let Some(syn::GenericArgument::Type(inner_type)) = args.args.first() + { + let inner_ts_type = rust_type_to_ts_type(inner_type); + return format!("{} | undefined", inner_ts_type); + } + } + "any | undefined".to_string() + } + "HashMap" | "BTreeMap" => { + // Handle maps -> Record + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + if args.args.len() >= 2 { + if let ( + Some(syn::GenericArgument::Type(key_type)), + Some(syn::GenericArgument::Type(value_type)), + ) = (args.args.first(), args.args.get(1)) + { + let key_ts_type = rust_type_to_ts_type(key_type); + let value_ts_type = rust_type_to_ts_type(value_type); + return format!("Record<{}, {}>", key_ts_type, value_ts_type); + } + } + } + "Record".to_string() + } + // For custom types, use the type name directly + _ => type_name, + } + } else { + "any".to_string() + } + } + Type::Reference(type_ref) => { + // Handle references like &str + rust_type_to_ts_type(&type_ref.elem) + } + Type::Tuple(tuple) => { + // Handle tuples -> [T, U, ...] + if tuple.elems.is_empty() { + "null".to_string() + } else { + let ts_types: Vec = tuple + .elems + .iter() + .map(|elem| rust_type_to_ts_type(elem)) + .collect(); + format!("[{}]", ts_types.join(", ")) + } + } + Type::Array(array) => { + // Handle arrays -> Array + let inner_ts_type = rust_type_to_ts_type(&array.elem); + format!("Array<{}>", inner_ts_type) + } + // For other types, default to "any" + _ => "any".to_string(), + } +} + +fn create_new_function_call( + fn_name: &syn::Ident, + args: &[proc_macro2::TokenStream], +) -> proc_macro2::TokenStream { + if let Some(first_arg) = args.first() { + // Check if the first argument is self (indicating an instance method) + if first_arg.to_string().contains("self") { + if args.len() > 1 { + quote::quote! { self.#fn_name(#(#args),*) } + } else { + quote::quote! { self.#fn_name() } + } + } else { + // Static method call (no self) + quote::quote! { Self::#fn_name(#(#args),*) } + } + } else { + // No arguments at all, must be a static method + quote::quote! { Self::#fn_name() } + } +} diff --git a/packages/orderbook/test/js_api/new.test.ts b/packages/orderbook/test/js_api/new.test.ts new file mode 100644 index 000000000..9f10d8bc6 --- /dev/null +++ b/packages/orderbook/test/js_api/new.test.ts @@ -0,0 +1,42 @@ +import assert from 'assert'; +import { describe, it } from 'vitest'; +import { TestStruct } from '../../dist/cjs/js_api.js'; +import { CustomError } from '../../dist/types/js_api'; + +describe('TestStruct', () => { + it('should be able to call simpleFunction', () => { + const result = TestStruct.simpleFunction(); + assert.equal(result.data, 'Hello, world!'); + }); + + it('should be able to call errFunction', () => { + let result = TestStruct.errFunction(); + if (result.data) { + assert.fail('result.data should be undefined'); + } + let error = { + msg: 'JavaScript error: some error', + readableMsg: 'JavaScript error: some error' + } as CustomError; + assert.deepEqual(result.error, error); + }); + + it('should be able to call simpleFunctionWithSelf', () => { + let testStruct = TestStruct.new('beef'); + const result = testStruct.simpleFunctionWithSelf(); + assert.equal(result.data, 'Hello, beef!'); + }); + + it('should be able to call errFunctionWithSelf', () => { + let testStruct = TestStruct.new('beef'); + const result = testStruct.errFunctionWithSelf(); + if (result.data) { + assert.fail('result.data should be undefined'); + } + let error = { + msg: 'JavaScript error: some error', + readableMsg: 'JavaScript error: some error' + } as CustomError; + assert.deepEqual(result.error, error); + }); +});