Skip to content

Commit dd9b467

Browse files
authored
feat: run shardy passes on julia end (#926)
* feat: run shardy passes on julia end fix: conditionally apply the passes fix: expose the passes as an option * feat: defaults based on entry-point * chore: bump jll version
1 parent 14a8d82 commit dd9b467

File tree

6 files changed

+109
-26
lines changed

6 files changed

+109
-26
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ PythonCall = "0.9"
8686
Random = "1.10"
8787
Random123 = "1.7"
8888
ReactantCore = "0.1.5"
89-
Reactant_jll = "0.0.89"
89+
Reactant_jll = "0.0.90"
9090
Scratch = "1.2"
9191
Sockets = "1.10"
9292
SpecialFunctions = "2.4"

src/Compiler.jl

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,8 @@ function compile_mlir!(
722722
}
723723
}();
724724
optimize::Union{Bool,Symbol}=true,
725+
# default refers to letting XLA handle the shardy inport/propagation/export
726+
shardy_passes::Symbol=:to_mhlo_shardings, # [:default, :to_mhlo_shardings]
725727
no_nan::Bool=false,
726728
backend="gpu",
727729
fn_kwargs=(),
@@ -977,6 +979,34 @@ function compile_mlir!(
977979
error("Invalid optimize option: $(Meta.quot(optimize))")
978980
end
979981

982+
# shardy passes
983+
use_shardy_partitioner = false
984+
if is_sharded
985+
if shardy_passes == :default
986+
# If `:default` is passed in, we will run a pass to export the sharding
987+
# inside the corresponding compile function for IFRT/PJRT. This keeps the
988+
# sharding readable.
989+
use_shardy_partitioner = true
990+
elseif shardy_passes == :to_mhlo_shardings
991+
# Convert all shardy ops to corresponding mhlo attrs/ops that can be consumed by
992+
# XLA (note we need to set `use_shardy_partitioner` to `false` in the options)
993+
# TODO: Use https://github.com/openxla/shardy/blob/01d3205086132d1bdf0867e911c05f489918431d/shardy/dialect/sdy/transforms/propagation/propagation_pipeline.cc#L28 to pass in the options
994+
run_pass_pipeline!(
995+
mod,
996+
join(
997+
["sdy-propagation-pipeline", "xla-sdy-stablehlo-export-pipeline"], ','
998+
),
999+
)
1000+
1001+
# Run our optimization passes here -- we need to be careful to not apply folding
1002+
# here since that violates the semantics of `sdy.constant` which was converted to
1003+
# `stablehlo.constant` by the previous pass.
1004+
run_pass_pipeline!(mod, join(["canonicalize", "cse"], ','))
1005+
else
1006+
error("Invalid shardy_passes option: $(Meta.quot(shardy_passes))")
1007+
end
1008+
end
1009+
9801010
preserved_args = Tuple{TracedType,Int}[]
9811011
results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)]
9821012
nresults = MLIR.IR.Value[]
@@ -1040,6 +1070,7 @@ function compile_mlir!(
10401070
concrete_result,
10411071
mlir_fn_res.sharding_mesh,
10421072
mlir_fn_res.mutated_args,
1073+
use_shardy_partitioner,
10431074
)
10441075
end
10451076

@@ -1050,7 +1081,11 @@ See also [`@code_xla`](@ref), [`@code_mhlo`](@ref).
10501081
"""
10511082
macro code_hlo(args...)
10521083
default_options = Dict{Symbol,Any}(
1053-
:optimize => true, :no_nan => false, :client => nothing, :raise => false
1084+
:optimize => true,
1085+
:no_nan => false,
1086+
:client => nothing,
1087+
:raise => false,
1088+
:shardy_passes => :(:default),
10541089
)
10551090
compile_expr, (; compiled) = compile_call_expr(
10561091
__module__, compile_mlir, default_options, args...
@@ -1074,7 +1109,11 @@ See also [`@code_xla`](@ref), [`@code_hlo`](@ref).
10741109
"""
10751110
macro code_mhlo(args...)
10761111
default_options = Dict{Symbol,Any}(
1077-
:optimize => true, :no_nan => false, :client => nothing, :raise => false
1112+
:optimize => true,
1113+
:no_nan => false,
1114+
:client => nothing,
1115+
:raise => false,
1116+
:shardy_passes => :(:default),
10781117
)
10791118
compile_expr, (; compiled) = compile_call_expr(
10801119
__module__, compile_xla, default_options, args...
@@ -1098,7 +1137,11 @@ See also [`@code_mhlo`](@ref), [`@code_hlo`](@ref).
10981137
"""
10991138
macro code_xla(args...)
11001139
default_options = Dict{Symbol,Any}(
1101-
:optimize => true, :no_nan => false, :client => nothing, :raise => false
1140+
:optimize => true,
1141+
:no_nan => false,
1142+
:client => nothing,
1143+
:raise => false,
1144+
:shardy_passes => :(:to_mhlo_shardings),
11021145
)
11031146
compile_expr, (; compiled) = compile_call_expr(
11041147
__module__, compile_xla, default_options, args...
@@ -1125,6 +1168,7 @@ macro compile(args...)
11251168
:no_nan => false,
11261169
:client => nothing,
11271170
:raise => false,
1171+
:shardy_passes => :(:to_mhlo_shardings),
11281172
)
11291173
return esc(first(compile_call_expr(__module__, compile, default_options, args...)))
11301174
end
@@ -1141,6 +1185,7 @@ macro jit(args...)
11411185
:no_nan => false,
11421186
:client => nothing,
11431187
:raise => false,
1188+
:shardy_passes => :(:to_mhlo_shardings),
11441189
)
11451190
compile_expr, (; compiled, args) = compile_call_expr(
11461191
__module__, compile, default_options, args...
@@ -1166,6 +1211,7 @@ function compile_call_expr(mod, compiler, options::Dict, args...)
11661211
options[option_name] = option.args[2]
11671212
end
11681213
end
1214+
11691215
call = only(args)
11701216
f_symbol = gensym(:f)
11711217
args_symbol = gensym(:args)
@@ -1207,18 +1253,20 @@ function compile_call_expr(mod, compiler, options::Dict, args...)
12071253
error("Invalid function call: $(call)")
12081254
end
12091255

1210-
return quote
1211-
$(f_symbol) = $(fname)
1212-
$(args_symbol) = $(args_rhs)
1213-
$(kwargs_symbol) = (; $(kwargs_rhs...))
1214-
$(compiled_symbol) = $(compiler)(
1215-
$(f_symbol),
1216-
$(args_symbol);
1217-
fn_kwargs=$(kwargs_symbol),
1218-
$(Expr.(:kw, keys(options), values(options))...),
1219-
)
1220-
end,
1221-
(; compiled=compiled_symbol, args=args_symbol)
1256+
return (
1257+
quote
1258+
$(f_symbol) = $(fname)
1259+
$(args_symbol) = $(args_rhs)
1260+
$(kwargs_symbol) = (; $(kwargs_rhs...))
1261+
$(compiled_symbol) = $(compiler)(
1262+
$(f_symbol),
1263+
$(args_symbol);
1264+
fn_kwargs=$(kwargs_symbol),
1265+
$(Expr.(:kw, keys(options), values(options))...),
1266+
)
1267+
end,
1268+
(; compiled=compiled_symbol, args=args_symbol),
1269+
)
12221270
end
12231271

12241272
"""
@@ -1765,6 +1813,7 @@ function compile_xla(f, args; client=nothing, kwargs...)
17651813
global_device_ids,
17661814
mlir_fn_res.num_replicas,
17671815
mlir_fn_res.num_partitions,
1816+
mlir_fn_res.use_shardy_partitioner,
17681817
)
17691818

17701819
return mod, exec, mlir_fn_res, device, client

src/TracedUtils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ mutable struct CompiledMlirFnResult{
148148
concrete_result::CR
149149
sharding_mesh::M
150150
mutated_args::MA
151+
use_shardy_partitioner::Bool
151152
end
152153

153154
function make_mlir_fn(
@@ -434,6 +435,7 @@ function make_mlir_fn(
434435
nothing,
435436
sharding_mesh,
436437
mutated_args,
438+
true,
437439
)
438440
end
439441

src/xla/IFRT/LoadedExecutable.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ function XLA.compile(
7979
num_parameters::Int64,
8080
num_replicas::Int64,
8181
num_partitions::Int64,
82+
use_shardy_partitioner::Bool,
8283
)
8384
device_id = is_sharded ? Int64(-1) : Int64(XLA.device_ordinal(device))
8485
GC.@preserve client mod begin
@@ -90,6 +91,7 @@ function XLA.compile(
9091
global_device_ids::Ptr{Clong},
9192
length(global_device_ids)::Clong,
9293
XLA.CUDA_DATA_DIR[]::Cstring,
94+
use_shardy_partitioner::Bool,
9395
)::Ptr{Cvoid}
9496
end
9597
return LoadedExecutable(

src/xla/PJRT/LoadedExecutable.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ function XLA.compile(
7373
num_parameters::Int64,
7474
num_replicas::Int64,
7575
num_partitions::Int64,
76+
use_shardy_partitioner::Bool,
7677
)
7778
device_id = is_sharded ? Int64(-1) : Int64(XLA.device_ordinal(device))
7879
GC.@preserve client mod begin
@@ -84,6 +85,7 @@ function XLA.compile(
8485
global_device_ids::Ptr{Clong},
8586
length(global_device_ids)::Clong,
8687
XLA.CUDA_DATA_DIR[]::Cstring,
88+
use_shardy_partitioner::Bool,
8789
)::Ptr{Cvoid}
8890
end
8991
return LoadedExecutable(

test/sharding.jl

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# Currently an extremely simple test
21
using Reactant, Test
32

43
const addressable_devices = Reactant.addressable_devices()
@@ -113,6 +112,9 @@ end
113112
)
114113
@test Array(@jit fn_test2(x_ra)) fn_test2(x)
115114
@test Reactant.to_number(@jit sum(x_ra)) sum(x)
115+
116+
@test Array(@jit shardy_passes = :to_mhlo_shardings fn_test3(x_ra)) fn_test3(x)
117+
@test Reactant.to_number(@jit shardy_passes = :to_mhlo_shardings sum(x_ra)) sum(x)
116118
else
117119
@warn "Not enough addressable devices to run sharding tests"
118120
end
@@ -125,8 +127,11 @@ end
125127
x_ra = Reactant.to_rarray(
126128
x; sharding=Sharding.NamedSharding(mesh, (("data", "model"), nothing))
127129
)
128-
@test Array(@jit fn_test2(x_ra)) fn_test2(x)
129-
@test Reactant.to_number(@jit sum(x_ra)) sum(x)
130+
@test Array(@jit shardy_passes = :default fn_test2(x_ra)) fn_test2(x)
131+
@test Reactant.to_number(@jit shardy_passes = :default sum(x_ra)) sum(x)
132+
133+
@test Array(@jit shardy_passes = :to_mhlo_shardings fn_test3(x_ra)) fn_test3(x)
134+
@test Reactant.to_number(@jit shardy_passes = :to_mhlo_shardings sum(x_ra)) sum(x)
130135
else
131136
@warn "Not enough addressable devices to run sharding tests"
132137
end
@@ -142,8 +147,12 @@ end
142147
mesh, ("model", nothing); is_closed=(false, false)
143148
),
144149
)
145-
@test Array(@jit fn_test2(x_ra)) fn_test2(x)
146-
@test Reactant.to_number(@jit sum(x_ra)) sum(x)
150+
151+
@test Array(@jit shardy_passes = :default fn_test2(x_ra)) fn_test2(x)
152+
@test Reactant.to_number(@jit shardy_passes = :default sum(x_ra)) sum(x)
153+
154+
@test Array(@jit shardy_passes = :to_mhlo_shardings fn_test3(x_ra)) fn_test3(x)
155+
@test Reactant.to_number(@jit shardy_passes = :to_mhlo_shardings sum(x_ra)) sum(x)
147156
else
148157
@warn "Not enough addressable devices to run sharding tests"
149158
end
@@ -193,6 +202,9 @@ end
193202

194203
hlo = @code_hlo fn_with_constraint(x_ra)
195204
@test contains(repr(hlo), "sharding_constraint")
205+
hlo = @code_hlo shardy_passes = :to_mhlo_shardings fn_with_constraint(x_ra)
206+
@test !contains(repr(hlo), "sharding_constraint")
207+
@test length(collect(eachmatch(r"mhlo.sharding", repr(hlo)))) == 3
196208

197209
z = Reactant.to_rarray(x; sharding=constraint)
198210
res = @jit fn_with_constraint(x_ra)
@@ -208,6 +220,11 @@ end
208220

209221
hlo = @code_hlo fn_with_constraint(x_ra_no_sharding)
210222
@test contains(repr(hlo), "sharding_constraint")
223+
hlo = @code_hlo shardy_passes = :to_mhlo_shardings fn_with_constraint(
224+
x_ra_no_sharding
225+
)
226+
@test !contains(repr(hlo), "sharding_constraint")
227+
@test length(collect(eachmatch(r"mhlo.sharding", repr(hlo)))) == 3
211228

212229
res = @jit fn_with_constraint(x_ra_no_sharding)
213230
@test x .+ x Array(res)
@@ -228,14 +245,17 @@ end
228245
x; sharding=Sharding.NamedSharding(mesh, ("data", "model"))
229246
)
230247

231-
@test Array(@jit sum(x_ra; dims=2)) sum(x; dims=2)
248+
@test Array(@jit shardy_passes = :default sum(x_ra; dims=2)) sum(x; dims=2)
249+
@test Array(@jit shardy_passes = :to_mhlo_shardings sum(x_ra; dims=2))
250+
sum(x; dims=2)
232251

233252
x = reshape(collect(Float32, 1:25), 5, 5)
234253
x_ra = Reactant.to_rarray(
235254
x; sharding=Sharding.NamedSharding(mesh, ("data", "model"))
236255
)
237256

238-
@test Array(@jit fn_test2(x_ra)) fn_test2(x)
257+
@test Array(@jit shardy_passes = :default fn_test2(x_ra)) fn_test2(x)
258+
@test Array(@jit shardy_passes = :to_mhlo_shardings fn_test2(x_ra)) fn_test2(x)
239259
else
240260
@warn "Not enough addressable devices to run sharding tests"
241261
end
@@ -271,24 +291,32 @@ end
271291
randn(Float32, 4, 5); sharding=Sharding.NamedSharding(mesh, ((:x, :y), :z))
272292
)
273293

274-
y_ra = Reactant.to_rarray(randn(Float32, 5, 4); sharding=Sharding.NoSharding())
294+
y_ra_arr = randn(Float32, 5, 4)
295+
y_ra = Reactant.to_rarray(y_ra_arr; sharding=Sharding.NoSharding())
296+
y_ra_2 = Reactant.to_rarray(y_ra_arr; sharding=Sharding.NoSharding())
275297

276298
function fn(x, y)
277299
z = x * y
278300
y[1:2, 1:2] .= 1
279301
return z
280302
end
281303

282-
y_ra_arr = Array(y_ra)
283304
x_ra_arr = Array(x_ra)
284305
z_ra_arr = fn(x_ra_arr, y_ra_arr)
285306

286-
z_ra = @jit fn(x_ra, y_ra)
307+
z_ra = @jit shardy_passes = :default fn(x_ra, y_ra)
287308
y_ra_final = Array(y_ra)
288309

289310
@test z_ra_arr Array(z_ra)
290311
@test y_ra_final[1:2, 1:2] y_ra_arr[1:2, 1:2]
291312
@test all(y_ra_final[1:2, 1:2] .== 1)
313+
314+
z_ra2 = @jit shardy_passes = :to_mhlo_shardings fn(x_ra, y_ra_2)
315+
y_ra_final2 = Array(y_ra_2)
316+
317+
@test z_ra_arr Array(z_ra2)
318+
@test y_ra_final[1:2, 1:2] y_ra_arr[1:2, 1:2]
319+
@test all(y_ra_final[1:2, 1:2] .== 1)
292320
else
293321
@warn "Not enough addressable devices to run sharding tests"
294322
end

0 commit comments

Comments
 (0)