Skip to content

Commit a71f2ce

Browse files
committed
feat: serialization
refactor: remove all runtime info from compiled function body perf: optimize mesh codegen fix: pjrt codegen fix: hlosharding codegen feat: serialize/deserialize pipeline
1 parent bbcbcbf commit a71f2ce

File tree

3 files changed

+123
-0
lines changed

3 files changed

+123
-0
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1313
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1414
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1515
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
16+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1617
LLVMOpenMP_jll = "1d63c593-3942-5779-bab2-d838dc0a180e"
1718
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1819
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -73,6 +74,7 @@ Functors = "0.5"
7374
GPUArraysCore = "0.2"
7475
GPUCompiler = "1.3"
7576
HTTP = "1.10.15"
77+
JLD2 = "0.5.12"
7678
KernelAbstractions = "0.9.30"
7779
LLVM = "9.1"
7880
LLVMOpenMP_jll = "18.1.7"

src/Reactant.jl

+2
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ include("Compiler.jl")
181181

182182
include("Overlay.jl")
183183

184+
include("Serialization.jl")
185+
184186
function Enzyme.make_zero(
185187
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
186188
)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}}

src/Serialization.jl

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
module Serialization
2+
3+
# TODO: move these deps into an extension
4+
5+
using JLD2
6+
using Reactant: Reactant, MLIR, XLA
7+
8+
struct SerializedThunk
9+
f
10+
body::Expr
11+
argTypes
12+
IsClosure::Bool
13+
num_parameters::Int
14+
num_results::Int
15+
is_device_present::Bool
16+
num_devices::Int
17+
module_string::String
18+
end
19+
20+
function serialize(
21+
filename::String, thunk::Reactant.Compiler.Thunk{FTy,tag,IsClosure,ArgTypes}
22+
) where {FTy,tag,IsClosure,ArgTypes}
23+
if isempty(thunk.module_string)
24+
throw("To serialize a compiled thunk, ensure it is called with `serializable=true`")
25+
end
26+
27+
serializable_thunk = SerializedThunk(
28+
thunk.f,
29+
Reactant.Compiler.__thunk_body_cache[tag],
30+
ArgTypes,
31+
IsClosure,
32+
thunk.exec.num_parameters,
33+
thunk.exec.num_outputs,
34+
thunk.device !== nothing,
35+
thunk.device !== nothing ? 1 : length(thunk.global_device_ids),
36+
thunk.module_string,
37+
)
38+
39+
return JLD2.jldsave(filename; thunk=serializable_thunk)
40+
end
41+
42+
function deserialize(f, filename::String; client, device, global_device_ids)
43+
if !isfile(filename)
44+
error("File $(filename) does not exist")
45+
end
46+
47+
serialized_thunk = JLD2.jldopen(filename, "r") do file
48+
file["thunk"]
49+
end
50+
51+
mod = MLIR.IR.with_context() do ctx
52+
parse(MLIR.IR.Module, serialized_thunk.module_string)
53+
end
54+
modop = MLIR.IR.Operation(mod)
55+
56+
# We always insert these attributes
57+
num_replicas = Int(MLIR.IR.attr(modop, "mhlo.num_replicas"))
58+
num_partitions = Int(MLIR.IR.attr(modop, "mhlo.num_partitions"))
59+
is_sharded = num_replicas * num_partitions > 1
60+
use_shardy_partitioner = false
61+
62+
if !serialized_thunk.is_device_present
63+
@assert serialized_thunk.num_devices == length(global_device_ids)
64+
end
65+
66+
exec = XLA.compile(
67+
client,
68+
device,
69+
mod;
70+
num_outputs=serialized_thunk.num_results,
71+
num_parameters=serialized_thunk.num_parameters,
72+
is_sharded,
73+
global_device_ids,
74+
num_replicas,
75+
num_partitions,
76+
use_shardy_partitioner,
77+
)
78+
79+
fname = gensym(Symbol(Symbol(f), :_reactant))
80+
Reactant.Compiler.__thunk_body_cache[fname] = serialized_thunk.body
81+
thunk = thunk_from_serialized_thunk(
82+
f,
83+
serialized_thunk,
84+
exec,
85+
fname,
86+
client,
87+
global_device_ids,
88+
device,
89+
serialized_thunk.module_string,
90+
)
91+
92+
return thunk
93+
end
94+
95+
function thunk_from_serialized_thunk(
96+
f::F,
97+
serialized_thunk::SerializedThunk,
98+
exec,
99+
tag,
100+
client,
101+
global_device_ids,
102+
device,
103+
module_string,
104+
) where {F}
105+
return Reactant.Compiler.Thunk{
106+
F,
107+
tag,
108+
serialized_thunk.IsClosure,
109+
serialized_thunk.argTypes,
110+
typeof(exec),
111+
typeof(device),
112+
typeof(client),
113+
typeof(global_device_ids),
114+
}(
115+
f, exec, device, module_string, client, global_device_ids
116+
)
117+
end
118+
119+
end

0 commit comments

Comments
 (0)