Skip to content

Output sharding of structs with concrete arrays #1227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
DhairyaLGandhi opened this issue Apr 30, 2025 · 3 comments · May be fixed by #1275
Open

Output sharding of structs with concrete arrays #1227

DhairyaLGandhi opened this issue Apr 30, 2025 · 3 comments · May be fixed by #1275
Labels
bug Something isn't working

Comments

@DhairyaLGandhi
Copy link

using Enzyme
using Reactant

struct MyModel{D}
    decoder::D
end

(m::MyModel)(x) = m.decoder * x

m = MyModel(Reactant.to_rarray(rand(128, 128)));

function loss(model, r)
    out = model(r)
    sum(out)
end

gr = @compile Enzyme.gradient(Reverse, Const(loss), m, r)
gr(Reverse, Const(loss), m, r)

produces:

ERROR: TypeError: in new, expected ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}, got a value of type ConcretePJRTArray{Float64, 2, 2, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{2, Reactant.Sharding.Mesh{2, UnitRange{Int64}}}, Vector{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}}
Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Reactant/80BYw/src/Compiler.jl:2963 [inlined]
 [2] (::Reactant.Compiler.Thunk{…})(::ReverseMode{…}, ::Const{…}, ::MyModel{…}, ::ConcretePJRTArray{…})
   @ Reactant.Compiler ~/.julia/packages/Reactant/80BYw/src/Compiler.jl:3037
 [3] top-level scope
   @ ~/arpa/jsmo/fine_tuing/fine_tuning/reactant.jl:269
Some type information was truncated. Use `show(err)` to see complete types.

Similar error with IFRT as well:

(Assuming the XLA runtime is set to IFRT via Preferences)

ERROR: TypeError: in new, expected ConcreteIFRTArray{Float64, 2, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}, Nothing}, got a value of type ConcreteIFRTArray{Float64, 2, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{2, Reactant.Sharding.Mesh{2, UnitRange{Int64}}}, Vector{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}, Nothing}
Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Reactant/80BYw/src/Compiler.jl:2963 [inlined]
 [2] (::Reactant.Compiler.Thunk{…})(::ReverseMode{…}, ::Const{…}, ::MyModel{…}, ::ConcreteIFRTArray{…})
   @ Reactant.Compiler ~/.julia/packages/Reactant/80BYw/src/Compiler.jl:3037
 [3] top-level scope
   @ ~/arpa/jsmo/fine_tuing/fine_tuning/reactant.jl:267
Some type information was truncated. Use `show(err)` to see complete types.
@DhairyaLGandhi
Copy link
Author

Some relevant pieces from the lowered code which contain the type mismatch if that's helpful

%50 = Reactant.Sharding.ShardInfo::Core.Const(Reactant.Sharding.ShardInfo)
│   %51 = (Reactant.Sharding.NamedSharding)(mesh#928::Core.PartialStruct(Reactant.Sharding.Mesh{2, UnitRange{Int64}}, Any[Vector{Int64}, Core.Const(0:1), Core.Const((:dev1, :
dev2)), Core.Const((1, 2))]), Vector{Union{Nothing, Symbol}}[[:dev2], [nothing]], (true, true), (-1, -1), Vector{Union{Nothing, Tuple{Int64, Int64}}}[[nothing], [nothing]])::
Core.PartialStruct(Reactant.Sharding.NamedSharding{2, Reactant.Sharding.Mesh{2, UnitRange{Int64}}}, Any[Core.PartialStruct(Reactant.Sharding.Mesh{2, UnitRange{Int64}}, Any[Ve
ctor{Int64}, Core.Const(0:1), Core.Const((:dev1, :dev2)), Core.Const((1, 2))]), Core.Const(Vector{Union{Nothing, Symbol}}[[:dev2], [nothing]]), Core.Const((true, true)), Core
.Const((-1, -1)), Core.Const(Vector{Union{Nothing, Tuple{Int64, Int64}}}[[nothing], [nothing]])])
│         (shard_info_2 = (%50)(%51, Tuple{UnitRange{Int64}, UnitRange{Int64}}[(1:64, 1:3), (65:128, 1:3)]))
│   %53 = Core.apply_type(Reactant.Compiler.ConcreteIFRTArray, Float64, 2)::Core.Const(ConcreteIFRTArray{Float64, 2, S, P} where {S<:Reactant.Sharding.ShardInfo, P<:Union{Not
hing, Tuple{Int64, Int64}}})
│   %54 = result_buffer_m2_1::Reactant.XLA.IFRT.AsyncArray%55 = (%53)(%54, (128, 128), shard_info_1::Core.PartialStruct(Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{2, Reactant.Sharding.Mesh{2, UnitRange{Int64}}},
 Vector{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}, Any[Core.PartialStruct(Reactant.Sharding.NamedSharding{2, Reactant.Sharding.Mesh{2, UnitRange{Int64}}}, Any[Core.PartialS
truct(Reactant.Sharding.Mesh{2, UnitRange{Int64}}, Any[Vector{Int64}, Core.Const(0:1), Core.Const((:dev1, :dev2)), Core.Const((1, 2))]), Core.Const(Vector{Union{Nothing, Symb
ol}}[[nothing], [:dev1]]), Core.Const((true, true)), Core.Const((-1, -1)), Core.Const(Vector{Union{Nothing, Tuple{Int64, Int64}}}[[nothing], [nothing]])]), Core.Const(Tuple{U
nitRange{Int64}, UnitRange{Int64}}[(1:128, 1:128), (1:128, 1:128)])]))::ConcreteIFRTArray{Float64, 2, Reactant.Sharding.ShardInfo{Reactant.Sharding.NamedSharding{2, Reactant.
Sharding.Mesh{2, UnitRange{Int64}}}, Vector{Tuple{UnitRange{Int64}, UnitRange{Int64}}}}, Nothing}
│         %new(MyModel{ConcreteIFRTArray{Float64, 2, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}, Nothing}}, %55)
│         Core.Const(:(Core.apply_type(Reactant.Compiler.ConcreteIFRTArray, Float64, 2)))
│         Core.Const(:(result_buffer_m2_2))
│         Core.Const(:((%57)(%58, (128, 3), shard_info_2)))
│         Core.Const(:(result = Core.tuple(%56, %59)))
│         Core.Const(:(Base.getindex(args, 4)))
│         Core.Const(:(result_buffer_m2_3))
│         Core.Const(:(Reactant.Compiler.traced_setfield!(%61, :data, %62, ())))
└──       Core.Const(:(return result))

@avik-pal
Copy link
Collaborator

How is r defined above?

@DhairyaLGandhi
Copy link
Author

mesh = Reactant.Sharding.Mesh(reshape(devices, 1, 2), (:dev1, :dev2))
sharding = Reactant.Sharding.NamedSharding(mesh, (:dev1, :dev2))

m = MyModel(Reactant.to_rarray(rand(128, 128); sharding = sharding));
r = Reactant.to_rarray(rand(128, 3); sharding = sharding)

@avik-pal avik-pal linked a pull request May 12, 2025 that will close this issue
@avik-pal avik-pal changed the title Missing Sharding information with Enzyme.gradient Output sharding of structs with concrete arrays May 12, 2025
@avik-pal avik-pal added the bug Something isn't working label May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants