From 85c257c85302bca781cdf4324be9411beb0737da Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Tue, 16 Apr 2024 04:07:17 +0100 Subject: [PATCH] use GCPtr Char instead of String to handle nulls in strings --- backend/src/ffi.cpp | 14 +++++- backend/src/ffi.h | 1 + backend/src/xla/client/xla_builder.cpp | 9 ++++ backend/src/xla/client/xla_computation.cpp | 9 +++- backend/src/xla/client/xla_computation.h | 3 +- backend/src/xla/pjrt/c/pjrt_c_api.cpp | 11 ++++- backend/src/xla/pjrt/c/pjrt_c_api.h | 2 +- backend/src/xla/pjrt/pjrt_executable.cpp | 10 ++--- backend/src/xla/service/BUILD | 1 + backend/src/xla/service/hlo.pb.cpp | 14 ++---- backend/src/xla/service/hlo.pb.h | 5 +-- backend/src/xla/shape.cpp | 2 +- backend/src/xla/shape_util.cpp | 4 +- src/Compiler/Eval.idr | 5 ++- src/Compiler/FFI.idr | 12 ++++-- src/Compiler/Xla/Client/XlaComputation.idr | 10 +++-- src/Compiler/Xla/PJRT/C/PJRT_C_API.idr | 50 ++++++++++++---------- 17 files changed, 102 insertions(+), 60 deletions(-) diff --git a/backend/src/ffi.cpp b/backend/src/ffi.cpp index 8de0251b4..adebb4869 100644 --- a/backend/src/ffi.cpp +++ b/backend/src/ffi.cpp @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include "ffi.h" @@ -34,10 +35,17 @@ extern "C" { } char* string_c_str(string* s) { +// std::cout << "string_c_str ..." << std::endl; auto str = reinterpret_cast(s); +// std::cout << "... s" << std::endl; +// std::cout << *str << std::endl; +// std::cout << "... s length: " << str->length() << std::endl; auto len = str->length(); auto res = (char *) malloc(len); - strncpy(res, str->c_str(), len); + std::copy(str->begin(), str->end(), res); +// std::cout << "... res" << std::endl; +// fwrite(res, sizeof(char), len, stdout); +// std::cout << std::endl; return res; } @@ -60,4 +68,8 @@ extern "C" { void set_array_int(int* arr, int idx, int value) { arr[idx] = value; } + + void set_array_ptr(void** arr, int idx, void* value) { + arr[idx] = value; + } } diff --git a/backend/src/ffi.h b/backend/src/ffi.h index 25a4e93cd..4e3a7f3a5 100644 --- a/backend/src/ffi.h +++ b/backend/src/ffi.h @@ -27,4 +27,5 @@ extern "C" { size_t size(char*); void* index(int idx, void** ptr); void set_array_int(int* arr, int idx, int value); + void set_array_ptr(void** arr, int idx, void* value); } diff --git a/backend/src/xla/client/xla_builder.cpp b/backend/src/xla/client/xla_builder.cpp index 231d21f7f..ab1944568 100644 --- a/backend/src/xla/client/xla_builder.cpp +++ b/backend/src/xla/client/xla_builder.cpp @@ -28,6 +28,15 @@ limitations under the License. #include "xla_builder.h" #include "xla_computation.h" +const char* c_string_copy(std::string str) { + char *res = NULL; + auto len = str.length(); + res = (char *) malloc(len + 1); + strncpy(res, str.c_str(), len); + res[len] = '\0'; + return res; +} + extern "C" { int sizeof_XlaOp() { return sizeof(xla::XlaOp); diff --git a/backend/src/xla/client/xla_computation.cpp b/backend/src/xla/client/xla_computation.cpp index 2d12e73fc..1ceeaab3f 100644 --- a/backend/src/xla/client/xla_computation.cpp +++ b/backend/src/xla/client/xla_computation.cpp @@ -18,6 +18,7 @@ limitations under the License. #include "xla/client/xla_computation.h" #include "../xla_data.pb.h" +#include "../../ffi.h" #include "xla_computation.h" extern "C" { @@ -38,9 +39,13 @@ extern "C" { // } // until I work out how to handle memory of HloModuleProto - const char* XlaComputation_SerializeAsString(XlaComputation* s) { + string* XlaComputation_SerializeAsString(XlaComputation* s) { std::cout << "XlaComputation_SerializeAsString ..." << std::endl; auto s_ = reinterpret_cast(s); - return c_string_copy(s_->proto().SerializeAsString()); + auto serialized = s_->proto().SerializeAsString(); +// std::cout << "... serialized" << std::endl; +// fwrite(serialized.c_str(), sizeof(char), serialized.length(), stdout); +// std::cout << std::endl; + return reinterpret_cast(new std::string(serialized)); } } diff --git a/backend/src/xla/client/xla_computation.h b/backend/src/xla/client/xla_computation.h index 964df242e..d3aa7d978 100644 --- a/backend/src/xla/client/xla_computation.h +++ b/backend/src/xla/client/xla_computation.h @@ -16,11 +16,12 @@ limitations under the License. // we have included this as it appears to be the source of HloModuleProto, but // can't find it, so we'll rely on a transitive BUILD target #include "../service/hlo.pb.h" +#include "../../ffi.h" extern "C" { struct XlaComputation; // void XlaComputation_delete(XlaComputation* s); // const HloModuleProto& XlaComputation_proto(XlaComputation* s); - const char* XlaComputation_SerializeAsString(XlaComputation* s); + string* XlaComputation_SerializeAsString(XlaComputation* s); } diff --git a/backend/src/xla/pjrt/c/pjrt_c_api.cpp b/backend/src/xla/pjrt/c/pjrt_c_api.cpp index a916e9a32..cce723f85 100644 --- a/backend/src/xla/pjrt/c/pjrt_c_api.cpp +++ b/backend/src/xla/pjrt/c/pjrt_c_api.cpp @@ -122,14 +122,14 @@ extern "C" { return api->PJRT_Client_Create(args); } - PJRT_Program* PJRT_Program_new(char* code) { + PJRT_Program* PJRT_Program_new(char* code, size_t code_size) { std::cout << "PJRT_Program_new ..." << std::endl; auto format = pjrt::kHloFormat; return new PJRT_Program{ .struct_size = PJRT_Program_STRUCT_SIZE, .extension_start = nullptr, .code = code, - .code_size = strlen(code), + .code_size = code_size, .format = format.data(), .format_size = format.length(), }; @@ -139,6 +139,9 @@ extern "C" { PJRT_Client* client, PJRT_Program* program, char* compile_options, size_t compile_options_size ) { std::cout << "PJRT_Client_Compile_Args_new ..." << std::endl; +// std::cout << "... code" << std::endl; +// fwrite(program->code, sizeof(char), program->code_size, stdout); +// std::cout << std::endl; return new PJRT_Client_Compile_Args{ .struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE, .extension_start = nullptr, @@ -156,6 +159,10 @@ extern "C" { PJRT_Error* pjrt_client_compile(PJRT_Api* api, PJRT_Client_Compile_Args* args) { std::cout << "pjrt_client_compile ..." << std::endl; +// std::cout << "... compile_options_size " << args->compile_options_size << std::endl; +// std::cout << "... compile_options" << std::endl; +// fwrite(args->compile_options, sizeof(char), args->compile_options_size, stdout); +// std::cout << std::endl; return api->PJRT_Client_Compile(args); } diff --git a/backend/src/xla/pjrt/c/pjrt_c_api.h b/backend/src/xla/pjrt/c/pjrt_c_api.h index ec5f67111..6e059b366 100644 --- a/backend/src/xla/pjrt/c/pjrt_c_api.h +++ b/backend/src/xla/pjrt/c/pjrt_c_api.h @@ -34,7 +34,7 @@ extern "C" { PJRT_Client_Destroy_Args* PJRT_Client_Destroy_Args_new(PJRT_Client* client); PJRT_Error* pjrt_client_destroy(PJRT_Api* api, PJRT_Client_Destroy_Args* args); - PJRT_Program* PJRT_Program_new(char* code); + PJRT_Program* PJRT_Program_new(char* code, size_t code_size); PJRT_Client_Compile_Args* PJRT_Client_Compile_Args_new( PJRT_Client* client, PJRT_Program* program, char* compile_options, size_t compile_options_size ); diff --git a/backend/src/xla/pjrt/pjrt_executable.cpp b/backend/src/xla/pjrt/pjrt_executable.cpp index 316827670..396a6f41d 100644 --- a/backend/src/xla/pjrt/pjrt_executable.cpp +++ b/backend/src/xla/pjrt/pjrt_executable.cpp @@ -41,9 +41,9 @@ extern "C" { .env_option_overrides = {}, .target_config = std::nullopt, }; - xla::CompileOptions::FromProto(*(options->ToProto())); - std::cout << "... serialized options " << std::endl; - std::cout << options->ToProto()->SerializeAsString() << std::endl; +// xla::CompileOptions::FromProto(*(options->ToProto())); +// std::cout << "... serialized options " << std::endl; +// std::cout << options->ToProto()->SerializeAsString() << std::endl; return reinterpret_cast(options); } @@ -51,8 +51,8 @@ extern "C" { std::cout << "CompileOptions_SerializeAsString ..." << std::endl; auto s_ = reinterpret_cast(s); auto res = s_->ToProto()->SerializeAsString(); - std::cout << "... serialized result: " << std::endl; - std::cout << res << std::endl; +// std::cout << "... serialized result: " << std::endl; +// std::cout << res << std::endl; return reinterpret_cast(new std::string(res)); } } diff --git a/backend/src/xla/service/BUILD b/backend/src/xla/service/BUILD index 3f9b7c824..1bff36840 100644 --- a/backend/src/xla/service/BUILD +++ b/backend/src/xla/service/BUILD @@ -6,6 +6,7 @@ cc_library( hdrs = glob(["*.h"]), deps = [ "@xla//xla/service:platform_util", + "//src:src", ], visibility = ["//visibility:public"], ) diff --git a/backend/src/xla/service/hlo.pb.cpp b/backend/src/xla/service/hlo.pb.cpp index 8d2c1ca3e..8dca5101b 100644 --- a/backend/src/xla/service/hlo.pb.cpp +++ b/backend/src/xla/service/hlo.pb.cpp @@ -18,19 +18,11 @@ limitations under the License. #include "hlo.pb.h" -const char* c_string_copy(std::string str) { - char *res = NULL; - auto len = str.length(); - res = (char *) malloc(len + 1); - strncpy(res, str.c_str(), len); - res[len] = '\0'; - return res; -} - extern "C" { - const char* SerializeAsString(HloModuleProto* s) { + string* SerializeAsString(HloModuleProto* s) { std::cout << "SerializeAsString ..." << std::endl; auto s_ = reinterpret_cast(s); - return c_string_copy(s_->SerializeAsString()); + auto serialized = s_->SerializeAsString(); + return reinterpret_cast(new std::string(serialized)); } } diff --git a/backend/src/xla/service/hlo.pb.h b/backend/src/xla/service/hlo.pb.h index 918b50a1a..02305d1dc 100644 --- a/backend/src/xla/service/hlo.pb.h +++ b/backend/src/xla/service/hlo.pb.h @@ -16,11 +16,10 @@ limitations under the License. #include #include "xla/service/hlo.pb.h" - -const char* c_string_copy(std::string str); +#include "../../ffi.h" extern "C" { struct HloModuleProto; - const char* SerializeAsString(HloModuleProto* s); + string* SerializeAsString(HloModuleProto* s); } diff --git a/backend/src/xla/shape.cpp b/backend/src/xla/shape.cpp index 3981d137c..692608f78 100644 --- a/backend/src/xla/shape.cpp +++ b/backend/src/xla/shape.cpp @@ -31,7 +31,7 @@ extern "C" { void set_array_Shape(Shape* arr, int idx, Shape* shape) { std::cout << "set_array_Shape ..." << std::endl; - std::cout << "... shape " << shape << std::endl; +// std::cout << "... shape " << shape << std::endl; reinterpret_cast(arr)[idx] = *reinterpret_cast(shape); } } diff --git a/backend/src/xla/shape_util.cpp b/backend/src/xla/shape_util.cpp index 6d3514f65..5842d31c0 100644 --- a/backend/src/xla/shape_util.cpp +++ b/backend/src/xla/shape_util.cpp @@ -48,6 +48,7 @@ extern "C" { } Shape* MakeShape(int primitive_type, int* shape, int rank) { + std::cout << "MakeShape ..." << std::endl; int64_t shape64[rank]; std::copy(shape, shape + rank, shape64); @@ -56,8 +57,7 @@ extern "C" { (xla::PrimitiveType) primitive_type, absl::Span(shape64, rank) ); - std::cout << "MakeShape ..." << std::endl; - std::cout << "... xla_shape " << xla_shape << std::endl; +// std::cout << "... xla_shape " << xla_shape << std::endl; return reinterpret_cast(xla_shape); } } diff --git a/src/Compiler/Eval.idr b/src/Compiler/Eval.idr index 945615c5c..333e5e56c 100644 --- a/src/Compiler/Eval.idr +++ b/src/Compiler/Eval.idr @@ -243,9 +243,10 @@ execute f shape = do api <- getPjrtApi -- need a gpu version client <- pjrtClientCreate api code <- serializeAsString computation - program <- mkPjrtProgram code + program <- mkPjrtProgram !(cstr code) (size code) compileOptionsStr <- serializeAsString !mkCompileOptions - loadedExec <- pjrtClientCompile api client program !(cstr compileOptionsStr) (size compileOptionsStr) + loadedExec <- pjrtClientCompile + api client program !(cstr compileOptionsStr) (size compileOptionsStr) buffer <- pjrtLoadedExecutableExecute api loadedExec literal <- allocLiteral shape pjrtBufferToHostBuffer api buffer literal diff --git a/src/Compiler/FFI.idr b/src/Compiler/FFI.idr index da9f0a737..d29d6cee4 100644 --- a/src/Compiler/FFI.idr +++ b/src/Compiler/FFI.idr @@ -34,11 +34,13 @@ namespace CppString delete = primIO . prim__stringDelete %foreign (libxla "string_c_str") -prim__stringCStr : GCAnyPtr -> PrimIO String +prim__stringCStr : GCAnyPtr -> PrimIO $ Ptr Char export -cstr : HasIO io => CppString -> io String -cstr (MkCppString str) = primIO $ prim__stringCStr str +cstr : HasIO io => CppString -> io $ GCPtr Char +cstr (MkCppString str) = do + cstr <- primIO $ prim__stringCStr str + onCollect cstr (free . prim__forgetPtr) %foreign (libxla "string_size") prim__stringSize : GCAnyPtr -> Int @@ -94,3 +96,7 @@ mkIntArray xs = do traverse_ (\(idx, x) => primIO $ prim__setArrayInt ptr (cast idx) (cast x)) (enumerate xs) ptr <- onCollect ptr (free . prim__forgetPtr) pure (MkIntArray ptr) + +export +%foreign (libxla "set_array_ptr") +prim__setArrayPtr : AnyPtr -> Int -> AnyPtr -> PrimIO () diff --git a/src/Compiler/Xla/Client/XlaComputation.idr b/src/Compiler/Xla/Client/XlaComputation.idr index a4f8630ca..d36ba44fc 100644 --- a/src/Compiler/Xla/Client/XlaComputation.idr +++ b/src/Compiler/Xla/Client/XlaComputation.idr @@ -48,9 +48,11 @@ prim__hloModuleProtoSerializeAsString : AnyPtr -> PrimIO String export %foreign (libxla "XlaComputation_SerializeAsString") -prim__xlaComputationSerializeAsString : GCAnyPtr -> PrimIO String +prim__xlaComputationSerializeAsString : GCAnyPtr -> PrimIO AnyPtr export -serializeAsString : HasIO io => XlaComputation -> io String -serializeAsString (MkXlaComputation computation) = - primIO $ prim__xlaComputationSerializeAsString computation +serializeAsString : HasIO io => XlaComputation -> io CppString +serializeAsString (MkXlaComputation computation) = do + str <- primIO $ prim__xlaComputationSerializeAsString computation + str <- onCollectAny str CppString.delete + pure (MkCppString str) diff --git a/src/Compiler/Xla/PJRT/C/PJRT_C_API.idr b/src/Compiler/Xla/PJRT/C/PJRT_C_API.idr index 82494b903..0d8544b3f 100644 --- a/src/Compiler/Xla/PJRT/C/PJRT_C_API.idr +++ b/src/Compiler/Xla/PJRT/C/PJRT_C_API.idr @@ -202,17 +202,17 @@ export data PjrtProgram = MkPjrtProgram GCAnyPtr %foreign (libxla "PJRT_Program_new") -prim__mkPjrtProgram : String -> PrimIO AnyPtr +prim__mkPjrtProgram : GCPtr Char -> Int -> PrimIO AnyPtr export -mkPjrtProgram : HasIO io => String -> io PjrtProgram -mkPjrtProgram code = do - ptr <- primIO $ prim__mkPjrtProgram code +mkPjrtProgram : HasIO io => GCPtr Char -> Int -> io PjrtProgram +mkPjrtProgram code codeSize = do + ptr <- primIO $ prim__mkPjrtProgram code codeSize ptr <- onCollectAny ptr free pure (MkPjrtProgram ptr) %foreign (libxla "PJRT_Client_Compile_Args_new") -prim__mkPjrtClientCompileArgs : GCAnyPtr -> GCAnyPtr -> String -> Int -> PrimIO AnyPtr +prim__mkPjrtClientCompileArgs : GCAnyPtr -> GCAnyPtr -> GCPtr Char -> Int -> PrimIO AnyPtr %foreign (libxla "PJRT_Client_Compile_Args_executable") prim__pjrtClientCompileArgsExecutable : AnyPtr -> AnyPtr @@ -234,27 +234,32 @@ pjrtClientCompile : PjrtApi -> PjrtClient -> PjrtProgram -> - String -> + GCPtr Char -> Int -> ErrIO PjrtError PjrtLoadedExecutable -pjrtClientCompile (MkPjrtApi api) (MkPjrtClient client) (MkPjrtProgram program) compileOptions compileOptionsSize = do - putStrLn "pjrtClientCompile ..." - args <- primIO $ prim__mkPjrtClientCompileArgs client program compileOptions compileOptionsSize - err <- primIO $ prim__pjrtClientCompile api args - let executable = prim__pjrtClientCompileArgsExecutable args - free args - try api err =<< do - executable <- onCollectAny executable destroyExecutable - pure $ MkPjrtLoadedExecutable executable +pjrtClientCompile + (MkPjrtApi api) + (MkPjrtClient client) + (MkPjrtProgram program) + compileOptions + compileOptionsSize = do + putStrLn "pjrtClientCompile ..." + args <- primIO $ prim__mkPjrtClientCompileArgs client program compileOptions compileOptionsSize + err <- primIO $ prim__pjrtClientCompile api args + let executable = prim__pjrtClientCompileArgsExecutable args + free args + try api err =<< do + executable <- onCollectAny executable destroyExecutable + pure $ MkPjrtLoadedExecutable executable - where + where - destroyExecutable : AnyPtr -> IO () - destroyExecutable executable = do - args <- primIO $ prim__mkPjrtLoadedExecutableDestroyArgs executable - err <- primIO $ prim__pjrtLoadedExecutableDestroy api args - free args - handleErrOnDestroy api err "PJRT_LoadedExecutable" + destroyExecutable : AnyPtr -> IO () + destroyExecutable executable = do + args <- primIO $ prim__mkPjrtLoadedExecutableDestroyArgs executable + err <- primIO $ prim__pjrtLoadedExecutableDestroy api args + free args + handleErrOnDestroy api err "PJRT_LoadedExecutable" %foreign (libxla "PJRT_ExecuteOptions_new") prim__mkPjrtExecuteOptions : PrimIO AnyPtr @@ -274,6 +279,7 @@ pjrtLoadedExecutableExecute (MkPjrtApi api) (MkPjrtLoadedExecutable executable) putStrLn "pjrtLoadedExecutableExecute ..." outputListsInner <- malloc sizeofPtr outputLists <- malloc sizeofPtr + primIO $ prim__setArrayPtr outputLists 0 outputListsInner options <- primIO prim__mkPjrtExecuteOptions args <- primIO $ prim__mkPjrtLoadedExecutableExecuteArgs executable options outputLists err <- primIO $ prim__pjrtLoadedExecutableExecute api args