Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a corner case of flat_params #182

Merged
merged 2 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJModelInterface"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
authors = ["Thibaut Lienart and Anthony Blaom"]
version = "1.9.2"
version = "1.9.3"

[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
49 changes: 33 additions & 16 deletions src/parameter_inspection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ values, which themselves might be transparent.
Most objects of type `MLJType` are transparent.

```julia
julia> params(EnsembleModel(atom=ConstantClassifier()))
(atom = (target_type = Bool,),
julia> params(EnsembleModel(model=ConstantClassifier()))
(model = (target_type = Bool,),
weights = Float64[],
bagging_fraction = 0.8,
rng_seed = 0,
Expand All @@ -36,25 +36,42 @@ isnotaleaf(m::Model) = length(propertynames(m)) > 0
"""
flat_params(m::Model)

Recursively convert any object subtyping `Model` into a named tuple, keyed on
the property names of `m`. The named tuple is possibly nested because
`flat_params` is recursively applied to the property values, which themselves
might subtype `Model`.
Deconstruct any `Model` instance `model` as a flat named tuple, keyed on property
names. Properties of nested model instances are recursively exposed,.as shown in the
example below. For most `Model` objects, properties are synonymous with fields, but this
is not a hard requirement.

For most `Model` objects, properties are synonymous with fields, but this is
not a hard requirement.
```julia
using MLJModels
using EnsembleModels
tree = (@load DecisionTreeClassifier pkg=DecisionTree)

julia> flat_params(EnsembleModel(model=tree))
(model__max_depth = -1,
model__min_samples_leaf = 1,
model__min_samples_split = 2,
model__min_purity_increase = 0.0,
model__n_subfeatures = 0,
model__post_prune = false,
model__merge_purity_threshold = 1.0,
model__display_depth = 5,
model__feature_importance = :impurity,
model__rng = Random._GLOBAL_RNG(),
atomic_weights = Float64[],
bagging_fraction = 0.8,
rng = Random._GLOBAL_RNG(),
n = 100,
acceleration = CPU1{Nothing}(nothing),
out_of_bag_measure = Any[],)
```

julia> flat_params(EnsembleModel(atom=ConstantClassifier()))
(atom = (target_type = Bool,),
weights = Float64[],
bagging_fraction = 0.8,
rng_seed = 0,
n = 100,
parallel = true,)

"""
flat_params(m; prefix="") = flat_params(m, Val(isnotaleaf(m)); prefix=prefix)
flat_params(m, ::Val{false}; prefix="") = NamedTuple{(Symbol(prefix),), Tuple{Any}}((m,))
function flat_params(m, ::Val{false}; prefix="")
prefix == "" && return NamedTuple()
NamedTuple{(Symbol(prefix),), Tuple{Any}}((m,))
end
function flat_params(m, ::Val{true}; prefix="")
fields = propertynames(m)
prefix = prefix == "" ? "" : prefix * "__"
Expand Down
4 changes: 4 additions & 0 deletions test/parameter_inspection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ end

struct Missy <: Model end

struct EmptyModel <: Model end

@testset "flat_params method" begin

m = ParentModel(1, "parent", ChildModel(2, "child1"),
Expand All @@ -61,5 +63,7 @@ struct Missy <: Model end
second_child__r = 3,
second_child__s = Missy()
)

@test MLJModelInterface.flat_params(EmptyModel()) == NamedTuple()
end
true