Skip to content

Commit d2aebc5

Browse files
committed
feat: serialize/deserialize pipeline
1 parent 6d25d7f commit d2aebc5

File tree

4 files changed

+129
-61
lines changed

4 files changed

+129
-61
lines changed

src/Compiler.jl

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ function create_result(
190190
error("TODO: Not yet Implemented. Use IFRT for this.")
191191
end
192192
sharding = pop!(path_to_shard_info, path)
193-
return :(ConcretePJRTNumber{$T,length($(restore))}(($(restore)...,), $sharding))
193+
return :(ConcretePJRTNumber{$T}(($(restore)...,), $sharding))
194194
else
195195
return :(ConcretePJRTNumber{$T}($restore))
196196
end
@@ -199,9 +199,7 @@ function create_result(
199199
# We will set the data for this later
200200
if path_to_shard_info !== nothing && haskey(path_to_shard_info, path)
201201
sharding = pop!(path_to_shard_info, path)
202-
return :(ConcretePJRTNumber{$T,length($(tocopy.data))}(
203-
($(tocopy.data...,)), $sharding
204-
))
202+
return :(ConcretePJRTNumber{$T}(($(tocopy.data...,)), $sharding))
205203
end
206204
return :(ConcretePJRTNumber{$T}($(tocopy.data)))
207205
end
@@ -2141,13 +2139,11 @@ function compile_xla(f, args; client=nothing, serializable::Bool=false, kwargs..
21412139

21422140
# XLA.compile mutates the module, for serialization we need to keep a copy
21432141
if serializable
2144-
# XXX: Double free??
2145-
# mod_pre_xla = MLIR.IR.Module(
2146-
# MLIR.API.mlirModuleFromOperation(copy(MLIR.IR.Operation(mod)))
2147-
# )
2148-
error("TODO")
2142+
iobuffer = IOBuffer()
2143+
show(iobuffer, mod)
2144+
module_string = String(take!(iobuffer))
21492145
else
2150-
mod_pre_xla = mod
2146+
module_string = ""
21512147
end
21522148

21532149
exec = XLA.compile(
@@ -2163,7 +2159,7 @@ function compile_xla(f, args; client=nothing, serializable::Bool=false, kwargs..
21632159
mlir_fn_res.use_shardy_partitioner,
21642160
)
21652161

2166-
return mod_pre_xla, exec, mlir_fn_res, device, client
2162+
return mod, exec, mlir_fn_res, device, client, module_string
21672163
finally
21682164
MLIR.IR.deactivate!(ctx)
21692165
end
@@ -2172,8 +2168,8 @@ function compile_xla(f, args; client=nothing, serializable::Bool=false, kwargs..
21722168
return results
21732169
end
21742170

2175-
function compile(f, args; sync=false, serializable=false, kwargs...)
2176-
_, exec, mlir_fn_res, device, client = compile_xla(f, args; kwargs...)
2171+
function compile(f, args; sync=false, kwargs...)
2172+
_, exec, mlir_fn_res, device, client, str = compile_xla(f, args; kwargs...)
21772173
(; linear_args, seen_args, linear_results, preserved_args, concrete_result) =
21782174
mlir_fn_res
21792175

@@ -2294,7 +2290,7 @@ function compile(f, args; sync=false, serializable=false, kwargs...)
22942290
mlir_fn_res.fnwrapped,
22952291
exec,
22962292
mlir_fn_res.is_sharded ? nothing : device,
2297-
serializable ? mod : nothing,
2293+
str,
22982294
client,
22992295
mlir_fn_res.global_device_ids,
23002296
)
@@ -2303,13 +2299,11 @@ end
23032299
# inspired by RuntimeGeneratedFunction.jl
23042300
const __thunk_body_cache = Dict{Symbol,Expr}()
23052301

2306-
struct Thunk{
2307-
FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy,M<:Union{Nothing,MLIR.IR.Module},ClientTy,GD
2308-
}
2302+
struct Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy,ClientTy,GD}
23092303
f::FTy
23102304
exec::ExecTy
23112305
device::DeviceTy
2312-
mod::M
2306+
module_string::String
23132307
client::ClientTy
23142308
global_device_ids::GD
23152309
end
@@ -2331,7 +2325,7 @@ struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end
23312325
function Base.showerror(
23322326
io::IO,
23332327
::MisMatchedThunkTypeError{
2334-
Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy,ClientTy,GD},FoundTypes
2328+
<:Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy,ClientTy,GD},FoundTypes
23352329
},
23362330
) where {FTy,tag,ArgTypes,FoundTypes,IsClosure,ExecTy,DeviceTy,ClientTy,GD}
23372331
print(
@@ -2354,15 +2348,15 @@ function Base.showerror(
23542348
)
23552349
end
23562350

2357-
@generated function (thunk::Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy})(
2351+
@generated function (thunk::Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy})(
23582352
args...
2359-
) where {FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy}
2353+
) where {FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy}
23602354
FoundTypes = Tuple{args...}
23612355
if ArgTypes != FoundTypes
23622356
return quote
23632357
throw(
23642358
$(MisMatchedThunkTypeError{
2365-
Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy},FoundTypes
2359+
Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy},FoundTypes
23662360
}()),
23672361
)
23682362
end
@@ -2386,23 +2380,22 @@ function register_thunk(
23862380
isclosure::Bool,
23872381
exec,
23882382
device,
2389-
mod,
2383+
module_string,
23902384
client,
23912385
global_device_ids,
23922386
)
23932387
__thunk_body_cache[tag] = body
23942388
return Thunk{
23952389
Core.Typeof(f),
23962390
tag,
2397-
argtys,
23982391
isclosure,
2392+
argtys,
23992393
Core.Typeof(exec),
24002394
Core.Typeof(device),
2401-
Core.Typeof(mod),
24022395
Core.Typeof(client),
24032396
Core.Typeof(global_device_ids),
24042397
}(
2405-
f, exec, device, mod, client, global_device_ids
2398+
f, exec, device, module_string, client, global_device_ids
24062399
)
24072400
end
24082401

src/Serialization.jl

Lines changed: 101 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,119 @@
11
module Serialization
22

33
# TODO: move these deps into an extension
4-
# TODO: Deal with sharding/global devices
54

65
using JLD2
7-
using Reactant: Reactant, MLIR
6+
using Reactant: Reactant, MLIR, XLA
87

9-
struct SerializedThunk{FTy,tag,ArgTypes,IsClosure}
10-
f::FTy
8+
struct SerializedThunk
9+
f
1110
body::Expr
11+
argTypes
12+
IsClosure::Bool
13+
num_parameters::Int
14+
num_results::Int
15+
is_device_present::Bool
16+
num_devices::Int
17+
module_string::String
1218
end
1319

14-
# function JLD2.writeas(
15-
# ::Type{<:Reactant.Compiler.Thunk{FTy,tag,ArgTypes,IsClosure}}
16-
# ) where {FTy,tag,ArgTypes,IsClosure}
17-
# return SerializedThunk{FTy,tag,ArgTypes,IsClosure}
18-
# end
19-
20-
# function JLD2.wconvert(
21-
# ::Type{SerializedThunk{FTy,tag,ArgTypes,IsClosure}},
22-
# thunk::Reactant.Compiler.Thunk{FTy,tag,ArgTypes,IsClosure},
23-
# ) where {FTy,tag,ArgTypes,IsClosure}
24-
# if thunk.mod === nothing
25-
# throw("To serialize a compiled thunk, ensure it is called with `serializable=true`")
26-
# end
27-
28-
# return error("TODO")
29-
# end
30-
31-
# function JLD2.rconvert(
32-
# ::Type{Reactant.Compiler.Thunk{FTy,tag,ArgTypes,IsClosure}},
33-
# serialized::SerializedThunk{FTy,tag,ArgTypes,IsClosure},
34-
# ) where {FTy,tag,ArgTypes,IsClosure}
35-
# return error("TODO")
36-
# end
37-
3820
function serialize(
39-
thunk::Reactant.Compiler.Thunk{FTy,tag,ArgTypes,IsClosure}
40-
) where {FTy,tag,ArgTypes,IsClosure}
41-
if thunk.mod === nothing
21+
filename::String, thunk::Reactant.Compiler.Thunk{FTy,tag,IsClosure,ArgTypes}
22+
) where {FTy,tag,IsClosure,ArgTypes}
23+
if isempty(thunk.module_string)
4224
throw("To serialize a compiled thunk, ensure it is called with `serializable=true`")
4325
end
4426

45-
return serializable_thunk = SerializedThunk{FTy,tag,ArgTypes,IsClosure}(
46-
thunk.f, Reactant.Compiler.__thunk_body_cache[tag]
27+
serializable_thunk = SerializedThunk(
28+
thunk.f,
29+
Reactant.Compiler.__thunk_body_cache[tag],
30+
ArgTypes,
31+
IsClosure,
32+
thunk.exec.num_parameters,
33+
thunk.exec.num_outputs,
34+
thunk.device !== nothing,
35+
thunk.device !== nothing ? 1 : length(thunk.global_device_ids),
36+
thunk.module_string,
4737
)
38+
39+
return JLD2.jldsave(filename; thunk=serializable_thunk)
4840
end
4941

50-
function deserialize() end
42+
function deserialize(f, filename::String; client, device, global_device_ids)
43+
if !isfile(filename)
44+
error("File $(filename) does not exist")
45+
end
46+
47+
serialized_thunk = JLD2.jldopen(filename, "r") do file
48+
file["thunk"]
49+
end
50+
51+
mod = MLIR.IR.with_context() do ctx
52+
parse(MLIR.IR.Module, serialized_thunk.module_string)
53+
end
54+
modop = MLIR.IR.Operation(mod)
55+
56+
# We always insert these attributes
57+
num_replicas = Int(MLIR.IR.attr(modop, "mhlo.num_replicas"))
58+
num_partitions = Int(MLIR.IR.attr(modop, "mhlo.num_partitions"))
59+
is_sharded = num_replicas * num_partitions > 1
60+
use_shardy_partitioner = false
61+
62+
if !serialized_thunk.is_device_present
63+
@assert serialized_thunk.num_devices == length(global_device_ids)
64+
end
65+
66+
exec = XLA.compile(
67+
client,
68+
device,
69+
mod;
70+
num_outputs=serialized_thunk.num_results,
71+
num_parameters=serialized_thunk.num_parameters,
72+
is_sharded,
73+
global_device_ids,
74+
num_replicas,
75+
num_partitions,
76+
use_shardy_partitioner,
77+
)
78+
79+
fname = gensym(Symbol(Symbol(f), :_reactant))
80+
Reactant.Compiler.__thunk_body_cache[fname] = serialized_thunk.body
81+
thunk = thunk_from_serialized_thunk(
82+
f,
83+
serialized_thunk,
84+
exec,
85+
fname,
86+
client,
87+
global_device_ids,
88+
device,
89+
serialized_thunk.module_string,
90+
)
91+
92+
return thunk
93+
end
94+
95+
function thunk_from_serialized_thunk(
96+
f::F,
97+
serialized_thunk::SerializedThunk,
98+
exec,
99+
tag,
100+
client,
101+
global_device_ids,
102+
device,
103+
module_string,
104+
) where {F}
105+
return Reactant.Compiler.Thunk{
106+
F,
107+
tag,
108+
serialized_thunk.IsClosure,
109+
serialized_thunk.argTypes,
110+
typeof(exec),
111+
typeof(device),
112+
typeof(client),
113+
typeof(global_device_ids),
114+
}(
115+
f, exec, device, module_string, client, global_device_ids
116+
)
117+
end
51118

52119
end

src/Sharding.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ function sharding_to_array_slices(
714714
else
715715
Reactant.ConcreteRArray(ones(Float32, size_x...); kws...)
716716
end
717-
_, exec, _, _, _ = Reactant.Compiler.compile_xla(
717+
_, exec, _, _, _, _ = Reactant.Compiler.compile_xla(
718718
Reactant.Ops.negate, (tmp,); input_shardings=IdDict(tmp => sharding)
719719
)
720720

src/Types.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ function ConcretePJRTNumber{T}(data::Tuple{XLA.PJRT.AsyncBuffer}) where {T}
7272
return ConcretePJRTNumber{T,1,Sharding.NoShardInfo}(data, Sharding.NoShardInfo())
7373
end
7474

75+
function ConcretePJRTNumber{T}(data::NTuple{D,XLA.PJRT.AsyncBuffer}, sharding) where {T,D}
76+
return ConcretePJRTNumber{T,D,typeof(sharding)}(data, sharding)
77+
end
78+
7579
@leaf ConcretePJRTNumber
7680

7781
function ConcretePJRTNumber{T}(data::T2; kwargs...) where {T<:Number,T2<:Number}
@@ -212,6 +216,10 @@ function ConcreteIFRTNumber{T}(data::XLA.IFRT.AsyncArray) where {T}
212216
return ConcreteIFRTNumber{T,Sharding.NoShardInfo}(data, Sharding.NoShardInfo())
213217
end
214218

219+
function ConcreteIFRTNumber{T}(data::XLA.IFRT.AsyncArray, sharding) where {T}
220+
return ConcreteIFRTNumber{T,typeof(sharding)}(data, sharding)
221+
end
222+
215223
@leaf ConcreteIFRTNumber
216224

217225
function ConcreteIFRTNumber{T}(data::T2; kwargs...) where {T<:Number,T2<:Number}

0 commit comments

Comments
 (0)