Skip to content

Commit 0f2b278

Browse files
authored
fix: relax device types and inputs for sharding (#930)
1 parent dd9b467 commit 0f2b278

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/Sharding.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ function standardize_sharding(sharding::DimsSharding, x::Union{AbstractArray,Num
315315
end
316316

317317
function (sharding::DimsSharding)(
318-
client::XLA.AbstractClient, device::Nothing, x::Union{AbstractArray,Number}
318+
client::XLA.AbstractClient, device, x::Union{AbstractArray,Number}
319319
)
320320
return (standardize_sharding(sharding, x))(client, device, x)
321321
end

src/Types.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ function ConcreteIFRTArray(
289289
arguments will be ignored."
290290
end
291291
end
292-
sharded_data, sharding = sharding(client, device, data)
292+
sharded_data, sharding = sharding(client, nothing, data)
293293
return ConcreteIFRTArray{T,N}(sharded_data, size(data), sharding)
294294
end
295295

0 commit comments

Comments
 (0)