Skip to content

Commit

Permalink
use GCPtr Char instead of String to handle nulls in strings
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley committed Apr 16, 2024
1 parent 1cbce53 commit 85c257c
Show file tree
Hide file tree
Showing 17 changed files with 102 additions and 60 deletions.
14 changes: 13 additions & 1 deletion backend/src/ffi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#include <cstddef>
#include <string>
#include <cstring>
#include <iostream>

#include "ffi.h"

Expand All @@ -34,10 +35,17 @@ extern "C" {
}

char* string_c_str(string* s) {
// std::cout << "string_c_str ..." << std::endl;
auto str = reinterpret_cast<std::string*>(s);
// std::cout << "... s" << std::endl;
// std::cout << *str << std::endl;
// std::cout << "... s length: " << str->length() << std::endl;
auto len = str->length();
auto res = (char *) malloc(len);
strncpy(res, str->c_str(), len);
std::copy(str->begin(), str->end(), res);
// std::cout << "... res" << std::endl;
// fwrite(res, sizeof(char), len, stdout);
// std::cout << std::endl;
return res;
}

Expand All @@ -60,4 +68,8 @@ extern "C" {
void set_array_int(int* arr, int idx, int value) {
arr[idx] = value;
}

void set_array_ptr(void** arr, int idx, void* value) {
arr[idx] = value;
}
}
1 change: 1 addition & 0 deletions backend/src/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ extern "C" {
size_t size(char*);
void* index(int idx, void** ptr);
void set_array_int(int* arr, int idx, int value);
void set_array_ptr(void** arr, int idx, void* value);
}
9 changes: 9 additions & 0 deletions backend/src/xla/client/xla_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ limitations under the License.
#include "xla_builder.h"
#include "xla_computation.h"

const char* c_string_copy(std::string str) {
char *res = NULL;
auto len = str.length();
res = (char *) malloc(len + 1);
strncpy(res, str.c_str(), len);
res[len] = '\0';
return res;
}

extern "C" {
int sizeof_XlaOp() {
return sizeof(xla::XlaOp);
Expand Down
9 changes: 7 additions & 2 deletions backend/src/xla/client/xla_computation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include "xla/client/xla_computation.h"

#include "../xla_data.pb.h"
#include "../../ffi.h"
#include "xla_computation.h"

extern "C" {
Expand All @@ -38,9 +39,13 @@ extern "C" {
// }

// until I work out how to handle memory of HloModuleProto
const char* XlaComputation_SerializeAsString(XlaComputation* s) {
string* XlaComputation_SerializeAsString(XlaComputation* s) {
std::cout << "XlaComputation_SerializeAsString ..." << std::endl;
auto s_ = reinterpret_cast<xla::XlaComputation*>(s);
return c_string_copy(s_->proto().SerializeAsString());
auto serialized = s_->proto().SerializeAsString();
// std::cout << "... serialized" << std::endl;
// fwrite(serialized.c_str(), sizeof(char), serialized.length(), stdout);
// std::cout << std::endl;
return reinterpret_cast<string*>(new std::string(serialized));
}
}
3 changes: 2 additions & 1 deletion backend/src/xla/client/xla_computation.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ limitations under the License.
// we have included this as it appears to be the source of HloModuleProto, but
// can't find it, so we'll rely on a transitive BUILD target
#include "../service/hlo.pb.h"
#include "../../ffi.h"

extern "C" {
struct XlaComputation;

// void XlaComputation_delete(XlaComputation* s);
// const HloModuleProto& XlaComputation_proto(XlaComputation* s);
const char* XlaComputation_SerializeAsString(XlaComputation* s);
string* XlaComputation_SerializeAsString(XlaComputation* s);
}
11 changes: 9 additions & 2 deletions backend/src/xla/pjrt/c/pjrt_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,14 @@ extern "C" {
return api->PJRT_Client_Create(args);
}

PJRT_Program* PJRT_Program_new(char* code) {
PJRT_Program* PJRT_Program_new(char* code, size_t code_size) {
std::cout << "PJRT_Program_new ..." << std::endl;
auto format = pjrt::kHloFormat;
return new PJRT_Program{
.struct_size = PJRT_Program_STRUCT_SIZE,
.extension_start = nullptr,
.code = code,
.code_size = strlen(code),
.code_size = code_size,
.format = format.data(),
.format_size = format.length(),
};
Expand All @@ -139,6 +139,9 @@ extern "C" {
PJRT_Client* client, PJRT_Program* program, char* compile_options, size_t compile_options_size
) {
std::cout << "PJRT_Client_Compile_Args_new ..." << std::endl;
// std::cout << "... code" << std::endl;
// fwrite(program->code, sizeof(char), program->code_size, stdout);
// std::cout << std::endl;
return new PJRT_Client_Compile_Args{
.struct_size = PJRT_Client_Compile_Args_STRUCT_SIZE,
.extension_start = nullptr,
Expand All @@ -156,6 +159,10 @@ extern "C" {

PJRT_Error* pjrt_client_compile(PJRT_Api* api, PJRT_Client_Compile_Args* args) {
std::cout << "pjrt_client_compile ..." << std::endl;
// std::cout << "... compile_options_size " << args->compile_options_size << std::endl;
// std::cout << "... compile_options" << std::endl;
// fwrite(args->compile_options, sizeof(char), args->compile_options_size, stdout);
// std::cout << std::endl;
return api->PJRT_Client_Compile(args);
}

Expand Down
2 changes: 1 addition & 1 deletion backend/src/xla/pjrt/c/pjrt_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ extern "C" {
PJRT_Client_Destroy_Args* PJRT_Client_Destroy_Args_new(PJRT_Client* client);
PJRT_Error* pjrt_client_destroy(PJRT_Api* api, PJRT_Client_Destroy_Args* args);

PJRT_Program* PJRT_Program_new(char* code);
PJRT_Program* PJRT_Program_new(char* code, size_t code_size);
PJRT_Client_Compile_Args* PJRT_Client_Compile_Args_new(
PJRT_Client* client, PJRT_Program* program, char* compile_options, size_t compile_options_size
);
Expand Down
10 changes: 5 additions & 5 deletions backend/src/xla/pjrt/pjrt_executable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,18 @@ extern "C" {
.env_option_overrides = {},
.target_config = std::nullopt,
};
xla::CompileOptions::FromProto(*(options->ToProto()));
std::cout << "... serialized options " << std::endl;
std::cout << options->ToProto()->SerializeAsString() << std::endl;
// xla::CompileOptions::FromProto(*(options->ToProto()));
// std::cout << "... serialized options " << std::endl;
// std::cout << options->ToProto()->SerializeAsString() << std::endl;
return reinterpret_cast<CompileOptions*>(options);
}

string* CompileOptions_SerializeAsString(CompileOptions* s) {
std::cout << "CompileOptions_SerializeAsString ..." << std::endl;
auto s_ = reinterpret_cast<xla::CompileOptions*>(s);
auto res = s_->ToProto()->SerializeAsString();
std::cout << "... serialized result: " << std::endl;
std::cout << res << std::endl;
// std::cout << "... serialized result: " << std::endl;
// std::cout << res << std::endl;
return reinterpret_cast<string*>(new std::string(res));
}
}
1 change: 1 addition & 0 deletions backend/src/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ cc_library(
hdrs = glob(["*.h"]),
deps = [
"@xla//xla/service:platform_util",
"//src:src",
],
visibility = ["//visibility:public"],
)
14 changes: 3 additions & 11 deletions backend/src/xla/service/hlo.pb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,11 @@ limitations under the License.

#include "hlo.pb.h"

const char* c_string_copy(std::string str) {
char *res = NULL;
auto len = str.length();
res = (char *) malloc(len + 1);
strncpy(res, str.c_str(), len);
res[len] = '\0';
return res;
}

extern "C" {
const char* SerializeAsString(HloModuleProto* s) {
string* SerializeAsString(HloModuleProto* s) {
std::cout << "SerializeAsString ..." << std::endl;
auto s_ = reinterpret_cast<xla::HloModuleProto*>(s);
return c_string_copy(s_->SerializeAsString());
auto serialized = s_->SerializeAsString();
return reinterpret_cast<string*>(new std::string(serialized));
}
}
5 changes: 2 additions & 3 deletions backend/src/xla/service/hlo.pb.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@ limitations under the License.
#include <string>

#include "xla/service/hlo.pb.h"

const char* c_string_copy(std::string str);
#include "../../ffi.h"

extern "C" {
struct HloModuleProto;

const char* SerializeAsString(HloModuleProto* s);
string* SerializeAsString(HloModuleProto* s);
}
2 changes: 1 addition & 1 deletion backend/src/xla/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ extern "C" {

void set_array_Shape(Shape* arr, int idx, Shape* shape) {
std::cout << "set_array_Shape ..." << std::endl;
std::cout << "... shape " << shape << std::endl;
// std::cout << "... shape " << shape << std::endl;
reinterpret_cast<xla::Shape*>(arr)[idx] = *reinterpret_cast<xla::Shape*>(shape);
}
}
4 changes: 2 additions & 2 deletions backend/src/xla/shape_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ extern "C" {
}

Shape* MakeShape(int primitive_type, int* shape, int rank) {
std::cout << "MakeShape ..." << std::endl;
int64_t shape64[rank];
std::copy(shape, shape + rank, shape64);

Expand All @@ -56,8 +57,7 @@ extern "C" {
(xla::PrimitiveType) primitive_type,
absl::Span<const int64_t>(shape64, rank)
);
std::cout << "MakeShape ..." << std::endl;
std::cout << "... xla_shape " << xla_shape << std::endl;
// std::cout << "... xla_shape " << xla_shape << std::endl;
return reinterpret_cast<Shape*>(xla_shape);
}
}
5 changes: 3 additions & 2 deletions src/Compiler/Eval.idr
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,10 @@ execute f shape = do
api <- getPjrtApi -- need a gpu version
client <- pjrtClientCreate api
code <- serializeAsString computation
program <- mkPjrtProgram code
program <- mkPjrtProgram !(cstr code) (size code)
compileOptionsStr <- serializeAsString !mkCompileOptions
loadedExec <- pjrtClientCompile api client program !(cstr compileOptionsStr) (size compileOptionsStr)
loadedExec <- pjrtClientCompile
api client program !(cstr compileOptionsStr) (size compileOptionsStr)
buffer <- pjrtLoadedExecutableExecute api loadedExec
literal <- allocLiteral shape
pjrtBufferToHostBuffer api buffer literal
Expand Down
12 changes: 9 additions & 3 deletions src/Compiler/FFI.idr
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ namespace CppString
delete = primIO . prim__stringDelete

%foreign (libxla "string_c_str")
prim__stringCStr : GCAnyPtr -> PrimIO String
prim__stringCStr : GCAnyPtr -> PrimIO $ Ptr Char

export
cstr : HasIO io => CppString -> io String
cstr (MkCppString str) = primIO $ prim__stringCStr str
cstr : HasIO io => CppString -> io $ GCPtr Char
cstr (MkCppString str) = do
cstr <- primIO $ prim__stringCStr str
onCollect cstr (free . prim__forgetPtr)

%foreign (libxla "string_size")
prim__stringSize : GCAnyPtr -> Int
Expand Down Expand Up @@ -94,3 +96,7 @@ mkIntArray xs = do
traverse_ (\(idx, x) => primIO $ prim__setArrayInt ptr (cast idx) (cast x)) (enumerate xs)
ptr <- onCollect ptr (free . prim__forgetPtr)
pure (MkIntArray ptr)

export
%foreign (libxla "set_array_ptr")
prim__setArrayPtr : AnyPtr -> Int -> AnyPtr -> PrimIO ()
10 changes: 6 additions & 4 deletions src/Compiler/Xla/Client/XlaComputation.idr
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ prim__hloModuleProtoSerializeAsString : AnyPtr -> PrimIO String

export
%foreign (libxla "XlaComputation_SerializeAsString")
prim__xlaComputationSerializeAsString : GCAnyPtr -> PrimIO String
prim__xlaComputationSerializeAsString : GCAnyPtr -> PrimIO AnyPtr

export
serializeAsString : HasIO io => XlaComputation -> io String
serializeAsString (MkXlaComputation computation) =
primIO $ prim__xlaComputationSerializeAsString computation
serializeAsString : HasIO io => XlaComputation -> io CppString
serializeAsString (MkXlaComputation computation) = do
str <- primIO $ prim__xlaComputationSerializeAsString computation
str <- onCollectAny str CppString.delete
pure (MkCppString str)
50 changes: 28 additions & 22 deletions src/Compiler/Xla/PJRT/C/PJRT_C_API.idr
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,17 @@ export
data PjrtProgram = MkPjrtProgram GCAnyPtr

%foreign (libxla "PJRT_Program_new")
prim__mkPjrtProgram : String -> PrimIO AnyPtr
prim__mkPjrtProgram : GCPtr Char -> Int -> PrimIO AnyPtr

export
mkPjrtProgram : HasIO io => String -> io PjrtProgram
mkPjrtProgram code = do
ptr <- primIO $ prim__mkPjrtProgram code
mkPjrtProgram : HasIO io => GCPtr Char -> Int -> io PjrtProgram
mkPjrtProgram code codeSize = do
ptr <- primIO $ prim__mkPjrtProgram code codeSize
ptr <- onCollectAny ptr free
pure (MkPjrtProgram ptr)

%foreign (libxla "PJRT_Client_Compile_Args_new")
prim__mkPjrtClientCompileArgs : GCAnyPtr -> GCAnyPtr -> String -> Int -> PrimIO AnyPtr
prim__mkPjrtClientCompileArgs : GCAnyPtr -> GCAnyPtr -> GCPtr Char -> Int -> PrimIO AnyPtr

%foreign (libxla "PJRT_Client_Compile_Args_executable")
prim__pjrtClientCompileArgsExecutable : AnyPtr -> AnyPtr
Expand All @@ -234,27 +234,32 @@ pjrtClientCompile :
PjrtApi ->
PjrtClient ->
PjrtProgram ->
String ->
GCPtr Char ->
Int ->
ErrIO PjrtError PjrtLoadedExecutable
pjrtClientCompile (MkPjrtApi api) (MkPjrtClient client) (MkPjrtProgram program) compileOptions compileOptionsSize = do
putStrLn "pjrtClientCompile ..."
args <- primIO $ prim__mkPjrtClientCompileArgs client program compileOptions compileOptionsSize
err <- primIO $ prim__pjrtClientCompile api args
let executable = prim__pjrtClientCompileArgsExecutable args
free args
try api err =<< do
executable <- onCollectAny executable destroyExecutable
pure $ MkPjrtLoadedExecutable executable
pjrtClientCompile
(MkPjrtApi api)
(MkPjrtClient client)
(MkPjrtProgram program)
compileOptions
compileOptionsSize = do
putStrLn "pjrtClientCompile ..."
args <- primIO $ prim__mkPjrtClientCompileArgs client program compileOptions compileOptionsSize
err <- primIO $ prim__pjrtClientCompile api args
let executable = prim__pjrtClientCompileArgsExecutable args
free args
try api err =<< do
executable <- onCollectAny executable destroyExecutable
pure $ MkPjrtLoadedExecutable executable

where
where

destroyExecutable : AnyPtr -> IO ()
destroyExecutable executable = do
args <- primIO $ prim__mkPjrtLoadedExecutableDestroyArgs executable
err <- primIO $ prim__pjrtLoadedExecutableDestroy api args
free args
handleErrOnDestroy api err "PJRT_LoadedExecutable"
destroyExecutable : AnyPtr -> IO ()
destroyExecutable executable = do
args <- primIO $ prim__mkPjrtLoadedExecutableDestroyArgs executable
err <- primIO $ prim__pjrtLoadedExecutableDestroy api args
free args
handleErrOnDestroy api err "PJRT_LoadedExecutable"

%foreign (libxla "PJRT_ExecuteOptions_new")
prim__mkPjrtExecuteOptions : PrimIO AnyPtr
Expand All @@ -274,6 +279,7 @@ pjrtLoadedExecutableExecute (MkPjrtApi api) (MkPjrtLoadedExecutable executable)
putStrLn "pjrtLoadedExecutableExecute ..."
outputListsInner <- malloc sizeofPtr
outputLists <- malloc sizeofPtr
primIO $ prim__setArrayPtr outputLists 0 outputListsInner
options <- primIO prim__mkPjrtExecuteOptions
args <- primIO $ prim__mkPjrtLoadedExecutableExecuteArgs executable options outputLists
err <- primIO $ prim__pjrtLoadedExecutableExecute api args
Expand Down

0 comments on commit 85c257c

Please sign in to comment.