Skip to content

Commit 51a1f46

Browse files
avik-palwsmoses
andauthored
feat: multi GPU support (#587)
* feat: create PJRT client on all GPUs * fix: remove return * fix: indexing * fix: cuda_visible_devices handling * revert: to old state * feat: add devices api * feat: expose more build options * fix: remove function * fix: build_local.jl * feat: correct device handling for concreterarray * docs: add to docs * feat: correctly set build options * fix: empty device list * fix: precompilation??? * Update WORKSPACE * chore: bump version * fix: precompilation --------- Co-authored-by: William Moses <[email protected]>
1 parent 4849c6b commit 51a1f46

File tree

13 files changed

+249
-87
lines changed

13 files changed

+249
-87
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ PythonCall = "0.9"
7070
Random = "1.10"
7171
Random123 = "1.7"
7272
ReactantCore = "0.1.4"
73-
Reactant_jll = "0.0.51"
73+
Reactant_jll = "0.0.52"
7474
Scratch = "1.2"
7575
Sockets = "1.10"
7676
SpecialFunctions = "2.4"

deps/Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
[deps]
2+
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
3+
BinaryBuilderBase = "7f725544-6523-48cd-82d1-3fa08ff4056e"
4+
Clang = "40e3b903-d033-50b4-a0cc-940c62c95e31"
25
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
36
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
47
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
58
Scratch = "6c6a2e73-6563-6170-7368-637461726353"
6-
Clang = "40e3b903-d033-50b4-a0cc-940c62c95e31"
7-
BinaryBuilderBase = "7f725544-6523-48cd-82d1-3fa08ff4056e"
89

910
[compat]
1011
Clang = "0.18"

deps/ReactantExtra/API.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,10 @@ extern "C" PjRtDevice *ClientGetAddressableDevice(PjRtClient *client,
388388
client->LookupAddressableDevice(PjRtLocalDeviceId(device_id)));
389389
}
390390

391+
extern "C" const char *ClientGetPlatformName(PjRtClient *client) {
392+
return cstr_from_string(client->platform_name());
393+
}
394+
391395
// To keep in sync with JLAllocatorStats in src/XLA.jl
392396
struct JLAllocatorStats {
393397
int64_t num_allocs;
@@ -578,14 +582,22 @@ extern "C" MlirModule ConvertLLVMStrToMLIR(const char *lmod, MlirContext cctx) {
578582

579583
/* Note that this */
580584
extern "C" xla::PjRtLoadedExecutable *ClientCompile(PjRtClient *client,
581-
MlirModule cmod) {
585+
MlirModule cmod,
586+
int device_ordinal,
587+
int num_replicas,
588+
int num_partitions,
589+
bool use_shardy_partitioner) {
582590
auto program =
583591
std::make_unique<xla::ifrt::HloProgram>(cast<ModuleOp>(*unwrap(cmod)));
584592

585593
CompileOptions options;
586-
// options.argument_layouts;
587-
// options.executable_build_options.set_device_ordinal();
588-
// options.executable_build_options.set_result_layout();
594+
595+
if (device_ordinal >= 0) {
596+
options.executable_build_options.set_device_ordinal(device_ordinal);
597+
}
598+
options.executable_build_options.set_num_replicas(num_replicas);
599+
options.executable_build_options.set_num_partitions(num_partitions);
600+
options.executable_build_options.set_use_shardy_partitioner(use_shardy_partitioner);
589601

590602
auto addressable_devices = client->addressable_devices();
591603
if (!addressable_devices.empty()) {

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "1f193a884b1a6e149d43ed77a77a5bdb15882ba9"
12+
ENZYMEXLA_COMMIT = "fd5517f2223adcf579c165a0387231ef4931f55b"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(

deps/build_local.jl

Lines changed: 65 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,49 @@
11
# Invoke with
2-
# `julia --project=deps deps/build_local.jl [dbg/opt] [auto/cpu/cuda]`
2+
# `julia --project=deps deps/build_local.jl [--debug] [--backend=auto/cpu/cuda]`
33

44
# the pre-built ReactantExtra_jll might not be loadable on this platform
55
Reactant_jll = Base.UUID("0192cb87-2b54-54ad-80e0-3be72ad8a3c0")
66

7+
using ArgParse
8+
9+
s = ArgParseSettings()
10+
#! format: off
11+
@add_arg_table! s begin
12+
"--debug"
13+
help = "Build with debug mode (-c dbg)."
14+
action = :store_true
15+
"--backend"
16+
help = "Build with the specified backend (auto, cpu, cuda)."
17+
default = "auto"
18+
arg_type = String
19+
"--gcc_host_compiler_path"
20+
help = "Path to the gcc host compiler."
21+
default = "/usr/bin/gcc"
22+
arg_type = String
23+
"--cc"
24+
default = "/home/wmoses/llvms/llvm16-r/clang+llvm-16.0.2-x86_64-linux-gnu-ubuntu-22.04/bin/clang"
25+
arg_type = String
26+
"--hermetic_python_version"
27+
help = "Hermetic Python version."
28+
default = "3.10"
29+
arg_type = String
30+
# For GCC < 13 we need to disable these flags
31+
"--xnn_disable_avx512fp16"
32+
help = "Disable AVX512 FP16 support in XNNPACK."
33+
action = :store_true
34+
"--xnn_disable_avxvnniint8"
35+
help = "Disable AVX VNNI INT8 support in XNNPACK."
36+
action = :store_true
37+
end
38+
#! format: on
39+
parsed_args = parse_args(ARGS, s)
40+
41+
println("Parsed args:")
42+
for (k, v) in parsed_args
43+
println(" $k = $v")
44+
end
45+
println()
46+
747
using Pkg, Scratch, Preferences, Libdl
848

949
# 1. Get a scratch directory
@@ -41,27 +81,10 @@ run(
4181
# --@local_config_cuda//:cuda_compiler=nvcc
4282
# --crosstool_top="@local_config_cuda//crosstool:toolchain"
4383

44-
build_kind = if length(ARGS) 1
45-
kind = ARGS[1]
46-
if kind ("dbg", "opt")
47-
error("Invalid build kind $(kind). Valid options are 'dbg' and 'opt'")
48-
end
49-
kind
50-
else
51-
"dbg"
52-
end
84+
build_kind = parsed_args["debug"] ? "dbg" : "opt"
5385

54-
@info "Building JLL with -c $(build_kind)"
55-
56-
build_backend = if length(ARGS) 2
57-
backend = ARGS[2]
58-
if backend ("auto", "cpu", "cuda")
59-
error("Invalid build backend $(backend). Valid options are 'auto', 'cpu', and 'cuda'")
60-
end
61-
backend
62-
else
63-
"auto"
64-
end
86+
build_backend = parsed_args["backend"]
87+
@assert build_backend in ("auto", "cpu", "cuda")
6588

6689
if build_backend == "auto"
6790
build_backend = try
@@ -78,8 +101,6 @@ elseif build_backend == "cpu"
78101
""
79102
end
80103

81-
@info "Building JLL with backend $(build_backend)"
82-
83104
bazel_cmd = if !isnothing(Sys.which("bazelisk"))
84105
"bazelisk"
85106
elseif !isnothing(Sys.which("bazel"))
@@ -90,36 +111,28 @@ end
90111

91112
@info "Building JLL with $(bazel_cmd)"
92113

93-
if isempty(arg)
94-
run(
95-
Cmd(
96-
`$(bazel_cmd) build -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1])
97-
--repo_env HERMETIC_PYTHON_VERSION="3.10"
98-
--check_visibility=false --verbose_failures :libReactantExtra.so`;
99-
dir=source_dir,
100-
),
101-
)
102-
else
103-
run(
104-
Cmd(
105-
`$(bazel_cmd) build $(arg) -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1])
106-
--repo_env=GCC_HOST_COMPILER_PATH=/usr/bin/gcc
107-
--repo_env=CC=/home/wmoses/llvms/llvm16-r/clang+llvm-16.0.2-x86_64-linux-gnu-ubuntu-22.04/bin/clang
108-
--repo_env HERMETIC_PYTHON_VERSION="3.10"
109-
--check_visibility=false --verbose_failures :libReactantExtra.so`;
110-
dir=source_dir,
111-
),
112-
)
114+
gcc_host_compiler_path = parsed_args["gcc_host_compiler_path"]
115+
cc = parsed_args["cc"]
116+
hermetic_python_version = parsed_args["hermetic_python_version"]
117+
118+
build_cmd_list = [bazel_cmd, "build"]
119+
!isempty(arg) && push!(build_cmd_list, arg)
120+
append!(build_cmd_list, ["-c", "$(build_kind)"])
121+
push!(build_cmd_list, "--action_env=JULIA=$(Base.julia_cmd().exec[1])")
122+
if parsed_args["xnn_disable_avx512fp16"]
123+
push!(build_cmd_list, "--define=xnn_enable_avx512fp16=false")
124+
end
125+
if parsed_args["xnn_disable_avxvnniint8"]
126+
push!(build_cmd_list, "--define=xnn_enable_avxvnniint8=false")
113127
end
114-
# env=Dict("HOME"=>ENV["HOME"], "PATH"=>joinpath(source_dir, "..")*":"*ENV["PATH"])))
128+
push!(build_cmd_list, "--repo_env=HERMETIC_PYTHON_VERSION=$(hermetic_python_version)")
129+
push!(build_cmd_list, "--repo_env=GCC_HOST_COMPILER_PATH=$(gcc_host_compiler_path)")
130+
push!(build_cmd_list, "--repo_env=CC=$(cc)")
131+
push!(build_cmd_list, "--check_visibility=false")
132+
push!(build_cmd_list, "--verbose_failures")
133+
push!(build_cmd_list, ":libReactantExtra.so")
115134

116-
run(Cmd(`rm -f libReactantExtra.dylib`; dir=joinpath(source_dir, "bazel-bin")))
117-
run(
118-
Cmd(
119-
`ln -s libReactantExtra.so libReactantExtra.dylib`;
120-
dir=joinpath(source_dir, "bazel-bin"),
121-
),
122-
)
135+
run(Cmd(Cmd(build_cmd_list); dir=source_dir))
123136

124137
# Discover built libraries
125138
built_libs = filter(readdir(joinpath(source_dir, "bazel-bin"))) do file
@@ -129,7 +142,7 @@ end
129142
lib_path = joinpath(source_dir, "bazel-bin", only(built_libs))
130143
isfile(lib_path) || error("Could not find library $lib_path in build directory")
131144

132-
# Tell ReactReactantExtra_jllant_jll to load our library instead of the default artifact one
145+
# Tell ReactantExtra_jll to load our library instead of the default artifact one
133146
set_preferences!(
134147
joinpath(dirname(@__DIR__), "LocalPreferences.toml"),
135148
"Reactant_jll",

docs/src/api/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,10 @@ Reactant.Profiler.with_profiler
3333
Reactant.Profiler.annotate
3434
Reactant.Profiler.@annotate
3535
```
36+
37+
## Devices
38+
39+
```@docs
40+
Reactant.devices
41+
Reactant.addressable_devices
42+
```

src/Compiler.jl

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,13 @@ end
614614
@compile [optimize = ...] [no_nan = <true/false>] [sync = <true/false>] f(args...)
615615
"""
616616
macro compile(args...)
617-
default_options = Dict{Symbol,Any}(:optimize => true, :sync => false, :no_nan => false)
617+
default_options = Dict{Symbol,Any}(
618+
:optimize => true,
619+
:sync => false,
620+
:no_nan => false,
621+
:client => nothing,
622+
:device => nothing,
623+
)
618624
return esc(first(compile_call_expr(__module__, compile, default_options, args...)))
619625
end
620626

@@ -624,7 +630,13 @@ end
624630
Run @compile f(args..) then immediately execute it
625631
"""
626632
macro jit(args...)
627-
default_options = Dict{Symbol,Any}(:optimize => true, :sync => false, :no_nan => false)
633+
default_options = Dict{Symbol,Any}(
634+
:optimize => true,
635+
:sync => false,
636+
:no_nan => false,
637+
:client => nothing,
638+
:device => nothing,
639+
)
628640
compile_expr, (; compiled, args) = compile_call_expr(
629641
__module__, compile, default_options, args...
630642
)
@@ -955,7 +967,7 @@ function codegen_xla_call(exec, flatten_names, donated_args_mask, nresults)
955967
return concretized_res_names, xla_call_code
956968
end
957969

958-
function compile_xla(f, args; client=nothing, optimize=true, no_nan=false)
970+
function compile_xla(f, args; client=nothing, optimize=true, no_nan=false, device=nothing)
959971
# register MLIR dialects
960972
ctx = MLIR.IR.Context(Reactant.registry[], false)
961973
context_gc_vector[ctx] = Vector{TracedRArray}(undef, 0)
@@ -969,32 +981,69 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false)
969981
mod, f, args; optimize, no_nan
970982
)
971983

972-
if isnothing(client)
984+
# Resolve client and device
985+
device_ordinal = -1
986+
if device === nothing
973987
if length(linear_args) > 0
974-
for (k, _) in Iterators.filter(((_, v),) -> v isa TracedRArray, seen_args)
975-
client = XLA.client(k.data)
988+
devices_list = [
989+
XLA.device(k.data) for (k, v) in seen_args if v isa TracedRArray
990+
]
991+
if !isempty(devices_list)
992+
@assert allequal(devices_list) "All arguments must be on the same device: $(devices_list)"
993+
device = first(devices_list)
976994
end
977995
end
978-
if isnothing(client)
996+
end
997+
998+
if client === nothing
999+
if device !== nothing
1000+
client = XLA.client(device)
1001+
else
9791002
client = XLA.default_backend[]
1003+
device = XLA.ClientGetDevice(client, XLA.default_device_idx[])
1004+
device_ordinal = XLA.default_device_idx[]
1005+
end
1006+
else
1007+
if device !== nothing
1008+
@assert client == XLA.client(device) "client ($(client)) and XLA.client(device) ($(XLA.client(device))) must be the same"
1009+
else
1010+
device = XLA.ClientGetDevice(client, XLA.default_device_idx[])
1011+
device_ordinal = XLA.default_device_idx[]
9801012
end
9811013
end
9821014

1015+
if device_ordinal < 0
1016+
device_ordinal = XLA.DeviceToClientDeviceOrdinal(device)
1017+
end
1018+
9831019
# compile MLIR module to XLA executable
984-
exec = XLA.Compile(client, mod)
985-
return exec,
986-
linear_args, linear_results, preserved_args, seen_args, concrete_result,
987-
isclosure
1020+
exec = XLA.Compile(
1021+
client,
1022+
mod;
1023+
device_ordinal,
1024+
num_replicas=1,
1025+
num_partitions=1,
1026+
use_shardy_partitioner=false,
1027+
)
1028+
return (
1029+
exec,
1030+
linear_args,
1031+
linear_results,
1032+
preserved_args,
1033+
seen_args,
1034+
concrete_result,
1035+
isclosure,
1036+
)
9881037
finally
9891038
MLIR.IR.deactivate!(ctx)
9901039
end
9911040
Base.delete!(context_gc_vector, ctx)
9921041
return results
9931042
end
9941043

995-
function compile(f, args; client=nothing, optimize=true, sync=false, no_nan=false)
1044+
function compile(f, args; sync=false, kwargs...)
9961045
exec, linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_xla(
997-
f, args; client, optimize, no_nan
1046+
f, args; kwargs...
9981047
)
9991048

10001049
preserved_args_idx = last.(preserved_args)

0 commit comments

Comments
 (0)