Skip to content

Commit e2932f9

Browse files
committed
Ensure GPU compatibility
1 parent 296e6fb commit e2932f9

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Adapt: Adapt, WrappedArray
1+
using Adapt: Adapt, WrappedArray, adapt
22
using ArrayLayouts: zero!
33
using BlockArrays:
44
BlockArrays,
@@ -350,3 +350,16 @@ function Base.replace_in_print_matrix(
350350
)
351351
return isstored(A, i, j) ? s : Base.replace_with_centered_mark(s)
352352
end
353+
354+
# attempt to catch things that wrap GPU arrays
355+
function Base.print_array(io::IO, X::AnyAbstractBlockSparseArray)
356+
X_cpu = adapt(Array, X)
357+
if typeof(X_cpu) === typeof(X) # prevent infinite recursion
358+
# need to specify ndims to allow specialized code for vector/matrix
359+
@allowscalar @invoke Base.print_array(
360+
io, X_cpu::AbstractArray{eltype(X_cpu),ndims(X_cpu)}
361+
)
362+
else
363+
Base.print_array(io, X_cpu)
364+
end
365+
end

test/test_basics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,7 @@ arrayts = (Array, JLArray)
11241124
# spacing so it isn't easy to make the test general.
11251125
a = BlockSparseMatrix{elt,arrayt{elt,2}}([2, 2], [2, 2])
11261126
@allowscalar a[1, 2] = 12
1127-
@test @allowscalar(sprint(show, "text/plain", a)) ==
1127+
@test sprint(show, "text/plain", a) ==
11281128
"$(summary(a)):\n $(zero(eltype(a))) $(eltype(a)(12)) │ ⋅ ⋅ \n $(zero(eltype(a))) $(zero(eltype(a))) │ ⋅ ⋅ \n ───────────┼──────────\n ⋅ ⋅ │ ⋅ ⋅ \n ⋅ ⋅ │ ⋅ ⋅ "
11291129
end
11301130
end

0 commit comments

Comments
 (0)