diff --git a/Source/buildimplementationrust.go b/Source/buildimplementationrust.go index 412a64b1..dbe055f0 100644 --- a/Source/buildimplementationrust.go +++ b/Source/buildimplementationrust.go @@ -76,6 +76,38 @@ func BuildImplementationRust(component ComponentDefinition, outputFolder string, return err } + IntfWrapperFileName := BaseName + "_interface_wrapper.rs" + IntfWrapperFilePath := path.Join(outputFolder, IntfWrapperFileName) + modfiles = append(modfiles, IntfWrapperFilePath) + log.Printf("Creating \"%s\"", IntfWrapperFilePath) + IntfWrapperRSFile, err := CreateLanguageFile(IntfWrapperFilePath, indentString) + if err != nil { + return err + } + IntfWrapperRSFile.WriteCLicenseHeader(component, + fmt.Sprintf("This is an autogenerated Rust implementation file in order to allow easy\ndevelopment of %s. The functions in this file need to be implemented. It needs to be generated only once.", LibraryName), + true) + err = buildRustWrapper(component, IntfWrapperRSFile, InterfaceMod) + if err != nil { + return err + } + + IntfHandleFileName := BaseName + "_interface_handle.rs" + IntfHandleFilePath := path.Join(outputFolder, IntfHandleFileName) + modfiles = append(modfiles, IntfHandleFilePath) + log.Printf("Creating \"%s\"", IntfHandleFilePath) + IntfHandleRSFile, err := CreateLanguageFile(IntfHandleFilePath, indentString) + if err != nil { + return err + } + IntfHandleRSFile.WriteCLicenseHeader(component, + fmt.Sprintf("This is an autogenerated Rust implementation file in order to allow easy\ndevelopment of %s. The functions in this file need to be implemented. It needs to be generated only once.", LibraryName), + true) + err = buildRustHandle(component, IntfHandleRSFile, InterfaceMod) + if err != nil { + return err + } + IntfWrapperStubName := path.Join(stubOutputFolder, BaseName+stubIdentifier+".rs") modfiles = append(modfiles, IntfWrapperStubName) if forceRebuild || !FileExists(IntfWrapperStubName) { @@ -344,7 +376,7 @@ func buildRustGlobalStubFile(component ComponentDefinition, w LanguageWriter, In w.Writeln("use %s::*;", InterfaceMod) w.Writeln("") w.Writeln("// Wrapper struct to implement the wrapper trait for global methods") - w.Writeln("struct CWrapper;") + w.Writeln("pub struct CWrapper;") w.Writeln("") w.Writeln("impl Wrapper for CWrapper {") w.Writeln("") @@ -474,3 +506,162 @@ func buildRustStubFile(component ComponentDefinition, class ComponentDefinitionC w.Writeln("") return nil } + +func buildRustWrapper(component ComponentDefinition, w LanguageWriter, InterfaceMod string) error { + // Imports + ModName := strings.ToLower(component.NameSpace) + w.Writeln("") + w.Writeln("// Calls from the C-Interface to the Rust traits via the CWrapper") + w.Writeln("// These are the symbols exposed in the shared object interface") + w.Writeln("") + w.Writeln("use %s::*;", InterfaceMod) + w.Writeln("use %s::CWrapper;", ModName) + w.Writeln("use std::ffi::{c_char, CStr};") + w.Writeln("") + cprefix := ModName + "_" + // Build the global methods + err := writeGlobalRustWrapper(component, w, cprefix) + if err != nil { + return err + } + return nil +} + +func buildRustHandle(component ComponentDefinition, w LanguageWriter, InterfaceMod string) error { + w.Writeln("") + w.Writeln("// Handle passed through interface define the casting maps needed to extract") + w.Writeln("") + w.Writeln("use %s::*;", InterfaceMod) + w.Writeln("") + w.Writeln("impl HandleImpl {") + w.AddIndentationLevel(1) + for i := 0; i < len(component.Classes); i++ { + class := component.Classes[i] + writeRustHandleAs(component, w, class, false) + writeRustHandleAs(component, w, class, true) + w.Writeln("") + } + w.AddIndentationLevel(-1) + w.Writeln("}") + return nil +} + +func writeRustHandleAs(component ComponentDefinition, w LanguageWriter, class ComponentDefinitionClass, mut bool) error { + //parents, err := getParentList(component, class) + //if err != nil { + // return err + //} + Name := class.ClassName + if !mut { + w.Writeln("pub fn as_%s(&self) -> Option<&dyn %s> {", toSnakeCase(Name), Name) + } else { + w.Writeln("pub fn as_mut_%s(&mut self) -> Option<&mut dyn %s> {", toSnakeCase(Name), Name) + } + w.AddIndentationLevel(1) + w.Writeln("None") + w.AddIndentationLevel(-1) + w.Writeln("}") + return nil +} + +func writeGlobalRustWrapper(component ComponentDefinition, w LanguageWriter, cprefix string) error { + methods := component.Global.Methods + for i := 0; i < len(methods); i++ { + method := methods[i] + err := writeRustMethodWrapper(method, w, cprefix) + if err != nil { + return err + } + w.Writeln("") + } + return nil +} + +func writeRustMethodWrapper(method ComponentDefinitionMethod, w LanguageWriter, cprefix string) error { + // Build up the parameter strings + parameterString := "" + returnName := "" + for k := 0; k < len(method.Params); k++ { + param := method.Params[k] + RustParams, err := generateRustParameters(param, true) + if err != nil { + return err + } + for i := 0; i < len(RustParams); i++ { + RustParam := RustParams[i] + if parameterString == "" { + parameterString += fmt.Sprintf("%s : %s", RustParam.ParamName, RustParam.ParamType) + } else { + parameterString += fmt.Sprintf(", %s : %s", RustParam.ParamName, RustParam.ParamType) + } + } + } + w.Writeln("pub fn %s%s(%s) -> i32 {", cprefix, strings.ToLower(method.MethodName), parameterString) + w.AddIndentationLevel(1) + argsString := "" + for k := 0; k < len(method.Params); k++ { + param := method.Params[k] + OName, err := writeRustParameterConversionArg(param, w) + if err != nil { + return err + } + if OName != "" { + if argsString == "" { + argsString = OName + } else { + argsString += fmt.Sprintf(", %s", OName) + } + } + } + if returnName != "" { + w.Writeln("let %s = CWrapper::%s(%s);", returnName, toSnakeCase(method.MethodName), argsString) + } else { + w.Writeln("CWrapper::%s(%s);", toSnakeCase(method.MethodName), argsString) + } + w.Writeln("// All ok") + w.Writeln("0") + w.AddIndentationLevel(-1) + w.Writeln("}") + return nil +} + +func writeRustParameterConversionArg(param ComponentDefinitionParam, w LanguageWriter) (string, error) { + if param.ParamPass == "return" { + return "", nil + } + IName := toSnakeCase(param.ParamName) + OName := "_" + IName + switch param.ParamType { + case "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "single", "double": + if param.ParamPass == "in" { + w.Writeln("let %s = %s;", OName, IName) + } else { + w.Writeln("let %s = unsafe {&mut *%s};", OName, IName) + } + case "class", "optionalclass": + if param.ParamPass == "in" { + HName := "_Handle_" + IName + w.Writeln("let %s = unsafe {&*%s};", HName, IName) + w.Writeln("let %s = %s.as_%s().unwrap();", OName, HName, toSnakeCase(param.ParamClass)) + } else { + HName := "_Handle_" + IName + w.Writeln("let %s = unsafe {&mut *%s};", HName, IName) + w.Writeln("let %s = %s.as_mut_%s().unwrap();", OName, HName, toSnakeCase(param.ParamClass)) + } + case "string": + if param.ParamPass == "in" { + SName := "_Str_" + IName + w.Writeln("let %s = unsafe{ CStr::from_ptr(%s) };", SName, IName) + w.Writeln("let %s = %s.to_str().unwrap();", OName, SName) + } else { + SName := "_String_" + IName + w.Writeln("let mut %s = String::new();", SName) + w.Writeln("let %s = &mut %s;", OName, SName) + } + case "bool", "pointer", "struct", "basicarray", "structarray": + //return fmt.Errorf("Conversion of type %s for parameter %s not supported", param.ParamType, IName) + default: + return "", fmt.Errorf("Conversion of type %s for parameter %s not supported as is unknown", param.ParamType, IName) + } + return OName, nil +} diff --git a/Source/languagerust.go b/Source/languagerust.go index 6d9a0303..d4a5ef86 100644 --- a/Source/languagerust.go +++ b/Source/languagerust.go @@ -47,7 +47,7 @@ func toSnakeCase(BaseType string) string { func writeRustBaseTypeDefinitions(componentdefinition ComponentDefinition, w LanguageWriter, NameSpace string, BaseName string) error { w.Writeln("#[allow(unused_imports)]") - w.Writeln("use std::ffi;") + w.Writeln("use std::ffi::c_void;") w.Writeln("") w.Writeln("/*************************************************************************************************************************") w.Writeln(" Version definition for %s", NameSpace) @@ -63,10 +63,28 @@ func writeRustBaseTypeDefinitions(componentdefinition ComponentDefinition, w Lan w.Writeln("") w.Writeln("/*************************************************************************************************************************") - w.Writeln(" Basic pointers definition for %s", NameSpace) + w.Writeln(" Handle definiton for %s", NameSpace) w.Writeln("**************************************************************************************************************************/") w.Writeln("") - w.Writeln("type Handle = std::ffi::c_void;") + w.Writeln("// Enum of all traits - this acts as a handle as we pass trait pointers through the interface") + w.Writeln("pub enum HandleImpl {") + w.AddIndentationLevel(1) + for i := 0; i < len(componentdefinition.Classes); i++ { + class := componentdefinition.Classes[i] + if i != len(componentdefinition.Classes)-1 { + w.Writeln("T%s(Box<dyn %s>),", class.ClassName, class.ClassName) + } else { + w.Writeln("T%s(Box<dyn %s>)", class.ClassName, class.ClassName) + } + } + w.AddIndentationLevel(-1) + w.Writeln("}") + w.Writeln("") + w.Writeln("pub type Handle = *mut HandleImpl;") + for i := 0; i < len(componentdefinition.Classes); i++ { + class := componentdefinition.Classes[i] + w.Writeln("pub type %sHandle = *mut HandleImpl;", class.ClassName) + } if len(componentdefinition.Enums) > 0 { w.Writeln("/*************************************************************************************************************************") @@ -211,6 +229,25 @@ func generateRustParameters(param ComponentDefinitionParam, isPlain bool) ([]Rus } if isPlain { + if param.ParamType == "string" { + if param.ParamPass == "out" { + Params = make([]RustParameter, 3) + Params[0].ParamType = "u64" + Params[0].ParamName = toSnakeCase(param.ParamName) + "_buffer_size" + Params[0].ParamComment = fmt.Sprintf("* @param[in] %s - size of the buffer (including trailing 0)", Params[0].ParamName) + + Params[1].ParamType = "*mut u64" + Params[1].ParamName = toSnakeCase(param.ParamName) + "_needed_chars" + Params[1].ParamComment = fmt.Sprintf("* @param[out] %s - will be filled with the count of the written bytes, or needed buffer size.", Params[1].ParamName) + + Params[2].ParamType = "*mut c_char" + Params[2].ParamName = toSnakeCase(param.ParamName) + "_buffer" + Params[2].ParamComment = fmt.Sprintf("* @param[out] %s - %s buffer of %s, may be NULL", Params[2].ParamName, param.ParamClass, param.ParamDescription) + + return Params, nil + } + } + if param.ParamType == "basicarray" { return nil, fmt.Errorf("Not yet handled") } @@ -231,50 +268,51 @@ func generateRustParameterType(param ComponentDefinitionParam, isPlain bool) (st RustParamTypeName := "" ParamTypeName := param.ParamType ParamClass := param.ParamClass + BasicType := false switch ParamTypeName { case "uint8": RustParamTypeName = "u8" - + BasicType = true case "uint16": RustParamTypeName = "u16" - + BasicType = true case "uint32": RustParamTypeName = "u32" - + BasicType = true case "uint64": RustParamTypeName = "u64" - + BasicType = true case "int8": RustParamTypeName = "i8" - + BasicType = true case "int16": RustParamTypeName = "i16" - + BasicType = true case "int32": RustParamTypeName = "i32" - + BasicType = true case "int64": RustParamTypeName = "i64" - + BasicType = true case "bool": if isPlain { RustParamTypeName = "u8" } else { RustParamTypeName = "bool" } - + BasicType = true case "single": RustParamTypeName = "f32" - + BasicType = true case "double": RustParamTypeName = "f64" - + BasicType = true case "pointer": RustParamTypeName = "c_void" - + BasicType = true case "string": if isPlain { - RustParamTypeName = "*mut char" + RustParamTypeName = "*const c_char" } else { switch param.ParamPass { case "out": @@ -290,17 +328,12 @@ func generateRustParameterType(param ComponentDefinitionParam, isPlain bool) (st if isPlain { RustParamTypeName = fmt.Sprintf("u16") } else { - switch param.ParamPass { - case "out": - RustParamTypeName = fmt.Sprintf("&mut %s", ParamClass) - case "in", "return": - RustParamTypeName = fmt.Sprintf("%s", ParamClass) - } + RustParamTypeName = ParamClass } - + BasicType = true case "functiontype": RustParamTypeName = fmt.Sprintf("%s", ParamClass) - + BasicType = true case "struct": if isPlain { RustParamTypeName = fmt.Sprintf("*mut %s", ParamClass) @@ -353,13 +386,14 @@ func generateRustParameterType(param ComponentDefinitionParam, isPlain bool) (st case "class", "optionalclass": if isPlain { - RustParamTypeName = fmt.Sprintf("Handle") + RustParamTypeName = fmt.Sprintf("%sHandle", ParamClass) + BasicType = true } else { switch param.ParamPass { case "out": - RustParamTypeName = fmt.Sprintf("&mut impl %s", ParamClass) + RustParamTypeName = fmt.Sprintf("&mut dyn %s", ParamClass) case "in": - RustParamTypeName = fmt.Sprintf("& impl %s", ParamClass) + RustParamTypeName = fmt.Sprintf("& dyn %s", ParamClass) case "return": RustParamTypeName = fmt.Sprintf("Box<dyn %s>", ParamClass) } @@ -368,6 +402,16 @@ func generateRustParameterType(param ComponentDefinitionParam, isPlain bool) (st default: return "", fmt.Errorf("invalid parameter type \"%s\" for Rust parameter", ParamTypeName) } - + if BasicType { + if param.ParamPass == "out" { + if isPlain { + RustParamOutTypeName := fmt.Sprintf("*mut %s", RustParamTypeName) + return RustParamOutTypeName, nil + } else { + RustParamOutTypeName := fmt.Sprintf("&mut %s", RustParamTypeName) + return RustParamOutTypeName, nil + } + } + } return RustParamTypeName, nil }