Skip to content

Commit

Permalink
add test to catch issue 927
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Feb 29, 2024
1 parent 831abfa commit 4aa0883
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/composition/learning_networks/replace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ end
newnode_given_old,
newmach_given_old,
newmodel_given_old,
serializable,
node::AbstractNode)
**Private method.**
Expand Down
6 changes: 3 additions & 3 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -977,14 +977,14 @@ See also [`restore!`](@ref), [`MLJBase.save`](@ref).
"""
function serializable(mach::Machine{<:Any, C}; verbosity=1) where C

isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED)
mach.state == -1 && return mach

# The next line of code makes `serializable` recursive, in the case that `mach.model`
# is a `Composite` model: `save` duplicates the underlying learning network, which
# involves calls to `serializable` on the old machines in the network to create the
# new ones.

isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED)
mach.state == -1 && return mach

serializable_fitresult = save(mach.model, mach.fitresult)

# Duplication currenty needs to happen in two steps for this to work in case of
Expand Down
62 changes: 62 additions & 0 deletions test/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,68 @@ end
rm(filename)
end

# define a model with non-persistent fitresult:
thing = []
struct EphemeralTransformer <: Unsupervised end
function MLJModelInterface.fit(::EphemeralTransformer, verbosity, X)
view = pointer(thing)
fitresult = (thing, view)
return fitresult, nothing, NamedTuple()
end
function MLJModelInterface.transform(::EphemeralTransformer, fitresult, X)
thing, view = fitresult
return view == pointer(thing) ? X : throw(ErrorException("dead fitresult"))
end
function MLJModelInterface.save(::EphemeralTransformer, fitresult)
thing, _ = fitresult
return thing
end
function MLJModelInterface.restore(::EphemeralTransformer, serialized_fitresult)
view = pointer(thing)
return (thing, view)
end

# commented out code just tests the transformer above has desired properties for testing:

# # test model transforms:
# model = EphemeralTransformer()
# mach = machine(model, 42) |> fit!
# @test MLJBase.transform(mach, 27) == 27

# # direct serialization fails:
# io = IOBuffer()
# serialize(io, mach)
# seekstart(io)
# mach2 = deserialize(io)
# @test_throws ErrorException("dead fitresult") transform(mach2, 42)

@testset "serialization for model with non-persistent fitresult" begin
X = (; x=randn(5))
mach = machine(EphemeralTransformer(), X)
fit!(mach, verbosity=0)
v = MLJBase.transform(mach, X).x
io = IOBuffer()
MLJBase.save(io, serializable(mach))
seekstart(io)
mach2 = restore!(deserialize(io))
@test MLJBase.transform(mach2, X).x == v
end

@testset "serialization for model with non-persistent fitresult in pipeline" begin
# https://github.com/JuliaAI/MLJBase.jl/issues/927
X = (; x=randn(5))
pipe = Standardizer |> EphemeralTransformer
X = (; x=randn(5))
mach = machine(pipe, X)
fit!(mach, verbosity=0)
v = MLJBase.transform(mach, X).x
io = IOBuffer()
MLJBase.save(io, serializable(mach))
seekstart(io)
mach2 = restore!(deserialize(io))
@test_broken MLJBase.transform(mach2, X).x == v
end

struct ReportingDynamic <: Unsupervised end
MLJBase.fit(::ReportingDynamic, _, X) = nothing, 16, NamedTuple()
MLJBase.transform(::ReportingDynamic,_, X) = (X, (news=42,))
Expand Down

0 comments on commit 4aa0883

Please sign in to comment.