From 4aa08831d8c73ed377e0137e7016d62320c74ce1 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 1 Mar 2024 08:27:58 +1300 Subject: [PATCH] add test to catch issue 927 --- src/composition/learning_networks/replace.jl | 1 + src/machines.jl | 6 +- test/machines.jl | 62 ++++++++++++++++++++ 3 files changed, 66 insertions(+), 3 deletions(-) diff --git a/src/composition/learning_networks/replace.jl b/src/composition/learning_networks/replace.jl index 175c4b91..c9806269 100644 --- a/src/composition/learning_networks/replace.jl +++ b/src/composition/learning_networks/replace.jl @@ -38,6 +38,7 @@ end newnode_given_old, newmach_given_old, newmodel_given_old, + serializable, node::AbstractNode) **Private method.** diff --git a/src/machines.jl b/src/machines.jl index a1a3afc5..c194c8ae 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -977,14 +977,14 @@ See also [`restore!`](@ref), [`MLJBase.save`](@ref). """ function serializable(mach::Machine{<:Any, C}; verbosity=1) where C + isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED) + mach.state == -1 && return mach + # The next line of code makes `serializable` recursive, in the case that `mach.model` # is a `Composite` model: `save` duplicates the underlying learning network, which # involves calls to `serializable` on the old machines in the network to create the # new ones. - isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED) - mach.state == -1 && return mach - serializable_fitresult = save(mach.model, mach.fitresult) # Duplication currenty needs to happen in two steps for this to work in case of diff --git a/test/machines.jl b/test/machines.jl index 7d0845c2..11062b02 100644 --- a/test/machines.jl +++ b/test/machines.jl @@ -508,6 +508,68 @@ end rm(filename) end +# define a model with non-persistent fitresult: +thing = [] +struct EphemeralTransformer <: Unsupervised end +function MLJModelInterface.fit(::EphemeralTransformer, verbosity, X) + view = pointer(thing) + fitresult = (thing, view) + return fitresult, nothing, NamedTuple() +end +function MLJModelInterface.transform(::EphemeralTransformer, fitresult, X) + thing, view = fitresult + return view == pointer(thing) ? X : throw(ErrorException("dead fitresult")) +end +function MLJModelInterface.save(::EphemeralTransformer, fitresult) + thing, _ = fitresult + return thing +end +function MLJModelInterface.restore(::EphemeralTransformer, serialized_fitresult) + view = pointer(thing) + return (thing, view) +end + +# commented out code just tests the transformer above has desired properties for testing: + +# # test model transforms: +# model = EphemeralTransformer() +# mach = machine(model, 42) |> fit! +# @test MLJBase.transform(mach, 27) == 27 + +# # direct serialization fails: +# io = IOBuffer() +# serialize(io, mach) +# seekstart(io) +# mach2 = deserialize(io) +# @test_throws ErrorException("dead fitresult") transform(mach2, 42) + +@testset "serialization for model with non-persistent fitresult" begin + X = (; x=randn(5)) + mach = machine(EphemeralTransformer(), X) + fit!(mach, verbosity=0) + v = MLJBase.transform(mach, X).x + io = IOBuffer() + MLJBase.save(io, serializable(mach)) + seekstart(io) + mach2 = restore!(deserialize(io)) + @test MLJBase.transform(mach2, X).x == v +end + +@testset "serialization for model with non-persistent fitresult in pipeline" begin + # https://github.com/JuliaAI/MLJBase.jl/issues/927 + X = (; x=randn(5)) + pipe = Standardizer |> EphemeralTransformer + X = (; x=randn(5)) + mach = machine(pipe, X) + fit!(mach, verbosity=0) + v = MLJBase.transform(mach, X).x + io = IOBuffer() + MLJBase.save(io, serializable(mach)) + seekstart(io) + mach2 = restore!(deserialize(io)) + @test_broken MLJBase.transform(mach2, X).x == v +end + struct ReportingDynamic <: Unsupervised end MLJBase.fit(::ReportingDynamic, _, X) = nothing, 16, NamedTuple() MLJBase.transform(::ReportingDynamic,_, X) = (X, (news=42,))