diff --git a/Project.toml b/Project.toml index 5297df2915..1d9c67f3e7 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" LLVMOpenMP_jll = "1d63c593-3942-5779-bab2-d838dc0a180e" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -73,6 +74,7 @@ Functors = "0.5" GPUArraysCore = "0.2" GPUCompiler = "1.3" HTTP = "1.10.15" +JLD2 = "0.5.12" KernelAbstractions = "0.9.30" LLVM = "9.1" LLVMOpenMP_jll = "18.1.7" diff --git a/src/Reactant.jl b/src/Reactant.jl index 06ae61f871..5ce8dd0428 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -181,6 +181,8 @@ include("Compiler.jl") include("Overlay.jl") +include("Serialization.jl") + function Enzyme.make_zero( ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) )::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}} diff --git a/src/Serialization.jl b/src/Serialization.jl new file mode 100644 index 0000000000..3a03f120df --- /dev/null +++ b/src/Serialization.jl @@ -0,0 +1,119 @@ +module Serialization + +# TODO: move these deps into an extension + +using JLD2 +using Reactant: Reactant, MLIR, XLA + +struct SerializedThunk + f + body::Expr + argTypes + IsClosure::Bool + num_parameters::Int + num_results::Int + is_device_present::Bool + num_devices::Int + module_string::String +end + +function serialize( + filename::String, thunk::Reactant.Compiler.Thunk{FTy,tag,IsClosure,ArgTypes} +) where {FTy,tag,IsClosure,ArgTypes} + if isempty(thunk.module_string) + throw("To serialize a compiled thunk, ensure it is called with `serializable=true`") + end + + serializable_thunk = SerializedThunk( + thunk.f, + Reactant.Compiler.__thunk_body_cache[tag], + ArgTypes, + IsClosure, + thunk.exec.num_parameters, + thunk.exec.num_outputs, + thunk.device !== nothing, + thunk.device !== nothing ? 1 : length(thunk.global_device_ids), + thunk.module_string, + ) + + return JLD2.jldsave(filename; thunk=serializable_thunk) +end + +function deserialize(f, filename::String; client, device, global_device_ids) + if !isfile(filename) + error("File $(filename) does not exist") + end + + serialized_thunk = JLD2.jldopen(filename, "r") do file + file["thunk"] + end + + mod = MLIR.IR.with_context() do ctx + parse(MLIR.IR.Module, serialized_thunk.module_string) + end + modop = MLIR.IR.Operation(mod) + + # We always insert these attributes + num_replicas = Int(MLIR.IR.attr(modop, "mhlo.num_replicas")) + num_partitions = Int(MLIR.IR.attr(modop, "mhlo.num_partitions")) + is_sharded = num_replicas * num_partitions > 1 + use_shardy_partitioner = false + + if !serialized_thunk.is_device_present + @assert serialized_thunk.num_devices == length(global_device_ids) + end + + exec = XLA.compile( + client, + device, + mod; + num_outputs=serialized_thunk.num_results, + num_parameters=serialized_thunk.num_parameters, + is_sharded, + global_device_ids, + num_replicas, + num_partitions, + use_shardy_partitioner, + ) + + fname = gensym(Symbol(Symbol(f), :_reactant)) + Reactant.Compiler.__thunk_body_cache[fname] = serialized_thunk.body + thunk = thunk_from_serialized_thunk( + f, + serialized_thunk, + exec, + fname, + client, + global_device_ids, + device, + serialized_thunk.module_string, + ) + + return thunk +end + +function thunk_from_serialized_thunk( + f::F, + serialized_thunk::SerializedThunk, + exec, + tag, + client, + global_device_ids, + device, + module_string, +) where {F} + return Reactant.Compiler.Thunk{ + F, + tag, + serialized_thunk.IsClosure, + serialized_thunk.argTypes, + typeof(exec), + typeof(device), + typeof(client), + typeof(global_device_ids), + }( + f, exec, device, module_string, client, global_device_ids + ) +end + +end