diff --git a/Project.toml b/Project.toml index e97e3fdf45..a0973a86f3 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.2.46" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" +DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -64,6 +65,7 @@ Adapt = "4.1" ArrayInterface = "7.17.1" CEnum = "0.5" CUDA = "5.6" +DeepDiffs = "1.2" Downloads = "1.6" EnumX = "1" Enzyme = "0.13.28" diff --git a/src/Compiler.jl b/src/Compiler.jl index 04aeb39e86..9a4097093c 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -2,6 +2,7 @@ module Compiler using Reactant_jll using Libdl: dlsym +using DeepDiffs: deepdiff import ..Reactant: Reactant, @@ -1940,24 +1941,23 @@ XLA.cost_analysis(thunk::Thunk) = XLA.cost_analysis(thunk.exec) struct MisMatchedThunkTypeError{ThunkTy,FoundTypes} <: Base.Exception end function Base.showerror( - io::IO, ece::MisMatchedThunkTypeError{Thunk{FTy,tag,ArgTypes,IsClosure},FoundTypes} -) where {FTy,tag,ArgTypes,FoundTypes,IsClosure} + io::IO, + ::MisMatchedThunkTypeError{ + Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy},FoundTypes + }, +) where {FTy,tag,ArgTypes,FoundTypes,IsClosure,ExecTy,DeviceTy} print( io, - "\nThe Reactant-compiled function `$(Thunk{FTy, tag, ArgTypes, IsClosure})` exists, but no method is defined for this combination of argument types.", + "\nThe Reactant-compiled function `$(Thunk{FTy, tag, ArgTypes, IsClosure, ExecTy, DeviceTy})` exists, but no method is defined for this combination of argument types.\n\nDiff between input argument types and compiled argument types:\n\n", ) - print( - io, - "\nYou passed in arguments with types\n\t(" * - join(FoundTypes.parameters, ", ") * - ")", - ) - return print( - io, - "\nHowever the method you are calling was compiled for arguments with types\n\t(" * - join(ArgTypes.parameters, ", ") * - ")", + + str = sprint( + show, + deepdiff(join(FoundTypes.parameters, ", "), join(ArgTypes.parameters, ", ")); + context=IOContext(io), ) + println(io, strip(str, '"')) + return nothing end @generated function (thunk::Thunk{FTy,tag,ArgTypes,IsClosure,ExecTy,DeviceTy})(