From fc311f7763337cda28259e3bdba96e73a2f828be Mon Sep 17 00:00:00 2001
From: Robert Goss <goss.robert@gmail.com>
Date: Mon, 25 Sep 2023 22:48:47 +0100
Subject: [PATCH] Basic conversion of parameters for global wrappers

---
 Source/buildimplementationrust.go | 193 +++++++++++++++++++++++++++++-
 Source/languagerust.go            | 100 +++++++++++-----
 2 files changed, 264 insertions(+), 29 deletions(-)

diff --git a/Source/buildimplementationrust.go b/Source/buildimplementationrust.go
index 412a64b1..dbe055f0 100644
--- a/Source/buildimplementationrust.go
+++ b/Source/buildimplementationrust.go
@@ -76,6 +76,38 @@ func BuildImplementationRust(component ComponentDefinition, outputFolder string,
 		return err
 	}
 
+	IntfWrapperFileName := BaseName + "_interface_wrapper.rs"
+	IntfWrapperFilePath := path.Join(outputFolder, IntfWrapperFileName)
+	modfiles = append(modfiles, IntfWrapperFilePath)
+	log.Printf("Creating \"%s\"", IntfWrapperFilePath)
+	IntfWrapperRSFile, err := CreateLanguageFile(IntfWrapperFilePath, indentString)
+	if err != nil {
+		return err
+	}
+	IntfWrapperRSFile.WriteCLicenseHeader(component,
+		fmt.Sprintf("This is an autogenerated Rust implementation file in order to allow easy\ndevelopment of %s. The functions in this file need to be implemented. It needs to be generated only once.", LibraryName),
+		true)
+	err = buildRustWrapper(component, IntfWrapperRSFile, InterfaceMod)
+	if err != nil {
+		return err
+	}
+
+	IntfHandleFileName := BaseName + "_interface_handle.rs"
+	IntfHandleFilePath := path.Join(outputFolder, IntfHandleFileName)
+	modfiles = append(modfiles, IntfHandleFilePath)
+	log.Printf("Creating \"%s\"", IntfHandleFilePath)
+	IntfHandleRSFile, err := CreateLanguageFile(IntfHandleFilePath, indentString)
+	if err != nil {
+		return err
+	}
+	IntfHandleRSFile.WriteCLicenseHeader(component,
+		fmt.Sprintf("This is an autogenerated Rust implementation file in order to allow easy\ndevelopment of %s. The functions in this file need to be implemented. It needs to be generated only once.", LibraryName),
+		true)
+	err = buildRustHandle(component, IntfHandleRSFile, InterfaceMod)
+	if err != nil {
+		return err
+	}
+
 	IntfWrapperStubName := path.Join(stubOutputFolder, BaseName+stubIdentifier+".rs")
 	modfiles = append(modfiles, IntfWrapperStubName)
 	if forceRebuild || !FileExists(IntfWrapperStubName) {
@@ -344,7 +376,7 @@ func buildRustGlobalStubFile(component ComponentDefinition, w LanguageWriter, In
 	w.Writeln("use %s::*;", InterfaceMod)
 	w.Writeln("")
 	w.Writeln("// Wrapper struct to implement the wrapper trait for global methods")
-	w.Writeln("struct CWrapper;")
+	w.Writeln("pub struct CWrapper;")
 	w.Writeln("")
 	w.Writeln("impl Wrapper for CWrapper {")
 	w.Writeln("")
@@ -474,3 +506,162 @@ func buildRustStubFile(component ComponentDefinition, class ComponentDefinitionC
 	w.Writeln("")
 	return nil
 }
+
+func buildRustWrapper(component ComponentDefinition, w LanguageWriter, InterfaceMod string) error {
+	// Imports
+	ModName := strings.ToLower(component.NameSpace)
+	w.Writeln("")
+	w.Writeln("// Calls from the C-Interface to the Rust traits via the CWrapper")
+	w.Writeln("// These are the symbols exposed in the shared object interface")
+	w.Writeln("")
+	w.Writeln("use %s::*;", InterfaceMod)
+	w.Writeln("use %s::CWrapper;", ModName)
+	w.Writeln("use std::ffi::{c_char, CStr};")
+	w.Writeln("")
+	cprefix := ModName + "_"
+	// Build the global methods
+	err := writeGlobalRustWrapper(component, w, cprefix)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func buildRustHandle(component ComponentDefinition, w LanguageWriter, InterfaceMod string) error {
+	w.Writeln("")
+	w.Writeln("// Handle passed through interface define the casting maps needed to extract")
+	w.Writeln("")
+	w.Writeln("use %s::*;", InterfaceMod)
+	w.Writeln("")
+	w.Writeln("impl HandleImpl {")
+	w.AddIndentationLevel(1)
+	for i := 0; i < len(component.Classes); i++ {
+		class := component.Classes[i]
+		writeRustHandleAs(component, w, class, false)
+		writeRustHandleAs(component, w, class, true)
+		w.Writeln("")
+	}
+	w.AddIndentationLevel(-1)
+	w.Writeln("}")
+	return nil
+}
+
+func writeRustHandleAs(component ComponentDefinition, w LanguageWriter, class ComponentDefinitionClass, mut bool) error {
+	//parents, err := getParentList(component, class)
+	//if err != nil {
+	//	return err
+	//}
+	Name := class.ClassName
+	if !mut {
+		w.Writeln("pub fn as_%s(&self) -> Option<&dyn %s> {", toSnakeCase(Name), Name)
+	} else {
+		w.Writeln("pub fn as_mut_%s(&mut self) -> Option<&mut dyn %s> {", toSnakeCase(Name), Name)
+	}
+	w.AddIndentationLevel(1)
+	w.Writeln("None")
+	w.AddIndentationLevel(-1)
+	w.Writeln("}")
+	return nil
+}
+
+func writeGlobalRustWrapper(component ComponentDefinition, w LanguageWriter, cprefix string) error {
+	methods := component.Global.Methods
+	for i := 0; i < len(methods); i++ {
+		method := methods[i]
+		err := writeRustMethodWrapper(method, w, cprefix)
+		if err != nil {
+			return err
+		}
+		w.Writeln("")
+	}
+	return nil
+}
+
+func writeRustMethodWrapper(method ComponentDefinitionMethod, w LanguageWriter, cprefix string) error {
+	// Build up the parameter strings
+	parameterString := ""
+	returnName := ""
+	for k := 0; k < len(method.Params); k++ {
+		param := method.Params[k]
+		RustParams, err := generateRustParameters(param, true)
+		if err != nil {
+			return err
+		}
+		for i := 0; i < len(RustParams); i++ {
+			RustParam := RustParams[i]
+			if parameterString == "" {
+				parameterString += fmt.Sprintf("%s : %s", RustParam.ParamName, RustParam.ParamType)
+			} else {
+				parameterString += fmt.Sprintf(", %s : %s", RustParam.ParamName, RustParam.ParamType)
+			}
+		}
+	}
+	w.Writeln("pub fn %s%s(%s) -> i32 {", cprefix, strings.ToLower(method.MethodName), parameterString)
+	w.AddIndentationLevel(1)
+	argsString := ""
+	for k := 0; k < len(method.Params); k++ {
+		param := method.Params[k]
+		OName, err := writeRustParameterConversionArg(param, w)
+		if err != nil {
+			return err
+		}
+		if OName != "" {
+			if argsString == "" {
+				argsString = OName
+			} else {
+				argsString += fmt.Sprintf(", %s", OName)
+			}
+		}
+	}
+	if returnName != "" {
+		w.Writeln("let %s = CWrapper::%s(%s);", returnName, toSnakeCase(method.MethodName), argsString)
+	} else {
+		w.Writeln("CWrapper::%s(%s);", toSnakeCase(method.MethodName), argsString)
+	}
+	w.Writeln("// All ok")
+	w.Writeln("0")
+	w.AddIndentationLevel(-1)
+	w.Writeln("}")
+	return nil
+}
+
+func writeRustParameterConversionArg(param ComponentDefinitionParam, w LanguageWriter) (string, error) {
+	if param.ParamPass == "return" {
+		return "", nil
+	}
+	IName := toSnakeCase(param.ParamName)
+	OName := "_" + IName
+	switch param.ParamType {
+	case "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "single", "double":
+		if param.ParamPass == "in" {
+			w.Writeln("let %s = %s;", OName, IName)
+		} else {
+			w.Writeln("let %s = unsafe {&mut *%s};", OName, IName)
+		}
+	case "class", "optionalclass":
+		if param.ParamPass == "in" {
+			HName := "_Handle_" + IName
+			w.Writeln("let %s = unsafe {&*%s};", HName, IName)
+			w.Writeln("let %s = %s.as_%s().unwrap();", OName, HName, toSnakeCase(param.ParamClass))
+		} else {
+			HName := "_Handle_" + IName
+			w.Writeln("let %s = unsafe {&mut *%s};", HName, IName)
+			w.Writeln("let %s = %s.as_mut_%s().unwrap();", OName, HName, toSnakeCase(param.ParamClass))
+		}
+	case "string":
+		if param.ParamPass == "in" {
+			SName := "_Str_" + IName
+			w.Writeln("let %s = unsafe{ CStr::from_ptr(%s) };", SName, IName)
+			w.Writeln("let %s = %s.to_str().unwrap();", OName, SName)
+		} else {
+			SName := "_String_" + IName
+			w.Writeln("let mut %s = String::new();", SName)
+			w.Writeln("let %s = &mut %s;", OName, SName)
+		}
+	case "bool", "pointer", "struct", "basicarray", "structarray":
+		//return fmt.Errorf("Conversion of type %s for parameter %s not supported", param.ParamType, IName)
+	default:
+		return "", fmt.Errorf("Conversion of type %s for parameter %s not supported as is unknown", param.ParamType, IName)
+	}
+	return OName, nil
+}
diff --git a/Source/languagerust.go b/Source/languagerust.go
index 6d9a0303..d4a5ef86 100644
--- a/Source/languagerust.go
+++ b/Source/languagerust.go
@@ -47,7 +47,7 @@ func toSnakeCase(BaseType string) string {
 
 func writeRustBaseTypeDefinitions(componentdefinition ComponentDefinition, w LanguageWriter, NameSpace string, BaseName string) error {
 	w.Writeln("#[allow(unused_imports)]")
-	w.Writeln("use std::ffi;")
+	w.Writeln("use std::ffi::c_void;")
 	w.Writeln("")
 	w.Writeln("/*************************************************************************************************************************")
 	w.Writeln(" Version definition for %s", NameSpace)
@@ -63,10 +63,28 @@ func writeRustBaseTypeDefinitions(componentdefinition ComponentDefinition, w Lan
 	w.Writeln("")
 
 	w.Writeln("/*************************************************************************************************************************")
-	w.Writeln(" Basic pointers definition for %s", NameSpace)
+	w.Writeln(" Handle definiton for %s", NameSpace)
 	w.Writeln("**************************************************************************************************************************/")
 	w.Writeln("")
-	w.Writeln("type Handle = std::ffi::c_void;")
+	w.Writeln("// Enum of all traits - this acts as a handle as we pass trait pointers through the interface")
+	w.Writeln("pub enum HandleImpl {")
+	w.AddIndentationLevel(1)
+	for i := 0; i < len(componentdefinition.Classes); i++ {
+		class := componentdefinition.Classes[i]
+		if i != len(componentdefinition.Classes)-1 {
+			w.Writeln("T%s(Box<dyn %s>),", class.ClassName, class.ClassName)
+		} else {
+			w.Writeln("T%s(Box<dyn %s>)", class.ClassName, class.ClassName)
+		}
+	}
+	w.AddIndentationLevel(-1)
+	w.Writeln("}")
+	w.Writeln("")
+	w.Writeln("pub type Handle = *mut HandleImpl;")
+	for i := 0; i < len(componentdefinition.Classes); i++ {
+		class := componentdefinition.Classes[i]
+		w.Writeln("pub type %sHandle = *mut HandleImpl;", class.ClassName)
+	}
 
 	if len(componentdefinition.Enums) > 0 {
 		w.Writeln("/*************************************************************************************************************************")
@@ -211,6 +229,25 @@ func generateRustParameters(param ComponentDefinitionParam, isPlain bool) ([]Rus
 	}
 
 	if isPlain {
+		if param.ParamType == "string" {
+			if param.ParamPass == "out" {
+				Params = make([]RustParameter, 3)
+				Params[0].ParamType = "u64"
+				Params[0].ParamName = toSnakeCase(param.ParamName) + "_buffer_size"
+				Params[0].ParamComment = fmt.Sprintf("* @param[in] %s - size of the buffer (including trailing 0)", Params[0].ParamName)
+
+				Params[1].ParamType = "*mut u64"
+				Params[1].ParamName = toSnakeCase(param.ParamName) + "_needed_chars"
+				Params[1].ParamComment = fmt.Sprintf("* @param[out] %s - will be filled with the count of the written bytes, or needed buffer size.", Params[1].ParamName)
+
+				Params[2].ParamType = "*mut c_char"
+				Params[2].ParamName = toSnakeCase(param.ParamName) + "_buffer"
+				Params[2].ParamComment = fmt.Sprintf("* @param[out] %s - %s buffer of %s, may be NULL", Params[2].ParamName, param.ParamClass, param.ParamDescription)
+
+				return Params, nil
+			}
+		}
+
 		if param.ParamType == "basicarray" {
 			return nil, fmt.Errorf("Not yet handled")
 		}
@@ -231,50 +268,51 @@ func generateRustParameterType(param ComponentDefinitionParam, isPlain bool) (st
 	RustParamTypeName := ""
 	ParamTypeName := param.ParamType
 	ParamClass := param.ParamClass
+	BasicType := false
 	switch ParamTypeName {
 	case "uint8":
 		RustParamTypeName = "u8"
-
+		BasicType = true
 	case "uint16":
 		RustParamTypeName = "u16"
-
+		BasicType = true
 	case "uint32":
 		RustParamTypeName = "u32"
-
+		BasicType = true
 	case "uint64":
 		RustParamTypeName = "u64"
-
+		BasicType = true
 	case "int8":
 		RustParamTypeName = "i8"
-
+		BasicType = true
 	case "int16":
 		RustParamTypeName = "i16"
-
+		BasicType = true
 	case "int32":
 		RustParamTypeName = "i32"
-
+		BasicType = true
 	case "int64":
 		RustParamTypeName = "i64"
-
+		BasicType = true
 	case "bool":
 		if isPlain {
 			RustParamTypeName = "u8"
 		} else {
 			RustParamTypeName = "bool"
 		}
-
+		BasicType = true
 	case "single":
 		RustParamTypeName = "f32"
-
+		BasicType = true
 	case "double":
 		RustParamTypeName = "f64"
-
+		BasicType = true
 	case "pointer":
 		RustParamTypeName = "c_void"
-
+		BasicType = true
 	case "string":
 		if isPlain {
-			RustParamTypeName = "*mut char"
+			RustParamTypeName = "*const c_char"
 		} else {
 			switch param.ParamPass {
 			case "out":
@@ -290,17 +328,12 @@ func generateRustParameterType(param ComponentDefinitionParam, isPlain bool) (st
 		if isPlain {
 			RustParamTypeName = fmt.Sprintf("u16")
 		} else {
-			switch param.ParamPass {
-			case "out":
-				RustParamTypeName = fmt.Sprintf("&mut %s", ParamClass)
-			case "in", "return":
-				RustParamTypeName = fmt.Sprintf("%s", ParamClass)
-			}
+			RustParamTypeName = ParamClass
 		}
-
+		BasicType = true
 	case "functiontype":
 		RustParamTypeName = fmt.Sprintf("%s", ParamClass)
-
+		BasicType = true
 	case "struct":
 		if isPlain {
 			RustParamTypeName = fmt.Sprintf("*mut %s", ParamClass)
@@ -353,13 +386,14 @@ func generateRustParameterType(param ComponentDefinitionParam, isPlain bool) (st
 
 	case "class", "optionalclass":
 		if isPlain {
-			RustParamTypeName = fmt.Sprintf("Handle")
+			RustParamTypeName = fmt.Sprintf("%sHandle", ParamClass)
+			BasicType = true
 		} else {
 			switch param.ParamPass {
 			case "out":
-				RustParamTypeName = fmt.Sprintf("&mut impl %s", ParamClass)
+				RustParamTypeName = fmt.Sprintf("&mut dyn %s", ParamClass)
 			case "in":
-				RustParamTypeName = fmt.Sprintf("& impl %s", ParamClass)
+				RustParamTypeName = fmt.Sprintf("& dyn %s", ParamClass)
 			case "return":
 				RustParamTypeName = fmt.Sprintf("Box<dyn %s>", ParamClass)
 			}
@@ -368,6 +402,16 @@ func generateRustParameterType(param ComponentDefinitionParam, isPlain bool) (st
 	default:
 		return "", fmt.Errorf("invalid parameter type \"%s\" for Rust parameter", ParamTypeName)
 	}
-
+	if BasicType {
+		if param.ParamPass == "out" {
+			if isPlain {
+				RustParamOutTypeName := fmt.Sprintf("*mut %s", RustParamTypeName)
+				return RustParamOutTypeName, nil
+			} else {
+				RustParamOutTypeName := fmt.Sprintf("&mut %s", RustParamTypeName)
+				return RustParamOutTypeName, nil
+			}
+		}
+	}
 	return RustParamTypeName, nil
 }