From 41181d88bd2b71c3ea61ebe898354b1adca2c1d8 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 1 Mar 2024 10:53:28 +1300 Subject: [PATCH] add fixes to nested serialization to close #927 --- src/composition/learning_networks/replace.jl | 23 ++++++++++-------- src/composition/models/network_composite.jl | 25 ++++++++++++++++---- src/machines.jl | 8 +++---- test/machines.jl | 18 +++++++++++--- 4 files changed, 52 insertions(+), 22 deletions(-) diff --git a/src/composition/learning_networks/replace.jl b/src/composition/learning_networks/replace.jl index c9806269..08145262 100644 --- a/src/composition/learning_networks/replace.jl +++ b/src/composition/learning_networks/replace.jl @@ -13,9 +13,9 @@ the `model` and `args` field values as derived from the provided dictionaries. I the returned machine is hooked into the new learning network defined by the values of `newnode_given_old`. -If `serializable=true`, return a serializable copy instead (namely, -`serializable(node.mach)`) and ignore the `newmodel_given_old` dictionary (no model -replacement). +If `serializable=true`, return a serializable copy instead, but make no model replacement. +The `newmodel_given_old` dictionary is still used, but now to look up the concrete model +corresponding to the symbolic one stored in `node`'s machine. See also [`serializable`](@ref). @@ -26,9 +26,10 @@ function machine_replacement( newnode_given_old, serializable ) - # the `replace` called here is defined in src/machines.jl: - mach = serializable ? MLJBase.serializable(N.machine) : - replace(N.machine, :model => newmodel_given_old[N.machine.model]) + # the `replace` called below is defined in src/machines.jl. + newmodel = newmodel_given_old[N.machine.model] + mach = serializable ? MLJBase.serializable(N.machine, newmodel) : + replace(N.machine, :model => newmodel) mach.args = Tuple(newnode_given_old[arg] for arg in N.machine.args) return mach end @@ -87,9 +88,11 @@ const DOC_REPLACE_OPTIONS = - `copy_unspecified_deeply=true`: If `false`, models or sources not listed for replacement are identically equal in the original and returned node. - - `serializable=false`: If `true`, all machines in the new network are serializable. - However, all `model` replacements are ignored, and unspecified sources are always - replaced with empty ones. + - `serializable=false`: If `true`, all machines in the new network are made + serializable and the specified model replacements are only used for serialization + purposes: for each pair `s => model` (`s` assumed to be a symbolic model) each + machine with model `s` is replaced with `serializable(mach, model)`. All unspecified + sources are always replaced with empty ones. """ @@ -193,7 +196,7 @@ function _replace( # Instantiate model dictionary: model_pairs = filter(collect(pairs)) do pair - first(pair) isa Model + first(pair) isa Model || first(pair) isa Symbol end models_ = models(W) models_to_copy = setdiff(models_, first.(model_pairs)) diff --git a/src/composition/models/network_composite.jl b/src/composition/models/network_composite.jl index f5586950..dcec4f83 100644 --- a/src/composition/models/network_composite.jl +++ b/src/composition/models/network_composite.jl @@ -88,18 +88,33 @@ MLJModelInterface.fitted_params(composite::NetworkComposite, signature) = MLJModelInterface.reporting_operations(::Type{<:NetworkComposite}) = OPERATIONS # here `fitresult` has type `Signature`. -save(model::NetworkComposite, fitresult) = replace(fitresult, serializable=true) +function save(model::NetworkComposite, fitresult) + # The network includes machines with symbolic models. These machines need to be + # replaced by serializable versions, but we cannot naively use `serializable(mach)`, + # because the absence of the concrete model means this just returns `mach` (because + # `save(::Symbol, fitresult)` returns `fitresult`). We need to use the special + # `serialiable(mach, model)` instead. This is what `replace` below does, because we + # pass it the flag `serializable=true` but we must also pass `symbol => + # concrete_model` replacements, which we calculate first: + + greatest_lower_bound = MLJBase.glb(fitresult) + machines_given_model = MLJBase.machines_given_model(greatest_lower_bound) + atomic_models = keys(machines_given_model) + pairs = [atom => getproperty(model, atom) for atom in atomic_models] + + replace(fitresult, pairs...; serializable=true) +end function MLJModelInterface.restore(model::NetworkComposite, serializable_fitresult) greatest_lower_bound = MLJBase.glb(serializable_fitresult) machines_given_model = MLJBase.machines_given_model(greatest_lower_bound) - models = keys(machines_given_model) + atomic_models = keys(machines_given_model) # the following indirectly mutates `serialiable_fiteresult`, returning it to # usefulness: - for model in models - for mach in machines_given_model[model] - mach.fitresult = restore(model, mach.fitresult) + for atom in atomic_models + for mach in machines_given_model[atom] + mach.fitresult = MLJBase.restore(getproperty(model, atom), mach.fitresult) mach.state = 1 end end diff --git a/src/machines.jl b/src/machines.jl index c194c8ae..5b2e0fd0 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -975,7 +975,7 @@ A machine returned by `serializable` is characterized by the property See also [`restore!`](@ref), [`MLJBase.save`](@ref). """ -function serializable(mach::Machine{<:Any, C}; verbosity=1) where C +function serializable(mach::Machine{<:Any, C}, model=mach.model; verbosity=1) where C isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED) mach.state == -1 && return mach @@ -985,7 +985,7 @@ function serializable(mach::Machine{<:Any, C}; verbosity=1) where C # involves calls to `serializable` on the old machines in the network to create the # new ones. - serializable_fitresult = save(mach.model, mach.fitresult) + serializable_fitresult = save(model, mach.fitresult) # Duplication currenty needs to happen in two steps for this to work in case of # `Composite` models. @@ -1017,9 +1017,9 @@ useable form. For an example see [`serializable`](@ref). """ -function restore!(mach::Machine) +function restore!(mach::Machine, model=mach.model) mach.state != -1 && return mach - mach.fitresult = restore(mach.model, mach.fitresult) + mach.fitresult = restore(model, mach.fitresult) mach.state = 1 return mach end diff --git a/test/machines.jl b/test/machines.jl index 11062b02..c78aa06d 100644 --- a/test/machines.jl +++ b/test/machines.jl @@ -549,10 +549,16 @@ end fit!(mach, verbosity=0) v = MLJBase.transform(mach, X).x io = IOBuffer() - MLJBase.save(io, serializable(mach)) + serialize(io, serializable(mach)) seekstart(io) mach2 = restore!(deserialize(io)) @test MLJBase.transform(mach2, X).x == v + + # using `save`/`machine`: + MLJBase.save(io, mach) + seekstart(io) + mach2 = machine(io) + @test MLJBase.transform(mach2, X).x == v end @testset "serialization for model with non-persistent fitresult in pipeline" begin @@ -564,10 +570,16 @@ end fit!(mach, verbosity=0) v = MLJBase.transform(mach, X).x io = IOBuffer() - MLJBase.save(io, serializable(mach)) + serialize(io, serializable(mach)) seekstart(io) mach2 = restore!(deserialize(io)) - @test_broken MLJBase.transform(mach2, X).x == v + @test MLJBase.transform(mach2, X).x == v + + # using `save`/`machine`: + MLJBase.save(io, mach) + seekstart(io) + mach2 = machine(io) + @test MLJBase.transform(mach2, X).x == v end struct ReportingDynamic <: Unsupervised end