Skip to content

Commit

Permalink
add fixes to nested serialization to close #927
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Feb 29, 2024
1 parent 4aa0883 commit 41181d8
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 22 deletions.
23 changes: 13 additions & 10 deletions src/composition/learning_networks/replace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ the `model` and `args` field values as derived from the provided dictionaries. I
the returned machine is hooked into the new learning network defined by the values of
`newnode_given_old`.
If `serializable=true`, return a serializable copy instead (namely,
`serializable(node.mach)`) and ignore the `newmodel_given_old` dictionary (no model
replacement).
If `serializable=true`, return a serializable copy instead, but make no model replacement.
The `newmodel_given_old` dictionary is still used, but now to look up the concrete model
corresponding to the symbolic one stored in `node`'s machine.
See also [`serializable`](@ref).
Expand All @@ -26,9 +26,10 @@ function machine_replacement(
newnode_given_old,
serializable
)
# the `replace` called here is defined in src/machines.jl:
mach = serializable ? MLJBase.serializable(N.machine) :
replace(N.machine, :model => newmodel_given_old[N.machine.model])
# the `replace` called below is defined in src/machines.jl.
newmodel = newmodel_given_old[N.machine.model]
mach = serializable ? MLJBase.serializable(N.machine, newmodel) :
replace(N.machine, :model => newmodel)
mach.args = Tuple(newnode_given_old[arg] for arg in N.machine.args)
return mach
end
Expand Down Expand Up @@ -87,9 +88,11 @@ const DOC_REPLACE_OPTIONS =
- `copy_unspecified_deeply=true`: If `false`, models or sources not listed for
replacement are identically equal in the original and returned node.
- `serializable=false`: If `true`, all machines in the new network are serializable.
However, all `model` replacements are ignored, and unspecified sources are always
replaced with empty ones.
- `serializable=false`: If `true`, all machines in the new network are made
serializable and the specified model replacements are only used for serialization
purposes: for each pair `s => model` (`s` assumed to be a symbolic model) each
machine with model `s` is replaced with `serializable(mach, model)`. All unspecified
sources are always replaced with empty ones.
"""

Expand Down Expand Up @@ -193,7 +196,7 @@ function _replace(

# Instantiate model dictionary:
model_pairs = filter(collect(pairs)) do pair
first(pair) isa Model
first(pair) isa Model || first(pair) isa Symbol
end
models_ = models(W)
models_to_copy = setdiff(models_, first.(model_pairs))
Expand Down
25 changes: 20 additions & 5 deletions src/composition/models/network_composite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,33 @@ MLJModelInterface.fitted_params(composite::NetworkComposite, signature) =
MLJModelInterface.reporting_operations(::Type{<:NetworkComposite}) = OPERATIONS

# here `fitresult` has type `Signature`.
save(model::NetworkComposite, fitresult) = replace(fitresult, serializable=true)
function save(model::NetworkComposite, fitresult)
# The network includes machines with symbolic models. These machines need to be
# replaced by serializable versions, but we cannot naively use `serializable(mach)`,
# because the absence of the concrete model means this just returns `mach` (because
# `save(::Symbol, fitresult)` returns `fitresult`). We need to use the special
# `serialiable(mach, model)` instead. This is what `replace` below does, because we
# pass it the flag `serializable=true` but we must also pass `symbol =>
# concrete_model` replacements, which we calculate first:

greatest_lower_bound = MLJBase.glb(fitresult)
machines_given_model = MLJBase.machines_given_model(greatest_lower_bound)
atomic_models = keys(machines_given_model)
pairs = [atom => getproperty(model, atom) for atom in atomic_models]

replace(fitresult, pairs...; serializable=true)
end

function MLJModelInterface.restore(model::NetworkComposite, serializable_fitresult)
greatest_lower_bound = MLJBase.glb(serializable_fitresult)
machines_given_model = MLJBase.machines_given_model(greatest_lower_bound)
models = keys(machines_given_model)
atomic_models = keys(machines_given_model)

# the following indirectly mutates `serialiable_fiteresult`, returning it to
# usefulness:
for model in models
for mach in machines_given_model[model]
mach.fitresult = restore(model, mach.fitresult)
for atom in atomic_models
for mach in machines_given_model[atom]
mach.fitresult = MLJBase.restore(getproperty(model, atom), mach.fitresult)
mach.state = 1
end
end
Expand Down
8 changes: 4 additions & 4 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ A machine returned by `serializable` is characterized by the property
See also [`restore!`](@ref), [`MLJBase.save`](@ref).
"""
function serializable(mach::Machine{<:Any, C}; verbosity=1) where C
function serializable(mach::Machine{<:Any, C}, model=mach.model; verbosity=1) where C

isdefined(mach, :fitresult) || throw(ERR_SERIALIZING_UNTRAINED)
mach.state == -1 && return mach
Expand All @@ -985,7 +985,7 @@ function serializable(mach::Machine{<:Any, C}; verbosity=1) where C
# involves calls to `serializable` on the old machines in the network to create the
# new ones.

serializable_fitresult = save(mach.model, mach.fitresult)
serializable_fitresult = save(model, mach.fitresult)

# Duplication currenty needs to happen in two steps for this to work in case of
# `Composite` models.
Expand Down Expand Up @@ -1017,9 +1017,9 @@ useable form.
For an example see [`serializable`](@ref).
"""
function restore!(mach::Machine)
function restore!(mach::Machine, model=mach.model)
mach.state != -1 && return mach
mach.fitresult = restore(mach.model, mach.fitresult)
mach.fitresult = restore(model, mach.fitresult)
mach.state = 1
return mach
end
Expand Down
18 changes: 15 additions & 3 deletions test/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,10 +549,16 @@ end
fit!(mach, verbosity=0)
v = MLJBase.transform(mach, X).x
io = IOBuffer()
MLJBase.save(io, serializable(mach))
serialize(io, serializable(mach))
seekstart(io)
mach2 = restore!(deserialize(io))
@test MLJBase.transform(mach2, X).x == v

# using `save`/`machine`:
MLJBase.save(io, mach)
seekstart(io)
mach2 = machine(io)
@test MLJBase.transform(mach2, X).x == v
end

@testset "serialization for model with non-persistent fitresult in pipeline" begin
Expand All @@ -564,10 +570,16 @@ end
fit!(mach, verbosity=0)
v = MLJBase.transform(mach, X).x
io = IOBuffer()
MLJBase.save(io, serializable(mach))
serialize(io, serializable(mach))
seekstart(io)
mach2 = restore!(deserialize(io))
@test_broken MLJBase.transform(mach2, X).x == v
@test MLJBase.transform(mach2, X).x == v

# using `save`/`machine`:
MLJBase.save(io, mach)
seekstart(io)
mach2 = machine(io)
@test MLJBase.transform(mach2, X).x == v
end

struct ReportingDynamic <: Unsupervised end
Expand Down

0 comments on commit 41181d8

Please sign in to comment.